Fix LayoutLMv2 init issue and doctest (#30278)

* fix

* try suggestion

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-04-23 15:33:17 +02:00 committed by GitHub
parent d179b9dc78
commit 416fdbad7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -503,6 +503,9 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, LayoutLMv2Model):
if hasattr(module, "visual_segment_embedding"):
module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range)
def my_convert_sync_batchnorm(module, process_group=None):
@ -822,7 +825,7 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
>>> import torch
>>> from datasets import load_dataset
>>> set_seed(88)
>>> set_seed(0)
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
@ -993,7 +996,7 @@ class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
>>> import torch
>>> from datasets import load_dataset
>>> set_seed(88)
>>> set_seed(0)
>>> dataset = load_dataset("rvl_cdip", split="train", streaming=True)
>>> data = next(iter(dataset))
@ -1012,8 +1015,8 @@ class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
>>> loss, logits = outputs.loss, outputs.logits
>>> predicted_idx = logits.argmax(dim=-1).item()
>>> predicted_answer = dataset.info.features["label"].names[4]
>>> predicted_idx, predicted_answer
(4, 'advertisement')
>>> predicted_idx, predicted_answer # results are not good without further fine-tuning
(7, 'advertisement')
```
"""
@ -1172,7 +1175,7 @@ class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
>>> from PIL import Image
>>> from datasets import load_dataset
>>> set_seed(88)
>>> set_seed(0)
>>> datasets = load_dataset("nielsr/funsd", split="test")
>>> labels = datasets.features["ner_tags"].feature.names
@ -1203,8 +1206,8 @@ class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
>>> predicted_token_class_ids = logits.argmax(-1)
>>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
>>> predicted_tokens_classes[:5]
['B-ANSWER', 'B-HEADER', 'B-HEADER', 'B-HEADER', 'B-HEADER']
>>> predicted_tokens_classes[:5] # results are not good without further fine-tuning
['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION']
```
"""
@ -1314,7 +1317,7 @@ class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
>>> from PIL import Image
>>> from datasets import load_dataset
>>> set_seed(88)
>>> set_seed(0)
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
@ -1328,12 +1331,12 @@ class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
>>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
>>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
>>> predicted_start_idx, predicted_end_idx
(154, 287)
(30, 191)
>>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
>>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
>>> predicted_answer # results are not very good without further fine-tuning
'council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public ...
>>> predicted_answer # results are not good without further fine-tuning
'44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from'
```
```python
@ -1343,7 +1346,7 @@ class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
>>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
>>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
>>> predicted_answer_span_start, predicted_answer_span_end
(154, 287)
(30, 191)
```
"""