Update troubleshoot with more content (#16243)

* 📝 first draft

* 🖍 apply feedback
This commit is contained in:
Steven Liu 2022-03-21 09:37:18 -07:00 committed by GitHub
parent fbb454307d
commit 5a42bb431e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -120,4 +120,57 @@ Another option is to get a better traceback from the GPU. Add the following envi
>>> import os
>>> os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
```
```
## Incorrect output when padding tokens aren't masked
In some cases, the output `hidden_state` may be incorrect if the `input_ids` include padding tokens. To demonstrate, load a model and tokenizer. You can access a model's `pad_token_id` to see its value. The `pad_token_id` may be `None` for some models, but you can always manually set it.
```py
>>> from transformers import AutoModelForSequenceClassification
>>> import torch
>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
>>> model.config.pad_token_id
0
```
The following example shows the output without masking the padding tokens:
```py
>>> input_ids = torch.tensor([[7592, 2057, 2097, 2393, 9611, 2115], [7592, 0, 0, 0, 0, 0]])
>>> output = model(input_ids)
>>> print(output.logits)
tensor([[ 0.0082, -0.2307],
[ 0.1317, -0.1683]], grad_fn=<AddmmBackward0>)
```
Here is the actual output of the second sequence:
```py
>>> input_ids = torch.tensor([[7592]])
>>> output = model(input_ids)
>>> print(output.logits)
tensor([[-0.1008, -0.4061]], grad_fn=<AddmmBackward0>)
```
Most of the time, you should provide an `attention_mask` to your model to ignore the padding tokens to avoid this silent error. Now the output of the second sequence matches its actual output:
<Tip>
By default, the tokenizer creates an `attention_mask` for you based on your specific tokenizer's defaults.
</Tip>
```py
>>> attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 0, 0, 0, 0, 0]])
>>> output = model(input_ids, attention_mask=attention_mask)
>>> print(output.logits)
tensor([[ 0.0082, -0.2307],
[-0.1008, -0.4061]], grad_fn=<AddmmBackward0>)
```
🤗 Transformers doesn't automatically create an `attention_mask` to mask a padding token if it is provided because:
- Some models don't have a padding token.
- For some use-cases, users want a model to attend to a padding token.