[Reformer] Fix example and error message (#4191)

* fix example reformer

* fix error message and example docstring

* improved error message
This commit is contained in:
Patrick von Platen 2020-05-07 10:50:11 +02:00 committed by GitHub
parent 96c78396ce
commit 74ffc9ea6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -124,8 +124,8 @@ class AxialPositionEmbeddings(nn.Module):
if self.training is True:
assert (
reduce(mul, self.axial_pos_shape) == sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(
self.axial_pos_shape, sequence_length
), "If training, make sure that config.axial_pos_shape factors: {} multiply to sequence length. Got prod({}) != sequence_length: {}. You might want to consider padding your sequence length to {} or changing config.axial_pos_shape.".format(
self.axial_pos_shape, self.axial_pos_shape, sequence_length, reduce(mul, self.axial_pos_shape)
)
if self.dropout > 0:
weights = torch.cat(broadcasted_weights, dim=-1)
@ -1515,11 +1515,11 @@ class ReformerModel(ReformerPreTrainedModel):
Examples::
from transformers import ReformerModel, ReformerTokenizer
from transformers import ReformerModelWithLMHead, ReformerTokenizer
import torch
tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased')
model = ReformerModel.from_pretrained('bert-base-uncased')
tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
@ -1562,7 +1562,7 @@ class ReformerModel(ReformerPreTrainedModel):
if self.training is True:
raise ValueError(
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length
input_shape[-1], least_common_mult_chunk_length, input_shape[-1] + padding_length
)
)