[GemmaModel] fix small typo (#31202)

* fixes

* fix-copies
This commit is contained in:
Arthur 2024-06-03 11:02:38 +02:00 committed by GitHub
parent 39b2ff69d6
commit 1749841a0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 14 deletions

View File

@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:

View File

@ -360,7 +360,7 @@ class LlamaAttention(nn.Module):
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention):
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention):
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention):
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer):
Helper method to update the body by removing duplicates before adding new statements. Helper method to update the body by removing duplicates before adding new statements.
""" """
deduplicated_new_body = [] deduplicated_new_body = []
existing_nodes = { existing_nodes = set()
self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode) for node in new_statements:
} code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
existing_nodes.add(comment_less_code)
for stmt in existing_body: for stmt in existing_body:
if self.python_module.code_for_node(stmt).strip() not in existing_nodes: comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if comment_less_code not in existing_nodes:
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring: if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
continue continue
deduplicated_new_body.append(stmt) deduplicated_new_body.append(stmt)
@ -542,7 +547,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--files_to_parse", "--files_to_parse",
default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"], default=["all"],
nargs="+", nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file", help="A list of `diff_xxxx` files that should be converted to single model file",
) )