mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Reformer] Fix example and error message (#4191)
* fix example reformer * fix error message and example docstring * improved error message
This commit is contained in:
parent
96c78396ce
commit
74ffc9ea6b
@ -124,8 +124,8 @@ class AxialPositionEmbeddings(nn.Module):
|
||||
if self.training is True:
|
||||
assert (
|
||||
reduce(mul, self.axial_pos_shape) == sequence_length
|
||||
), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(
|
||||
self.axial_pos_shape, sequence_length
|
||||
), "If training, make sure that config.axial_pos_shape factors: {} multiply to sequence length. Got prod({}) != sequence_length: {}. You might want to consider padding your sequence length to {} or changing config.axial_pos_shape.".format(
|
||||
self.axial_pos_shape, self.axial_pos_shape, sequence_length, reduce(mul, self.axial_pos_shape)
|
||||
)
|
||||
if self.dropout > 0:
|
||||
weights = torch.cat(broadcasted_weights, dim=-1)
|
||||
@ -1515,11 +1515,11 @@ class ReformerModel(ReformerPreTrainedModel):
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ReformerModel, ReformerTokenizer
|
||||
from transformers import ReformerModelWithLMHead, ReformerTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = ReformerModel.from_pretrained('bert-base-uncased')
|
||||
tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
|
||||
model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
@ -1562,7 +1562,7 @@ class ReformerModel(ReformerPreTrainedModel):
|
||||
if self.training is True:
|
||||
raise ValueError(
|
||||
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
|
||||
input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length
|
||||
input_shape[-1], least_common_mult_chunk_length, input_shape[-1] + padding_length
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user