mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix doc examples (#15257)
This commit is contained in:
parent
ad7390636d
commit
515ed3ad2a
@ -70,7 +70,8 @@ into a single instance to both extract the input features and decode the predict
|
|||||||
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||||
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
||||||
|
|
||||||
>>> # load image from the IAM dataset url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
>>> # load image from the IAM dataset
|
||||||
|
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||||
|
|
||||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||||
|
@ -42,10 +42,10 @@ from .configuration_vilt import ViltConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "ViltConfig"
|
_CONFIG_FOR_DOC = "ViltConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm-itm"
|
_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm"
|
||||||
|
|
||||||
VILT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
VILT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"dandelin/vilt-b32-mlm-itm",
|
"dandelin/vilt-b32-mlm",
|
||||||
# See all ViLT models at https://huggingface.co/models?filter=vilt
|
# See all ViLT models at https://huggingface.co/models?filter=vilt
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -775,17 +775,19 @@ class ViltModel(ViltPreTrainedModel):
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import ViltFeatureExtractor, ViltModel
|
>>> from transformers import ViltProcessor, ViltModel
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import requests
|
>>> import requests
|
||||||
|
|
||||||
|
>>> # prepare image and text
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> text = "hello world"
|
||||||
|
|
||||||
>>> feature_extractor = ViltFeatureExtractor.from_pretrained("dandelin/vilt-b32-mlm-itm")
|
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
|
||||||
>>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm-itm")
|
>>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
|
||||||
|
|
||||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
>>> inputs = processor(image, text, return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
```"""
|
```"""
|
||||||
@ -930,10 +932,11 @@ class ViltForMaskedLM(ViltPreTrainedModel):
|
|||||||
>>> from transformers import ViltProcessor, ViltForMaskedLM
|
>>> from transformers import ViltProcessor, ViltForMaskedLM
|
||||||
>>> import requests
|
>>> import requests
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
|
>>> import re
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
>>> text = "How many cats are there?"
|
>>> text = "a bunch of [MASK] laying on a [MASK]."
|
||||||
|
|
||||||
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
|
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
|
||||||
>>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
|
>>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
|
||||||
@ -943,7 +946,31 @@ class ViltForMaskedLM(ViltPreTrainedModel):
|
|||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> outputs = model(**encoding)
|
>>> outputs = model(**encoding)
|
||||||
>>> logits = outputs.logits
|
|
||||||
|
>>> tl = len(re.findall("\[MASK\]", text))
|
||||||
|
>>> inferred_token = [text]
|
||||||
|
|
||||||
|
>>> # gradually fill in the MASK tokens, one by one
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... for i in range(tl):
|
||||||
|
... encoded = processor.tokenizer(inferred_token)
|
||||||
|
... input_ids = torch.tensor(encoded.input_ids).to(device)
|
||||||
|
... encoded = encoded["input_ids"][0][1:-1]
|
||||||
|
... outputs = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||||
|
... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
|
||||||
|
... # only take into account text features (minus CLS and SEP token)
|
||||||
|
... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
|
||||||
|
... mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
|
||||||
|
... # only take into account text
|
||||||
|
... mlm_values[torch.tensor(encoded) != 103] = 0
|
||||||
|
... select = mlm_values.argmax().item()
|
||||||
|
... encoded[select] = mlm_ids[select].item()
|
||||||
|
... inferred_token = [processor.decode(encoded)]
|
||||||
|
|
||||||
|
>>> selected_token = ""
|
||||||
|
>>> encoded = processor.tokenizer(inferred_token)
|
||||||
|
>>> processor.decode(encoded.input_ids[0], skip_special_tokens=True)
|
||||||
|
a bunch of cats laying on a couch.
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@ -1093,6 +1120,7 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
|
|||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
>>> idx = logits.argmax(-1).item()
|
>>> idx = logits.argmax(-1).item()
|
||||||
>>> print("Predicted answer:", model.config.id2label[idx])
|
>>> print("Predicted answer:", model.config.id2label[idx])
|
||||||
|
Predicted answer: 2
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@ -1297,13 +1325,13 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
|
|||||||
|
|
||||||
>>> # prepare inputs
|
>>> # prepare inputs
|
||||||
>>> encoding = processor([image1, image2], text, return_tensors="pt")
|
>>> encoding = processor([image1, image2], text, return_tensors="pt")
|
||||||
>>> pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1)
|
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
|
>>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
|
||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
>>> idx = logits.argmax(-1).item()
|
>>> idx = logits.argmax(-1).item()
|
||||||
>>> print("Predicted answer:", model.config.id2label[idx])
|
>>> print("Predicted answer:", model.config.id2label[idx])
|
||||||
|
Predicted answer: True
|
||||||
```"""
|
```"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
Loading…
Reference in New Issue
Block a user