Adding DocTest to TrOCR (#16398)

* docstring still WIP | adding to documentation_tests

* clean version | passes tests

* adding to documentation_test

* adding forward for training pass

* make fixup applied

* address comments

* fix doctest

* apply make fixup

* remove additional blank

* fix file to have correct split for prepare_for_doc_test

* Update src/transformers/models/trocr/modeling_trocr.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* address comments

* changing text | adding loss check | make fixup

* make fixup

* Update src/transformers/models/trocr/modeling_trocr.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Update src/transformers/models/trocr/modeling_trocr.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Update src/transformers/models/trocr/modeling_trocr.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* make fixup

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Arnaud Stiegler 2022-03-29 10:19:06 -04:00 committed by GitHub
parent 85295621f1
commit ed31ab3f10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 5 deletions

View File

@ -597,8 +597,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor`
of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
@ -891,13 +891,49 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
Example:
```python
>>> from transformers import VisionEncoderDecoderModel, TrOCRForCausalLM, ViTModel, TrOCRConfig, ViTConfig
>>> from transformers import (
... TrOCRConfig,
... TrOCRProcessor,
... TrOCRForCausalLM,
... ViTConfig,
... ViTModel,
... VisionEncoderDecoderModel,
... )
>>> import requests
>>> from PIL import Image
>>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
>>> # init vision2text model with random weights
>>> encoder = ViTModel(ViTConfig())
>>> decoder = TrOCRForCausalLM(TrOCRConfig())
# init vision2text model
>>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
>>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
>>> # load image from the IAM dataset
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
>>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"
>>> # training
>>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
>>> model.config.pad_token_id = processor.tokenizer.pad_token_id
>>> model.config.vocab_size = model.config.decoder.vocab_size
>>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
>>> outputs = model(pixel_values, labels=labels)
>>> loss = outputs.loss
>>> round(loss.item(), 2)
5.30
>>> # inference
>>> generated_ids = model.generate(pixel_values)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> generated_text
'industry, " Mr. Brown commented icily. " Let us have a'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

View File

@ -39,6 +39,7 @@ src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.p
src/transformers/models/speech_to_text/modeling_speech_to_text.py
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
src/transformers/models/swin/modeling_swin.py
src/transformers/models/trocr/modeling_trocr.py
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
src/transformers/models/van/modeling_van.py