diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ff0f7082e95..2c948dd74c8 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention): query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention): ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - return_legacy_cache = False + return_legacy_cache = False # noqa: F841 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True + return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 836528ee210..92c3249247e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -360,7 +360,7 @@ class LlamaAttention(nn.Module): attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) @@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention): query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention): ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c54b8774eea..4a489ecb25d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention): value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention): ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py index d9786a9b3c4..b37ad74cb1b 100644 --- a/utils/diff_model_converter.py +++ b/utils/diff_model_converter.py @@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer): Helper method to update the body by removing duplicates before adding new statements. """ deduplicated_new_body = [] - existing_nodes = { - self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode) - } + existing_nodes = set() + for node in new_statements: + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + existing_nodes.add(comment_less_code) for stmt in existing_body: - if self.python_module.code_for_node(stmt).strip() not in existing_nodes: + comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if comment_less_code not in existing_nodes: if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring: continue deduplicated_new_body.append(stmt) @@ -542,7 +547,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"], + default=["all"], nargs="+", help="A list of `diff_xxxx` files that should be converted to single model file", )