mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[Reformer] Fix example and error message (#4191)
* fix example reformer * fix error message and example docstring * improved error message
This commit is contained in:
parent
96c78396ce
commit
74ffc9ea6b
@ -124,8 +124,8 @@ class AxialPositionEmbeddings(nn.Module):
|
|||||||
if self.training is True:
|
if self.training is True:
|
||||||
assert (
|
assert (
|
||||||
reduce(mul, self.axial_pos_shape) == sequence_length
|
reduce(mul, self.axial_pos_shape) == sequence_length
|
||||||
), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(
|
), "If training, make sure that config.axial_pos_shape factors: {} multiply to sequence length. Got prod({}) != sequence_length: {}. You might want to consider padding your sequence length to {} or changing config.axial_pos_shape.".format(
|
||||||
self.axial_pos_shape, sequence_length
|
self.axial_pos_shape, self.axial_pos_shape, sequence_length, reduce(mul, self.axial_pos_shape)
|
||||||
)
|
)
|
||||||
if self.dropout > 0:
|
if self.dropout > 0:
|
||||||
weights = torch.cat(broadcasted_weights, dim=-1)
|
weights = torch.cat(broadcasted_weights, dim=-1)
|
||||||
@ -1515,11 +1515,11 @@ class ReformerModel(ReformerPreTrainedModel):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
from transformers import ReformerModel, ReformerTokenizer
|
from transformers import ReformerModelWithLMHead, ReformerTokenizer
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
|
||||||
model = ReformerModel.from_pretrained('bert-base-uncased')
|
model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
|
||||||
|
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||||
outputs = model(input_ids)
|
outputs = model(input_ids)
|
||||||
@ -1562,7 +1562,7 @@ class ReformerModel(ReformerPreTrainedModel):
|
|||||||
if self.training is True:
|
if self.training is True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
|
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
|
||||||
input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length
|
input_shape[-1], least_common_mult_chunk_length, input_shape[-1] + padding_length
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user