mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Merge pull request #3011 from patrickvonplaten/add_models_special_tokens_to_specific_configs
Add models special tokens to its pretrained configs
This commit is contained in:
commit
fa2aa699da
@ -135,6 +135,8 @@ class GPT2Config(PretrainedConfig):
|
|||||||
summary_activation=None,
|
summary_activation=None,
|
||||||
summary_proj_to_labels=True,
|
summary_proj_to_labels=True,
|
||||||
summary_first_dropout=0.1,
|
summary_first_dropout=0.1,
|
||||||
|
bos_token_id=50256,
|
||||||
|
eos_token_id=50256,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -156,6 +158,9 @@ class GPT2Config(PretrainedConfig):
|
|||||||
self.summary_first_dropout = summary_first_dropout
|
self.summary_first_dropout = summary_first_dropout
|
||||||
self.summary_proj_to_labels = summary_proj_to_labels
|
self.summary_proj_to_labels = summary_proj_to_labels
|
||||||
|
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_ids = [eos_token_id]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
return self.n_positions
|
return self.n_positions
|
||||||
|
@ -149,6 +149,7 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
proj_init_std=0.01,
|
proj_init_std=0.01,
|
||||||
init_std=0.02,
|
init_std=0.02,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
|
eos_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -186,6 +187,8 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
|
||||||
|
self.eos_token_ids = [eos_token_id]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
return self.tgt_len + self.ext_len + self.mem_len
|
return self.tgt_len + self.ext_len + self.mem_len
|
||||||
|
@ -194,6 +194,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
mask_token_id=0,
|
mask_token_id=0,
|
||||||
lang_id=0,
|
lang_id=0,
|
||||||
|
bos_token_id=0,
|
||||||
|
pad_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
@ -234,6 +236,9 @@ class XLMConfig(PretrainedConfig):
|
|||||||
if "n_words" in kwargs:
|
if "n_words" in kwargs:
|
||||||
self.n_words = kwargs["n_words"]
|
self.n_words = kwargs["n_words"]
|
||||||
|
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_words(self): # For backward compatibility
|
def n_words(self): # For backward compatibility
|
||||||
return self.vocab_size
|
return self.vocab_size
|
||||||
|
@ -155,6 +155,9 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
summary_last_dropout=0.1,
|
summary_last_dropout=0.1,
|
||||||
start_n_top=5,
|
start_n_top=5,
|
||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
|
bos_token_id=1,
|
||||||
|
pad_token_id=5,
|
||||||
|
eos_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""Constructs XLNetConfig.
|
"""Constructs XLNetConfig.
|
||||||
@ -188,6 +191,10 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
self.start_n_top = start_n_top
|
self.start_n_top = start_n_top
|
||||||
self.end_n_top = end_n_top
|
self.end_n_top = end_n_top
|
||||||
|
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.eos_token_ids = [eos_token_id]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
return -1
|
return -1
|
||||||
|
@ -677,7 +677,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
|
outputs = model.generate(max_length=40, do_sample=False) # do greedy decoding
|
||||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
||||||
@ -692,7 +692,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
input_context = 'The dog'
|
input_context = 'The dog'
|
||||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
|
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
|
||||||
for i in range(3): # 3 output sequences were generated
|
for i in range(3): # 3 output sequences were generated
|
||||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
@ -339,14 +339,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
def prepare_generation_special_tokens():
|
|
||||||
return {"bos_token_id": 50256, "eos_token_id": 50256}
|
|
||||||
|
|
||||||
|
|
||||||
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
special_tokens = prepare_generation_special_tokens()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_gpt2(self):
|
def test_lm_generate_gpt2(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
@ -375,11 +368,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
] # The dog is cute too. It likes to rub on me and is good for me (the dog
|
] # The dog is cute too. It likes to rub on me and is good for me (the dog
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(input_ids)
|
||||||
input_ids,
|
|
||||||
bos_token_id=self.special_tokens["bos_token_id"],
|
|
||||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|
||||||
@ -410,11 +399,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
2635,
|
2635,
|
||||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||||
|
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(input_ids, do_sample=False)
|
||||||
input_ids,
|
|
||||||
do_sample=False,
|
|
||||||
bos_token_id=self.special_tokens["bos_token_id"],
|
|
||||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
@ -214,14 +214,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
def prepare_generation_special_tokens():
|
|
||||||
return {"eos_token_id": 0}
|
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
special_tokens = prepare_generation_special_tokens()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_transfo_xl_wt103(self):
|
def test_lm_generate_transfo_xl_wt103(self):
|
||||||
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
||||||
@ -578,6 +571,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, eos_token_ids=self.special_tokens["eos_token_id"], max_length=200)
|
output_ids = model.generate(input_ids, max_length=200)
|
||||||
|
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
@ -399,14 +399,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
def prepare_generation_special_tokens():
|
|
||||||
return {"bos_token_id": 0, "pad_token_id": 2}
|
|
||||||
|
|
||||||
|
|
||||||
class XLMModelLanguageGenerationTest(unittest.TestCase):
|
class XLMModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
special_tokens = prepare_generation_special_tokens()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_xlm_mlm_en_2048(self):
|
def test_lm_generate_xlm_mlm_en_2048(self):
|
||||||
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
|
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
|
||||||
@ -435,10 +428,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
|
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(input_ids)
|
||||||
input_ids,
|
|
||||||
bos_token_id=self.special_tokens["bos_token_id"],
|
|
||||||
pad_token_id=self.special_tokens["pad_token_id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
@ -513,14 +513,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
def prepare_generation_special_tokens():
|
|
||||||
return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2}
|
|
||||||
|
|
||||||
|
|
||||||
class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
special_tokens = prepare_generation_special_tokens()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_xlnet_base_cased(self):
|
def test_lm_generate_xlnet_base_cased(self):
|
||||||
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
||||||
@ -917,12 +910,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
# Since, however, he has had difficulty walking with Maria
|
# Since, however, he has had difficulty walking with Maria
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(input_ids, max_length=200)
|
||||||
input_ids,
|
|
||||||
bos_token_id=self.special_tokens["bos_token_id"],
|
|
||||||
pad_token_id=self.special_tokens["pad_token_id"],
|
|
||||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
|
||||||
max_length=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
Loading…
Reference in New Issue
Block a user