diff --git a/transformers/modeling_ctrl.py b/transformers/modeling_ctrl.py index af651413c79..2d8f6c38335 100644 --- a/transformers/modeling_ctrl.py +++ b/transformers/modeling_ctrl.py @@ -303,11 +303,6 @@ class CTRLModel(CTRLPreTrainedModel): def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - if past is None: past_length = 0 past = [None] * len(self.h) @@ -349,42 +344,51 @@ class CTRLModel(CTRLPreTrainedModel): else: head_mask = [None] * self.config.n_layer - x = self.w(input_ids) - # x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + token_type_embeds = self.w(token_type_ids) + token_type_embeds *= np.sqrt(self.d_model_size) + else: + token_type_embeds = 0 + position_ids = position_ids.view(-1, input_shape[-1]) + + inputs_embeds = self.w(input_ids) + # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded seq_len = input_ids.shape[-1] - mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device) + mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device) - x *= np.sqrt(self.d_model_size) + inputs_embeds *= np.sqrt(self.d_model_size) - pos_x = self.pos_encoding[position_ids, :].to(x.device) - x += pos_x + pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device) - x = self.dropout(x) + hidden_states = inputs_embeds + pos_embeds + token_type_embeds - output_shape = input_shape + (x.size(-1),) + hidden_states = self.dropout(hidden_states) + + output_shape = input_shape + (inputs_embeds.size(-1),) presents = () all_hidden_states = () all_attentions = [] for i, (h, layer_past) in enumerate(zip(self.h, past)): if self.output_hidden_states: - all_hidden_states = all_hidden_states + (x.view(*output_shape),) - outputs = h(x, + all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) + outputs = h(hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]) - x, present = outputs[:2] + hidden_states, present = outputs[:2] presents = presents + (present,) if self.output_attentions: all_attentions.append(outputs[2]) - x = self.layernorm(x) - x = x.view(*output_shape) + hidden_states = self.layernorm(hidden_states) + hidden_states = hidden_states.view(*output_shape) if self.output_hidden_states: - all_hidden_states = all_hidden_states + (x,) + all_hidden_states = all_hidden_states + (hidden_states,) - outputs = (x, presents) + outputs = (hidden_states, presents) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: diff --git a/transformers/modeling_openai.py b/transformers/modeling_openai.py index 2827bf11e50..52f3b7db72a 100644 --- a/transformers/modeling_openai.py +++ b/transformers/modeling_openai.py @@ -170,7 +170,7 @@ class Attention(nn.Module): # w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights # XD: self.b may be larger than w, so we need to crop it b = self.bias[:, :, : w.size(-2), : w.size(-1)] - w = w * b + -1e9 * (1 - b) + w = w * b + - 1e4 * (1 - b) if attention_mask is not None: # Apply the attention mask diff --git a/transformers/modeling_tf_ctrl.py b/transformers/modeling_tf_ctrl.py index d9c3f494caf..b6127d27893 100644 --- a/transformers/modeling_tf_ctrl.py +++ b/transformers/modeling_tf_ctrl.py @@ -238,6 +238,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): past_length = shape_list(past[0][0])[-2] if position_ids is None: position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] + position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1]) # Attention mask. if attention_mask is not None: @@ -276,7 +277,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): token_type_embeds = 0 position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - inputs_embeds = self.w(input_ids) + inputs_embeds = self.w(input_ids, mode='embedding') # x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded seq_len = input_shape[-1] mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 4b363c6dc83..20f649ca64d 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -81,8 +81,9 @@ class TFCommonTestCases: pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining pt_model_class = getattr(transformers, pt_model_class_name) - tf_model = model_class(config, output_hidden_states=True) - pt_model = pt_model_class(config, output_hidden_states=True) + config.output_hidden_states = True + tf_model = model_class(config) + pt_model = pt_model_class(config) # Check we can load pt model in tf and vice-versa (architecture similar) tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) @@ -96,7 +97,7 @@ class TFCommonTestCases: pto = pt_model(**pt_inputs_dict) tfo = tf_model(inputs_dict) max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) - self.assertLessEqual(max_diff, 2e-2) + self.assertLessEqual(max_diff, 2e-5) def test_keyword_and_dict_args(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()