mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
a21557fa3e
commit
0b0ede8b2b
@ -365,21 +365,6 @@ class SuperTransformer(cst.CSTTransformer):
|
||||
return updated_node.with_changes(body=new_body, params=updated_node.params)
|
||||
return updated_node
|
||||
|
||||
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.Return:
|
||||
""" "When a return statement is reached, it is replaced with the unrolled super code"""
|
||||
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
|
||||
func_def = self.get_metadata(ParentNodeProvider, original_node)
|
||||
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_modeling_methods:
|
||||
updated_return_value = updated_node.value.with_changes(
|
||||
args=[
|
||||
cst.Arg(
|
||||
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
|
||||
)
|
||||
]
|
||||
)
|
||||
return updated_node.with_changes(value=updated_return_value)
|
||||
return updated_node
|
||||
|
||||
|
||||
def find_all_dependencies(
|
||||
dependency_mapping: dict[str, set],
|
||||
|
Loading…
Reference in New Issue
Block a user