mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
parent
39b2ff69d6
commit
1749841a0e
@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
|
|||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False # noqa: F841
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True # noqa: F841
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
|
@ -360,7 +360,7 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
if self.config.pretraining_tp > 1:
|
||||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention):
|
|||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer):
|
|||||||
Helper method to update the body by removing duplicates before adding new statements.
|
Helper method to update the body by removing duplicates before adding new statements.
|
||||||
"""
|
"""
|
||||||
deduplicated_new_body = []
|
deduplicated_new_body = []
|
||||||
existing_nodes = {
|
existing_nodes = set()
|
||||||
self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode)
|
for node in new_statements:
|
||||||
}
|
code = self.python_module.code_for_node(node)
|
||||||
|
comment_less_code = re.sub(r"#.*", "", code).strip()
|
||||||
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||||
|
existing_nodes.add(comment_less_code)
|
||||||
for stmt in existing_body:
|
for stmt in existing_body:
|
||||||
if self.python_module.code_for_node(stmt).strip() not in existing_nodes:
|
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
|
||||||
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
||||||
|
if comment_less_code not in existing_nodes:
|
||||||
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
|
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
|
||||||
continue
|
continue
|
||||||
deduplicated_new_body.append(stmt)
|
deduplicated_new_body.append(stmt)
|
||||||
@ -542,7 +547,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--files_to_parse",
|
"--files_to_parse",
|
||||||
default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"],
|
default=["all"],
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="A list of `diff_xxxx` files that should be converted to single model file",
|
help="A list of `diff_xxxx` files that should be converted to single model file",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user