mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixing CTRL tests and OpenAI GPT tests
This commit is contained in:
parent
6dce6dda1b
commit
c19b8e4ae0
@ -303,11 +303,6 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
@ -349,42 +344,51 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.n_layer
|
||||
|
||||
x = self.w(input_ids)
|
||||
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
token_type_embeds = self.w(token_type_ids)
|
||||
token_type_embeds *= np.sqrt(self.d_model_size)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
inputs_embeds = self.w(input_ids)
|
||||
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
seq_len = input_ids.shape[-1]
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
|
||||
|
||||
x *= np.sqrt(self.d_model_size)
|
||||
inputs_embeds *= np.sqrt(self.d_model_size)
|
||||
|
||||
pos_x = self.pos_encoding[position_ids, :].to(x.device)
|
||||
x += pos_x
|
||||
pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device)
|
||||
|
||||
x = self.dropout(x)
|
||||
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
|
||||
|
||||
output_shape = input_shape + (x.size(-1),)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output_shape = input_shape + (inputs_embeds.size(-1),)
|
||||
presents = ()
|
||||
all_hidden_states = ()
|
||||
all_attentions = []
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (x.view(*output_shape),)
|
||||
outputs = h(x,
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
outputs = h(hidden_states,
|
||||
mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i])
|
||||
x, present = outputs[:2]
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
|
||||
x = self.layernorm(x)
|
||||
x = x.view(*output_shape)
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (x,)
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (x, presents)
|
||||
outputs = (hidden_states, presents)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
|
@ -170,7 +170,7 @@ class Attention(nn.Module):
|
||||
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
|
||||
# XD: self.b may be larger than w, so we need to crop it
|
||||
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
|
||||
w = w * b + -1e9 * (1 - b)
|
||||
w = w * b + - 1e4 * (1 - b)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
|
@ -238,6 +238,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
@ -276,7 +277,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
token_type_embeds = 0
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
inputs_embeds = self.w(input_ids)
|
||||
inputs_embeds = self.w(input_ids, mode='embedding')
|
||||
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
seq_len = input_shape[-1]
|
||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||
|
@ -81,8 +81,9 @@ class TFCommonTestCases:
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config, output_hidden_states=True)
|
||||
pt_model = pt_model_class(config, output_hidden_states=True)
|
||||
config.output_hidden_states = True
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa (architecture similar)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||
@ -96,7 +97,7 @@ class TFCommonTestCases:
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(inputs_dict)
|
||||
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
self.assertLessEqual(max_diff, 2e-5)
|
||||
|
||||
def test_keyword_and_dict_args(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user