From 304c6a1e0dcad79f4930858efd6a96f9e46a8fcc Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 17 Apr 2024 14:38:48 +0500 Subject: [PATCH] Enable fx tracing for Mistral (#30209) * tracing for mistral * typo * fix copies --- src/transformers/models/mixtral/modeling_mixtral.py | 3 --- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 --- src/transformers/utils/fx.py | 5 +++++ tests/models/mistral/test_modeling_mistral.py | 1 + tests/models/mixtral/test_modeling_mixtral.py | 1 + tests/models/qwen2/test_modeling_qwen2.py | 1 + tests/models/qwen2_moe/test_modeling_qwen2_moe.py | 1 + 7 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index eb5794f640a..c78e907d5fd 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -868,9 +868,6 @@ class MixtralSparseMoeBlock(nn.Module): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - if top_x.shape[0] == 0: - continue - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7b9e2f6fc0a..70072c91720 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -840,9 +840,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - if top_x.shape[0] == 0: - continue - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index df0aba8d5d4..ab4f823c2fc 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -141,12 +141,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "marian", "mbart", "megatron-bert", + "mistral", + "mixtral", "mobilebert", "mt5", "nezha", "opt", "pegasus", "plbart", + "qwen2", + "qwen2_moe", "resnet", "roberta", "segformer", @@ -758,6 +762,7 @@ class HFTracer(Tracer): "tensor", "clamp", "finfo", + "tril", ] supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 59f3cdea695..3500024b3ea 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -303,6 +303,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False + fx_compatible = True # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0cc8c9fc442..0d92595d8cf 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -302,6 +302,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False + fx_compatible = True # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 21ee694bdca..f4e88a97f06 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -313,6 +313,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False + fx_compatible = True # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index e322e5f8490..f0818e680d3 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -342,6 +342,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ) test_headmasking = False test_pruning = False + fx_compatible = True # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip(