Fix torch.fx issue related to the new loss_kwargs keyword argument (#34380)

* Fix FX

* Unskip tests
This commit is contained in:
Michael Benayoun 2024-10-24 18:34:28 +02:00 committed by GitHub
parent d9989e0b9a
commit 1c5918d910
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1 additions and 6 deletions

View File

@ -1416,7 +1416,7 @@ class HFTracer(Tracer):
your custom tracer.
"""
attribute = HFAttribute(obj, "keys")()
if obj.node.target == "**kwargs":
if obj.node.target.startswith("**"):
return attribute._metadata
return attribute

View File

@ -304,7 +304,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()

View File

@ -356,7 +356,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()

View File

@ -356,7 +356,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()

View File

@ -368,7 +368,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()

View File

@ -391,7 +391,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()