mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding DocTest to TrOCR (#16398)
* docstring still WIP | adding to documentation_tests * clean version | passes tests * adding to documentation_test * adding forward for training pass * make fixup applied * address comments * fix doctest * apply make fixup * remove additional blank * fix file to have correct split for prepare_for_doc_test * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * address comments * changing text | adding loss check | make fixup * make fixup * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * make fixup Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
85295621f1
commit
ed31ab3f10
@ -597,8 +597,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor`
|
||||
of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
||||
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
||||
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
||||
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
||||
embedding lookup matrix.
|
||||
@ -891,13 +891,49 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VisionEncoderDecoderModel, TrOCRForCausalLM, ViTModel, TrOCRConfig, ViTConfig
|
||||
>>> from transformers import (
|
||||
... TrOCRConfig,
|
||||
... TrOCRProcessor,
|
||||
... TrOCRForCausalLM,
|
||||
... ViTConfig,
|
||||
... ViTModel,
|
||||
... VisionEncoderDecoderModel,
|
||||
... )
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
|
||||
>>> # init vision2text model with random weights
|
||||
>>> encoder = ViTModel(ViTConfig())
|
||||
>>> decoder = TrOCRForCausalLM(TrOCRConfig())
|
||||
# init vision2text model
|
||||
|
||||
>>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
>>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
|
||||
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
|
||||
>>> # load image from the IAM dataset
|
||||
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
>>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"
|
||||
|
||||
>>> # training
|
||||
>>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||
>>> model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
>>> model.config.vocab_size = model.config.decoder.vocab_size
|
||||
|
||||
>>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
|
||||
>>> outputs = model(pixel_values, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> round(loss.item(), 2)
|
||||
5.30
|
||||
|
||||
>>> # inference
|
||||
>>> generated_ids = model.generate(pixel_values)
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
>>> generated_text
|
||||
'industry, " Mr. Brown commented icily. " Let us have a'
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
@ -39,6 +39,7 @@ src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.p
|
||||
src/transformers/models/speech_to_text/modeling_speech_to_text.py
|
||||
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
|
||||
src/transformers/models/swin/modeling_swin.py
|
||||
src/transformers/models/trocr/modeling_trocr.py
|
||||
src/transformers/models/unispeech/modeling_unispeech.py
|
||||
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
|
||||
src/transformers/models/van/modeling_van.py
|
||||
|
Loading…
Reference in New Issue
Block a user