fixing CTRL tests and OpenAI GPT tests

This commit is contained in:
thomwolf 2019-10-09 13:51:05 +02:00
parent 6dce6dda1b
commit c19b8e4ae0
4 changed files with 31 additions and 25 deletions

View File

@ -303,11 +303,6 @@ class CTRLModel(CTRLPreTrainedModel):
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
@ -349,42 +344,51 @@ class CTRLModel(CTRLPreTrainedModel):
else:
head_mask = [None] * self.config.n_layer
x = self.w(input_ids)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_embeds = self.w(token_type_ids)
token_type_embeds *= np.sqrt(self.d_model_size)
else:
token_type_embeds = 0
position_ids = position_ids.view(-1, input_shape[-1])
inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
x *= np.sqrt(self.d_model_size)
inputs_embeds *= np.sqrt(self.d_model_size)
pos_x = self.pos_encoding[position_ids, :].to(x.device)
x += pos_x
pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device)
x = self.dropout(x)
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
output_shape = input_shape + (x.size(-1),)
hidden_states = self.dropout(hidden_states)
output_shape = input_shape + (inputs_embeds.size(-1),)
presents = ()
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x.view(*output_shape),)
outputs = h(x,
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h(hidden_states,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i])
x, present = outputs[:2]
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
x = self.layernorm(x)
x = x.view(*output_shape)
hidden_states = self.layernorm(hidden_states)
hidden_states = hidden_states.view(*output_shape)
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x,)
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (x, presents)
outputs = (hidden_states, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:

View File

@ -170,7 +170,7 @@ class Attention(nn.Module):
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
w = w * b + -1e9 * (1 - b)
w = w * b + - 1e4 * (1 - b)
if attention_mask is not None:
# Apply the attention mask

View File

@ -238,6 +238,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1])
# Attention mask.
if attention_mask is not None:
@ -276,7 +277,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.w(input_ids)
inputs_embeds = self.w(input_ids, mode='embedding')
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)

View File

@ -81,8 +81,9 @@ class TFCommonTestCases:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config, output_hidden_states=True)
pt_model = pt_model_class(config, output_hidden_states=True)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa (architecture similar)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
@ -96,7 +97,7 @@ class TFCommonTestCases:
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
self.assertLessEqual(max_diff, 2e-5)
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()