diff --git a/.github/workflows/self-scheduled-intel-gaudi.yml b/.github/workflows/self-scheduled-intel-gaudi.yml index 2db5ece064b..ad14d66b58b 100644 --- a/.github/workflows/self-scheduled-intel-gaudi.yml +++ b/.github/workflows/self-scheduled-intel-gaudi.yml @@ -84,8 +84,6 @@ jobs: machine_type: ${{ matrix.machine_type }} folder_slices: ${{ needs.setup.outputs.folder_slices }} runner: ${{ inputs.runner_scale_set }}-${{ matrix.machine_type }} - report_name_prefix: run_models_gpu - secrets: inherit run_trainer_and_fsdp_gpu: @@ -104,11 +102,10 @@ jobs: folder_slices: ${{ needs.setup.outputs.folder_slices }} runner: ${{ inputs.runner_scale_set }}-${{ matrix.machine_type }} report_name_prefix: run_trainer_and_fsdp_gpu - secrets: inherit - run_pipelines_gpu: - if: ${{ inputs.job == 'run_pipelines_gpu' }} + run_pipelines_torch_gpu: + if: ${{ inputs.job == 'run_pipelines_torch_gpu' }} name: Pipelines strategy: fail-fast: false @@ -161,20 +158,20 @@ jobs: - name: Run all pipeline tests on Intel Gaudi run: | - python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_pipelines_gpu_test_reports tests/pipelines -m "not not_device_test" + python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports tests/pipelines -m "not not_device_test" - name: Failure short reports if: ${{ failure() }} continue-on-error: true run: | - cat reports/${{ env.machine_type }}_run_pipelines_gpu_test_reports/failures_short.txt + cat reports/${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports/failures_short.txt - - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_pipelines_gpu_test_reports" + - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports" if: ${{ always() }} uses: actions/upload-artifact@v4 with: - name: ${{ env.machine_type }}_run_pipelines_gpu_test_reports - path: reports/${{ env.machine_type }}_run_pipelines_gpu_test_reports + name: ${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports + path: reports/${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports run_examples_gpu: if: ${{ inputs.job == 'run_examples_gpu' }} @@ -248,8 +245,8 @@ jobs: name: ${{ env.machine_type }}_run_examples_gpu_test_reports path: reports/${{ env.machine_type }}_run_examples_gpu_test_reports - run_deepspeed_gpu: - if: ${{ inputs.job == 'run_deepspeed_gpu' }} + run_torch_cuda_extensions_gpu: + if: ${{ inputs.job == 'run_torch_cuda_extensions_gpu' }} name: Intel Gaudi deepspeed tests strategy: fail-fast: false @@ -305,20 +302,20 @@ jobs: - name: Run all deepspeed tests on intel Gaudi run: | - python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_deepspeed_gpu_test_reports tests/deepspeed -m "not not_device_test" + python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports tests/deepspeed -m "not not_device_test" - name: Failure short reports if: ${{ failure() }} continue-on-error: true run: | - cat reports/${{ env.machine_type }}_run_deepspeed_gpu_test_reports/failures_short.txt + cat reports/${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports/failures_short.txt - - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_deepspeed_gpu_test_reports" + - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports" if: ${{ always() }} uses: actions/upload-artifact@v4 with: - name: ${{ env.machine_type }}_run_deepspeed_gpu_test_reports - path: reports/${{ env.machine_type }}_run_deepspeed_gpu_test_reports + name: ${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports + path: reports/${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports send_results: name: Slack Report @@ -327,8 +324,8 @@ jobs: setup, run_models_gpu, run_examples_gpu, - run_pipelines_gpu, - run_deepspeed_gpu, + run_torch_cuda_extensions_gpu, + run_pipelines_torch_gpu, run_trainer_and_fsdp_gpu, ] if: ${{ always() }} diff --git a/.github/workflows/self-scheduled-intel-gaudi3-caller.yml b/.github/workflows/self-scheduled-intel-gaudi3-caller.yml index 83cb89290d3..8a3d70c4d43 100644 --- a/.github/workflows/self-scheduled-intel-gaudi3-caller.yml +++ b/.github/workflows/self-scheduled-intel-gaudi3-caller.yml @@ -23,7 +23,7 @@ jobs: name: Pipeline CI uses: ./.github/workflows/self-scheduled-intel-gaudi.yml with: - job: run_pipelines_gpu + job: run_pipelines_torch_gpu ci_event: Scheduled CI (Intel) - Gaudi3 runner_scale_set: itac-bm-emr-gaudi3-dell slack_report_channel: "#transformers-ci-daily-intel-gaudi3" @@ -47,7 +47,7 @@ jobs: name: DeepSpeed CI uses: ./.github/workflows/self-scheduled-intel-gaudi.yml with: - job: run_deepspeed_gpu + job: run_torch_cuda_extensions_gpu ci_event: Scheduled CI (Intel) - Gaudi3 runner_scale_set: itac-bm-emr-gaudi3-dell slack_report_channel: "#transformers-ci-daily-intel-gaudi3" diff --git a/docs/source/en/model_doc/dia.md b/docs/source/en/model_doc/dia.md index 67c4a3be0b6..a4a2f84c78b 100644 --- a/docs/source/en/model_doc/dia.md +++ b/docs/source/en/model_doc/dia.md @@ -44,7 +44,7 @@ tokens and decodes them back into audio. from transformers import AutoProcessor, DiaForConditionalGeneration torch_device = "cuda" -model_checkpoint = "buttercrab/dia-v1-1.6b" +model_checkpoint = "nari-labs/Dia-1.6B-0626" text = ["[S1] Dia is an open weights text to dialogue model."] processor = AutoProcessor.from_pretrained(model_checkpoint) @@ -66,7 +66,7 @@ from datasets import load_dataset, Audio from transformers import AutoProcessor, DiaForConditionalGeneration torch_device = "cuda" -model_checkpoint = "buttercrab/dia-v1-1.6b" +model_checkpoint = "nari-labs/Dia-1.6B-0626" ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") ds = ds.cast_column("audio", Audio(sampling_rate=44100)) @@ -93,7 +93,7 @@ from datasets import load_dataset, Audio from transformers import AutoProcessor, DiaForConditionalGeneration torch_device = "cuda" -model_checkpoint = "buttercrab/dia-v1-1.6b" +model_checkpoint = "nari-labs/Dia-1.6B-0626" ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") ds = ds.cast_column("audio", Audio(sampling_rate=44100)) diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 91979590046..8f6f49f26bc 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -333,6 +333,11 @@ class ChatCommand(BaseTransformersCLICommand): ) args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1) + + if args.model_name_or_path is None: + raise ValueError( + "When connecting to a server, please specify a model name with the --model_name_or_path flag." + ) else: self.spawn_backend = True args.model_name_or_path = args.model_name_or_path_or_address diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 9b886b27210..d8f61603692 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -347,7 +347,7 @@ class ServeCommand(BaseTransformersCLICommand): if not req.stream: return {"error": "Only streaming mode is supported."} - update_model = req.model != self.loaded_model + update_model = self.canonicalized_model_name(req.model) != self.loaded_model if update_model: self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) @@ -402,7 +402,7 @@ class ServeCommand(BaseTransformersCLICommand): if self.last_messages is None: req_continues_last_messages = False # The new request has fewer rounds of conversation: this is a new request - elif len(self.last_messages) > len(req.messages): + elif len(self.last_messages) >= len(req.messages): req_continues_last_messages = False # Otherwise, check that the last messages are a subset of the new request else: @@ -417,7 +417,7 @@ class ServeCommand(BaseTransformersCLICommand): def generate(self, app): @app.post("/v1/chat/completions") def _serve(req: "ChatCompletionInput"): - update_model = req.model != self.loaded_model + update_model = self.canonicalized_model_name(req.model) != self.loaded_model if update_model: self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) @@ -585,6 +585,11 @@ class ServeCommand(BaseTransformersCLICommand): return quantization_config + def canonicalized_model_name(self, model_id: str) -> str: + if "@" in model_id: + return model_id + return f"{model_id}@main" + def load_model_and_tokenizer( self, model_id_and_revision: str, args: ServeArguments ) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]: @@ -621,9 +626,9 @@ class ServeCommand(BaseTransformersCLICommand): if getattr(model, "hf_device_map", None) is None: model = model.to(args.device) - self.loaded_model = model_id_and_revision + self.loaded_model = f"{model_id}@{revision}" - print("Loaded model", model_id_and_revision) + logger.warning(f"Loaded model {self.loaded_model}") return model, tokenizer diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index de6ae44bb5a..13a1c83a719 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3773,16 +3773,28 @@ class GenerationMixin(ContinuousMixin): Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False """ # a. Can the open beams improve the top completed scores? - # early_stopping == False -> apply heuristic = always get the best score from - # `cur_len - decoder_prompt_len`. See the discussion below for more details. - # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use # `max_length` there. + # !! + # Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization + # does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences + # compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut + # generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite + # its name. These modifications were empirically found in the past (prior to 2022) to produce better quality + # generations, and changing them is BC breaking. + # For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`. + # See the discussion below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # !! if early_stopping == "never" and length_penalty > 0.0: best_hypothetical_length = max_length - decoder_prompt_len else: best_hypothetical_length = cur_len - decoder_prompt_len + + # best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before + # applying length penalty best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) improvement_possible = torch.any(best_possible_running_score > worst_finished_score) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 269382c3768..fc0a249850c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4431,10 +4431,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization." ) - # If torchrun was used, make sure to TP by default. This way people don't need to change tp or device map - if device_map == "auto" and tp_plan is None and int(os.environ.get("WORLD_SIZE", 0)): - tp_plan = "auto" # device_map = "auto" in torchrun equivalent to TP plan = AUTO! - device_map = None + if device_map == "auto" and int(os.environ.get("WORLD_SIZE", 0)): + logger.info( + "You've set device_map=`auto` while triggering a distributed run with torchrun. This might lead to unexpected behavior. " + "If your plan is to load the model on each device, you should set device_map={" + ": PartialState().process_index} where PartialState comes from accelerate library" + ) # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple # `device_map` pointing to the correct device diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 4148dfd10ac..a97d9c0f827 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1269,13 +1269,13 @@ class Glm4vModel(Glm4vPreTrainedModel): if input_ids is None: video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) video_mask = video_mask.all(-1) else: - video_mask = input_ids == self.config.video_token_id + video_mask = input_ids == self.config.image_token_id - n_video_tokens = (video_mask).sum() + n_video_tokens = video_mask.sum() n_video_features = video_embeds.shape[0] video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index cf4a6b9233f..5732503daa6 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -1269,13 +1269,13 @@ class Glm4vModel(Qwen2_5_VLModel): if input_ids is None: video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) video_mask = video_mask.all(-1) else: - video_mask = input_ids == self.config.video_token_id + video_mask = input_ids == self.config.image_token_id - n_video_tokens = (video_mask).sum() + n_video_tokens = video_mask.sum() n_video_features = video_embeds.shape[0] video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b5c07bebb86..abfac77418c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2357,7 +2357,7 @@ class Trainer: model = self.accelerator.prepare(self.model) else: if delay_optimizer_creation: - self.optimizer = self.accelerator.prepare(self.optimizer) + model = self.accelerator.prepare(self.model) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 2c28cbde292..5c9c6d5690d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -865,50 +865,59 @@ def is_torch_hpu_available(): if not hasattr(torch, "hpu") or not torch.hpu.is_available(): return False - import habana_frameworks.torch.utils.experimental as htexp # noqa: F401 - - # IlyasMoutawwakil: We patch masked_fill_ for int64 tensors to avoid a bug on Gaudi1 - # synNodeCreateWithId failed for node: masked_fill_fwd_i64 with synStatus 26 [Generic failure] - # This can be removed once Gaudi1 support is discontinued but for now we need it to keep using - # dl1.24xlarge Gaudi1 instances on AWS for testing. - # check if the device is Gaudi1 (vs Gaudi2, Gaudi3). - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi: - original_masked_fill_ = torch.Tensor.masked_fill_ - - def patched_masked_fill_(self, mask, value): - if self.dtype == torch.int64: - logger.warning_once( - "In-place tensor.masked_fill_(mask, value) is not supported for int64 tensors on Gaudi1. " - "This operation will be performed out-of-place using tensor[mask] = value." - ) - self[mask] = value - else: - original_masked_fill_(self, mask, value) - - torch.Tensor.masked_fill_ = patched_masked_fill_ - # We patch torch.gather for int64 tensors to avoid a bug on Gaudi # Graph compile failed with synStatus 26 [Generic failure] # This can be removed once bug is fixed but for now we need it. - original_gather = torch.Tensor.gather + original_gather = torch.gather def patched_gather(input: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor: if input.dtype == torch.int64 and input.device.type == "hpu": - logger.warning_once( - "torch.gather is not supported for int64 tensors on Gaudi. " - "This operation will be performed patched_gather using indexing." - ) - - idx = [torch.arange(size, device=input.device, dtype=input.dtype) for size in input.shape] - idx[dim] = index - idx = tuple(idx) - output = input[idx] - return output + return original_gather(input.to(torch.int32), dim, index).to(torch.int64) else: return original_gather(input, dim, index) + torch.gather = patched_gather torch.Tensor.gather = patched_gather + original_take_along_dim = torch.take_along_dim + + def patched_take_along_dim( + input: torch.Tensor, indices: torch.LongTensor, dim: Optional[int] = None + ) -> torch.Tensor: + if input.dtype == torch.int64 and input.device.type == "hpu": + return original_take_along_dim(input.to(torch.int32), indices, dim).to(torch.int64) + else: + return original_take_along_dim(input, indices, dim) + + torch.take_along_dim = patched_take_along_dim + + original_cholesky = torch.linalg.cholesky + + def safe_cholesky(A, *args, **kwargs): + output = original_cholesky(A, *args, **kwargs) + + if torch.isnan(output).any(): + jitter_value = 1e-9 + diag_jitter = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) * jitter_value + output = original_cholesky(A + diag_jitter, *args, **kwargs) + + return output + + torch.linalg.cholesky = safe_cholesky + + original_scatter = torch.scatter + + def patched_scatter( + input: torch.Tensor, dim: int, index: torch.Tensor, src: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + if input.device.type == "hpu" and input is src: + return original_scatter(input, dim, index, src.clone(), *args, **kwargs) + else: + return original_scatter(input, dim, index, src, *args, **kwargs) + + torch.scatter = patched_scatter + torch.Tensor.scatter = patched_scatter + # IlyasMoutawwakil: we patch torch.compile to use the HPU backend by default # https://github.com/huggingface/transformers/pull/38790#discussion_r2157043944 # This is necessary for cases where torch.compile is used as a decorator (defaulting to inductor) diff --git a/tests/commands/test_chat.py b/tests/commands/test_chat.py index 6ba3413fafa..e07df4a3938 100644 --- a/tests/commands/test_chat.py +++ b/tests/commands/test_chat.py @@ -29,12 +29,34 @@ class ChatCLITest(unittest.TestCase): self.assertIn("chat interface", cs.out.lower()) @patch.object(ChatCommand, "run") - def test_cli_dispatch(self, run_mock): + def test_cli_dispatch_model(self, run_mock): + """ + Running transformers chat with just a model should work & spawn a serve underneath + """ args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"] with patch("sys.argv", args): cli.main() run_mock.assert_called_once() + def test_cli_dispatch_url(self): + """ + Running transformers chat with just a URL should not work as a model should additionally be specified + """ + args = ["transformers", "chat", "localhost:8000"] + with self.assertRaises(ValueError): + with patch("sys.argv", args): + cli.main() + + @patch.object(ChatCommand, "run") + def test_cli_dispatch_url_and_model(self, run_mock): + """ + Running transformers chat with a URL and a model should work + """ + args = ["transformers", "chat", "localhost:8000", "--model_name_or_path=hf-internal-testing/tiny-random-gpt2"] + with patch("sys.argv", args): + cli.main() + run_mock.assert_called_once() + def test_parsed_args(self): with ( patch.object(ChatCommand, "__init__", return_value=None) as init_mock, diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 5f147220daf..15c520d1d21 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -462,6 +462,9 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() + def test_batching_equivalence(self, atol=3e-4, rtol=3e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="Start to fail after using torch `cu118`.") def test_multi_gpu_data_parallel_forward(self): super().test_multi_gpu_data_parallel_forward() diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index ece423338a6..cdab28a3a7d 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -25,6 +25,7 @@ from transformers import ( AriaTextConfig, AutoProcessor, AutoTokenizer, + BitsAndBytesConfig, is_torch_available, is_vision_available, ) @@ -52,6 +53,9 @@ if is_torch_available(): if is_vision_available(): from PIL import Image +# Used to be https://aria-vl.github.io/static/images/view.jpg but it was removed, llava-vl has the same image +IMAGE_OF_VIEW_URL = "https://llava-vl.github.io/static/images/view.jpg" + class AriaVisionText2TextModelTester: def __init__( @@ -262,23 +266,38 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): @require_bitsandbytes def test_small_model_integration_test(self): # Let's make sure we test the preprocessing to replace what is used - model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) + model = AriaForConditionalGeneration.from_pretrained( + "rhymes-ai/Aria", + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) - prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" - image_file = "https://aria-vl.github.io/static/images/view.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") + prompt = "<|img|>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" + raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw) + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device, model.dtype) - EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip + non_img_tokens = [ + 109, 3905, 2000, 93415, 4551, 1162, 901, 3894, 970, 2478, 1017, 19312, 2388, 1596, 1809, 970, 5449, 1235, + 3333, 93483, 109, 61081, 11984, 14800, 93415 + ] # fmt: skip + EXPECTED_INPUT_IDS = torch.tensor([[9] * 256 + non_img_tokens]).to(inputs["input_ids"].device) self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip + decoded_output = self.processor.decode(output[0], skip_special_tokens=True) - self.assertEqual( - self.processor.decode(output[0], skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) + expected_output = Expectations( + { + ( + "cuda", + None, + ): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,", + ( + "rocm", + (9, 5), + ): "\n USER: What are the things I should be cautious about when I visit this place?\n ASSISTANT: When you visit this place, you should be cautious about the following things:\n\n- The", + } + ).get_expectation() + self.assertEqual(decoded_output, expected_output) @slow @require_torch_large_accelerator @@ -287,20 +306,29 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): # Let's make sure we test the preprocessing to replace what is used model_id = "rhymes-ai/Aria" - model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + model = AriaForConditionalGeneration.from_pretrained( + model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) processor = AutoProcessor.from_pretrained(model_id) - prompt = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT:" - image_file = "https://aria-vl.github.io/static/images/view.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) + prompt = "USER: <|img|>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:" + raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device, model.dtype) - output = model.generate(**inputs, max_new_tokens=900, do_sample=False) - EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip + output = model.generate(**inputs, max_new_tokens=90, do_sample=False) + EXPECTED_DECODED_TEXT = Expectations( + { + ("cuda", (8, 0)): "USER: \n What are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this beautiful location, it's important to be mindful of a few things to ensure both your safety and the preservation of the environment. Firstly, always be cautious when walking on the wooden pier, as it can be slippery, especially during or after rain. Secondly, be aware of the local wildlife and do not feed or disturb them. Lastly, respect the natural surroundings by not littering and sticking to", + ("rocm", (9, 5)): "USER: \n What are the things I should be cautious about when I visit this place? ASSISTANT: \n\nWhen visiting this place, you should be cautious about the following:\n\n1. **Weather Conditions**: The weather can be unpredictable, so it's important to check the forecast and dress in layers. Sudden changes in weather can occur, so be prepared for rain or cold temperatures.\n\n2. **Safety on the Dock**: The dock may be slippery, especially when", + } + ).get_expectation() # fmt: off + decoded_output = processor.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) self.assertEqual( - processor.decode(output[0], skip_special_tokens=True), + decoded_output, EXPECTED_DECODED_TEXT, + f"Expected: {repr(EXPECTED_DECODED_TEXT)}\nActual: {repr(decoded_output)}", ) @slow @@ -310,53 +338,77 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): # Let's make sure we test the preprocessing to replace what is used model_id = "rhymes-ai/Aria" - model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + model = AriaForConditionalGeneration.from_pretrained( + model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) processor = AutoProcessor.from_pretrained(model_id) prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:", - "USER: \nWhat is this? ASSISTANT:", + "USER: <|img|>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:", + "USER: <|img|>\nWhat is this? ASSISTANT:", ] - image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw) + image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to( + model.device, model.dtype + ) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip + EXPECTED_DECODED_TEXT = Expectations( + { + ("cuda", None): [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you", + "USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on", + ], + ("rocm", (9, 5)): [ + "USER: \n What are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: \n\nWhen visiting this place, you should be cautious about the weather conditions, as it", + "USER: \n What is this? ASSISTANT: This is a picture of two cats sleeping on a couch. USER: What is the color of", + ], + } + ).get_expectation() - self.assertEqual( - processor.batch_decode(output, skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) + decoded_output = processor.batch_decode(output, skip_special_tokens=True) + self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT) @slow @require_torch_large_accelerator @require_bitsandbytes def test_small_model_integration_test_batch(self): # Let's make sure we test the preprocessing to replace what is used - model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) + model = AriaForConditionalGeneration.from_pretrained( + "rhymes-ai/Aria", + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", - "USER: \nWhat is this?\nASSISTANT:", + "USER: <|img|>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: <|img|>\nWhat is this?\nASSISTANT:", ] - image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw) + image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to( + model.device, model.dtype + ) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = [ - 'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.', - 'USER: \nWhat is this?\nASSISTANT: Cats' - ] # fmt: skip - self.assertEqual( - self.processor.batch_decode(output, skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) + EXPECTED_DECODED_TEXT = Expectations({ + ("cuda", None): [ + 'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.', + 'USER: \nWhat is this?\nASSISTANT: Cats', + ], + ("rocm", (9, 5)): [ + 'USER: \n What are the things I should be cautious about when I visit this place? What should I bring with me?\n ASSISTANT: \n\nWhen visiting this place, you should be cautious about the following:\n\n-', + 'USER: \n What is this?\n ASSISTANT: This is a picture of two cats sleeping on a couch. The couch is red, and the cats', + ], + }).get_expectation() # fmt: skip + + decoded_output = self.processor.batch_decode(output, skip_special_tokens=True) + self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT) @slow @require_torch_large_accelerator @@ -366,26 +418,31 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): model_id = "rhymes-ai/Aria" # Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before) - model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True, attn_implementation="eager") + model = AriaForConditionalGeneration.from_pretrained( + model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) processor = AutoProcessor.from_pretrained(model_id, pad_token="") prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", - "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", + "USER: <|img|>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: <|img|>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <|img|>\nAnd this?\nASSISTANT:", ] - image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw) + image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True) + inputs = inputs.to(model.device, model.dtype) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + EXPECTED_DECODED_TEXT = Expectations({ + ("cuda", None): ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'], + ("rocm", (9, 5)): ['USER: \n What are the things I should be cautious about when I visit this place? What should I bring with me?\n ASSISTANT: \n\nWhen visiting this place, you should be cautious about the weather conditions, as it', 'USER: \n What is this?\n ASSISTANT: Two cats lying on a bed!\n USER: \n And this?\n ASSISTANT: A serene lake scene with a wooden dock extending into the water.\n USER: \n'] + }).get_expectation() # fmt: skip - self.assertEqual( - processor.batch_decode(output, skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) + decoded_output = processor.batch_decode(output, skip_special_tokens=True) + self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT) @slow @require_torch_large_accelerator @@ -395,7 +452,8 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): # Skip multihead_attn for 4bit because MHA will read the original weight without dequantize. # See https://github.com/huggingface/transformers/pull/37444#discussion_r2045852538. model = AriaForConditionalGeneration.from_pretrained( - "rhymes-ai/Aria", load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"] + "rhymes-ai/Aria", + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), ) processor = AutoProcessor.from_pretrained("rhymes-ai/Aria") @@ -447,6 +505,10 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): "<|im_start|>user\n \n \n USER: What's the difference of two images?\n ASSISTANT: \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The first image features a cute, light-colored puppy sitting on a paved surface with", "<|im_start|>user\n \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The image shows a young alpaca standing on a patch of ground with some dry grass. The", ], + ("rocm", (9, 5)): [ + "<|im_start|>user\n \n \n USER: What's the difference of two images?\n ASSISTANT: \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The first image shows a cute golden retriever puppy sitting on a paved surface with a stick", + '<|im_start|>user\n \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The image shows a young llama standing on a patch of ground with some dry grass and dirt. The' + ], } ) # fmt: skip EXPECTED_OUTPUT = EXPECTED_OUTPUTS.get_expectation() @@ -480,9 +542,12 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): @require_bitsandbytes def test_generation_no_images(self): model_id = "rhymes-ai/Aria" - model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + model = AriaForConditionalGeneration.from_pretrained( + model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) processor = AutoProcessor.from_pretrained(model_id) - + assert model.device.type == "cuda", "This test is only supported on CUDA" # TODO: remove this # Prepare inputs with no images inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device) diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 248b40121a5..eb968ad9f68 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -18,7 +18,7 @@ import unittest from transformers import DPTConfig from transformers.file_utils import is_torch_available, is_vision_available from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor @@ -342,11 +342,15 @@ class DPTModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 384, 384)) self.assertEqual(predicted_depth.shape, expected_shape) - expected_slice = torch.tensor( - [[6.3199, 6.3629, 6.4148], [6.3850, 6.3615, 6.4166], [6.3519, 6.3176, 6.3575]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [[6.3199, 6.3629, 6.4148], [6.3850, 6.3615, 6.4166], [6.3519, 6.3176, 6.3575]], + ("cuda", 8): [[6.3215, 6.3635, 6.4155], [6.3863, 6.3622, 6.4174], [6.3530, 6.3184, 6.3583]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) def test_inference_semantic_segmentation(self): image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade") diff --git a/tests/models/dpt/test_modeling_dpt_auto_backbone.py b/tests/models/dpt/test_modeling_dpt_auto_backbone.py index 5ef6c11c375..1505be27cf7 100644 --- a/tests/models/dpt/test_modeling_dpt_auto_backbone.py +++ b/tests/models/dpt/test_modeling_dpt_auto_backbone.py @@ -17,7 +17,7 @@ import unittest from transformers import Dinov2Config, DPTConfig from transformers.file_utils import is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils.import_utils import get_torch_major_and_minor_version from ...test_configuration_common import ConfigTester @@ -267,11 +267,15 @@ class DPTModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 576, 736)) self.assertEqual(predicted_depth.shape, expected_shape) - expected_slice = torch.tensor( - [[6.0336, 7.1502, 7.4130], [6.8977, 7.2383, 7.2268], [7.9180, 8.0525, 8.0134]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [[6.0336, 7.1502, 7.4130], [6.8977, 7.2383, 7.2268], [7.9180, 8.0525, 8.0134]], + ("cuda", 8): [[6.0350, 7.1518, 7.4144], [6.8992, 7.2396, 7.2280], [7.9194, 8.0538, 8.0145]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) def test_inference_depth_estimation_beit(self): image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-base-384") @@ -289,11 +293,23 @@ class DPTModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 384, 384)) self.assertEqual(predicted_depth.shape, expected_shape) - expected_slice = torch.tensor( - [[2669.7061, 2663.7144, 2674.9399], [2633.9326, 2650.9092, 2665.4270], [2621.8271, 2632.0129, 2637.2290]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [2669.7061, 2663.7144, 2674.9399], + [2633.9326, 2650.9092, 2665.4270], + [2621.8271, 2632.0129, 2637.2290], + ], + ("cuda", 8): [ + [2669.4292, 2663.4121, 2674.6233], + [2633.7400, 2650.7026, 2665.2085], + [2621.6572, 2631.8452, 2637.0525], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) def test_inference_depth_estimation_swinv2(self): image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256") @@ -311,8 +327,20 @@ class DPTModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 256, 256)) self.assertEqual(predicted_depth.shape, expected_shape) - expected_slice = torch.tensor( - [[1032.7719, 1025.1886, 1030.2661], [1023.7619, 1021.0075, 1024.9121], [1022.5667, 1018.8522, 1021.4145]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [1032.7719, 1025.1886, 1030.2661], + [1023.7619, 1021.0075, 1024.9121], + [1022.5667, 1018.8522, 1021.4145], + ], + ("cuda", 8): [ + [1032.7170, 1025.0629, 1030.1941], + [1023.7309, 1020.9786, 1024.8594], + [1022.5233, 1018.8235, 1021.3312], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/dpt/test_modeling_dpt_hybrid.py b/tests/models/dpt/test_modeling_dpt_hybrid.py index fbdd88278ea..79cad886db4 100644 --- a/tests/models/dpt/test_modeling_dpt_hybrid.py +++ b/tests/models/dpt/test_modeling_dpt_hybrid.py @@ -194,6 +194,9 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() + def test_batching_equivalence(self, atol=2e-5, rtol=2e-5): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="DPT does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py index 22201f42b0e..ccd88576f86 100644 --- a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py +++ b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py @@ -24,7 +24,14 @@ from transformers import ( FastSpeech2ConformerWithHifiGanConfig, is_torch_available, ) -from transformers.testing_utils import require_g2p_en, require_torch, require_torch_accelerator, slow, torch_device +from transformers.testing_utils import ( + Expectations, + require_g2p_en, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor @@ -373,24 +380,38 @@ class FastSpeech2ConformerModelIntegrationTest(unittest.TestCase): # mel-spectrogram is too large (1, 205, 80), so only check top-left 100 elements # fmt: off - expected_mel_spectrogram = torch.tensor( - [ - [-1.2426, -1.7286, -1.6754, -1.7451, -1.6402, -1.5219, -1.4480, -1.3345, -1.4031, -1.4497], - [-0.7858, -1.4966, -1.3602, -1.4876, -1.2949, -1.0723, -1.0021, -0.7553, -0.6521, -0.6929], - [-0.7298, -1.3908, -1.0369, -1.2656, -1.0342, -0.7883, -0.7420, -0.5249, -0.3734, -0.3977], - [-0.4784, -1.3508, -1.1558, -1.4678, -1.2820, -1.0252, -1.0868, -0.9006, -0.8947, -0.8448], - [-0.3963, -1.2895, -1.2813, -1.6147, -1.4658, -1.2560, -1.4134, -1.2650, -1.3255, -1.1715], - [-1.4914, -1.3097, -0.3821, -0.3898, -0.5748, -0.9040, -1.0755, -1.0575, -1.2205, -1.0572], - [0.0197, -0.0582, 0.9147, 1.1512, 1.1651, 0.6628, -0.1010, -0.3085, -0.2285, 0.2650], - [1.1780, 0.1803, 0.7251, 1.5728, 1.6678, 0.4542, -0.1572, -0.1787, 0.0744, 0.8168], - [-0.2078, -0.3211, 1.1096, 1.5085, 1.4632, 0.6299, -0.0515, 0.0589, 0.8609, 1.4429], - [0.7831, -0.2663, 1.0352, 1.4489, 0.9088, 0.0247, -0.3995, 0.0078, 1.2446, 1.6998], - ], - device=torch_device, + expectations = Expectations( + { + (None, None): [ + [-1.2426, -1.7286, -1.6754, -1.7451, -1.6402, -1.5219, -1.4480, -1.3345, -1.4031, -1.4497], + [-0.7858, -1.4966, -1.3602, -1.4876, -1.2949, -1.0723, -1.0021, -0.7553, -0.6521, -0.6929], + [-0.7298, -1.3908, -1.0369, -1.2656, -1.0342, -0.7883, -0.7420, -0.5249, -0.3734, -0.3977], + [-0.4784, -1.3508, -1.1558, -1.4678, -1.2820, -1.0252, -1.0868, -0.9006, -0.8947, -0.8448], + [-0.3963, -1.2895, -1.2813, -1.6147, -1.4658, -1.2560, -1.4134, -1.2650, -1.3255, -1.1715], + [-1.4914, -1.3097, -0.3821, -0.3898, -0.5748, -0.9040, -1.0755, -1.0575, -1.2205, -1.0572], + [0.0197, -0.0582, 0.9147, 1.1512, 1.1651, 0.6628, -0.1010, -0.3085, -0.2285, 0.2650], + [1.1780, 0.1803, 0.7251, 1.5728, 1.6678, 0.4542, -0.1572, -0.1787, 0.0744, 0.8168], + [-0.2078, -0.3211, 1.1096, 1.5085, 1.4632, 0.6299, -0.0515, 0.0589, 0.8609, 1.4429], + [0.7831, -0.2663, 1.0352, 1.4489, 0.9088, 0.0247, -0.3995, 0.0078, 1.2446, 1.6998], + ], + ("cuda", 8): [ + [-1.2425, -1.7282, -1.6750, -1.7448, -1.6400, -1.5217, -1.4478, -1.3341, -1.4026, -1.4493], + [-0.7858, -1.4967, -1.3601, -1.4875, -1.2950, -1.0725, -1.0021, -0.7553, -0.6522, -0.6929], + [-0.7303, -1.3911, -1.0370, -1.2656, -1.0345, -0.7888, -0.7423, -0.5251, -0.3737, -0.3979], + [-0.4784, -1.3506, -1.1556, -1.4677, -1.2820, -1.0253, -1.0868, -0.9006, -0.8949, -0.8448], + [-0.3968, -1.2896, -1.2811, -1.6145, -1.4660, -1.2564, -1.4135, -1.2652, -1.3258, -1.1716], + [-1.4912, -1.3092, -0.3812, -0.3886, -0.5737, -0.9034, -1.0749, -1.0571, -1.2202, -1.0567], + [0.0200, -0.0577, 0.9151, 1.1516, 1.1656, 0.6628, -0.1012, -0.3086, -0.2283, 0.2658], + [1.1778, 0.1805, 0.7255, 1.5732, 1.6680, 0.4539, -0.1572, -0.1785, 0.0751, 0.8175], + [-0.2088, -0.3212, 1.1101, 1.5085, 1.4625, 0.6293, -0.0522, 0.0587, 0.8615, 1.4432], + [0.7834, -0.2659, 1.0355, 1.4486, 0.9080, 0.0244, -0.3995, 0.0083, 1.2452, 1.6998], + ], + } ) + expected_mel_spectrogram = torch.tensor(expectations.get_expectation()).to(torch_device) # fmt: on - torch.testing.assert_close(spectrogram[0, :10, :10], expected_mel_spectrogram, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(spectrogram[0, :10, :10], expected_mel_spectrogram, rtol=2e-4, atol=2e-4) self.assertEqual(spectrogram.shape, (1, 205, model.config.num_mel_bins)) def test_training_integration(self): diff --git a/tests/models/focalnet/test_modeling_focalnet.py b/tests/models/focalnet/test_modeling_focalnet.py index d272f258910..893d9ed1ee6 100644 --- a/tests/models/focalnet/test_modeling_focalnet.py +++ b/tests/models/focalnet/test_modeling_focalnet.py @@ -17,7 +17,7 @@ import collections import unittest from transformers import FocalNetConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_backbone_common import BackboneTesterMixin @@ -425,8 +425,16 @@ class FocalNetModelIntegrationTest(unittest.TestCase): # verify the logits expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + + expectations = Expectations( + { + (None, None): [0.2166, -0.4368, 0.2191], + ("cuda", 8): [0.2168, -0.4367, 0.2190], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281) diff --git a/tests/models/glpn/test_modeling_glpn.py b/tests/models/glpn/test_modeling_glpn.py index b3e1852373a..b98743de357 100644 --- a/tests/models/glpn/test_modeling_glpn.py +++ b/tests/models/glpn/test_modeling_glpn.py @@ -164,6 +164,9 @@ class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_batching_equivalence(self, atol=3e-4, rtol=3e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + def test_for_depth_estimation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs) diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py index d632f99e2ca..953255797b5 100644 --- a/tests/models/grounding_dino/test_modeling_grounding_dino.py +++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py @@ -681,25 +681,48 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase): expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.d_model)) self.assertEqual(outputs.logits.shape, expected_shape_logits) - expected_boxes = torch.tensor( - [[0.7674, 0.4136, 0.4572], [0.2566, 0.5463, 0.4760], [0.2585, 0.5442, 0.4641]] - ).to(torch_device) - expected_logits = torch.tensor( - [[-4.8913, -0.1900, -0.2161], [-4.9653, -0.3719, -0.3950], [-5.9599, -3.3765, -3.3104]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [[0.7674, 0.4136, 0.4572], [0.2566, 0.5463, 0.4760], [0.2585, 0.5442, 0.4641]], + ("cuda", 8): [[0.7674, 0.4135, 0.4571], [0.2566, 0.5463, 0.4760], [0.2585, 0.5442, 0.4640]], + } + ) + expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + expectations = Expectations( + { + (None, None): [[-4.8913, -0.1900, -0.2161], [-4.9653, -0.3719, -0.3950], [-5.9599, -3.3765, -3.3104]], + ("cuda", 8): [[-4.8927, -0.1910, -0.2169], [-4.9657, -0.3748, -0.3980], [-5.9579, -3.3812, -3.3153]], + } + ) + expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, rtol=1e-3, atol=1e-3) expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) - torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=2e-4, atol=2e-4) # verify postprocessing results = processor.image_processor.post_process_object_detection( outputs, threshold=0.35, target_sizes=[(image.height, image.width)] )[0] - expected_scores = torch.tensor([0.4526, 0.4082]).to(torch_device) - expected_slice_boxes = torch.tensor([344.8143, 23.1796, 637.4004, 373.8295]).to(torch_device) + + expectations = Expectations( + { + (None, None): [[0.4526, 0.4082]], + ("cuda", 8): [0.4524, 0.4074], + } + ) + expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device) + + expectations = Expectations( + { + (None, None): [344.8143, 23.1796, 637.4004, 373.8295], + ("cuda", 8): [344.8210, 23.1831, 637.3943, 373.8227], + } + ) + expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) self.assertEqual(len(results["scores"]), 2) torch.testing.assert_close(results["scores"], expected_scores, rtol=1e-3, atol=1e-3) diff --git a/tests/models/hiera/test_modeling_hiera.py b/tests/models/hiera/test_modeling_hiera.py index dfbec4a4b8a..1e3ed8e7952 100644 --- a/tests/models/hiera/test_modeling_hiera.py +++ b/tests/models/hiera/test_modeling_hiera.py @@ -262,6 +262,9 @@ class HieraModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): self.config_tester.check_config_can_be_init_without_params() self.config_tester.check_config_arguments_init() + def test_batching_equivalence(self, atol=3e-4, rtol=3e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + # Overriding as Hiera `get_input_embeddings` returns HieraPatchEmbeddings def test_model_get_set_embeddings(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/levit/test_modeling_levit.py b/tests/models/levit/test_modeling_levit.py index f6226be1f87..80f8f822c51 100644 --- a/tests/models/levit/test_modeling_levit.py +++ b/tests/models/levit/test_modeling_levit.py @@ -19,7 +19,7 @@ from math import ceil, floor from transformers import LevitConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -406,6 +406,11 @@ class LevitModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([1.0448, -0.3745, -1.8317]).to(torch_device) - - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [1.0448, -0.3745, -1.8317], + ("cuda", 8): [1.0453, -0.3739, -1.8314], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/lightglue/test_modeling_lightglue.py b/tests/models/lightglue/test_modeling_lightglue.py index 20d9f2ef61c..7f36469bf53 100644 --- a/tests/models/lightglue/test_modeling_lightglue.py +++ b/tests/models/lightglue/test_modeling_lightglue.py @@ -17,7 +17,7 @@ import unittest from datasets import load_dataset from transformers.models.lightglue.configuration_lightglue import LightGlueConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import get_device_properties, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -143,6 +143,13 @@ class LightGlueModelTest(ModelTesterMixin, unittest.TestCase): self.config_tester.check_config_can_be_init_without_params() self.config_tester.check_config_arguments_init() + def test_batching_equivalence(self, atol=1e-5, rtol=1e-5): + device_properties = get_device_properties() + if device_properties[0] == "cuda" and device_properties[1] == 8: + # TODO: (ydshieh) fix this + self.skipTest(reason="After switching to A10, this test always fails, but pass on CPU or T4.") + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="LightGlueForKeypointMatching does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 352a3ef1915..2cddb1ecfd3 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -29,6 +29,7 @@ from transformers import ( is_vision_available, ) from transformers.testing_utils import ( + Expectations, cleanup, require_bitsandbytes, require_torch, @@ -378,12 +379,16 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): # verify generation output = model.generate(**inputs, do_sample=False, max_new_tokens=40) - EXPECTED_DECODED_TEXT = ( - "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", # cuda output - "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while wearing a pair of glasses that are too large for them. The glasses are", # xpu output - ) + expected_decoded_text = Expectations( + { + ("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", + ("xpu", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while wearing a pair of glasses that are too large for them. The glasses are", + ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they", + } + ).get_expectation() # fmt: off - self.assertTrue(self.processor.decode(output[0], skip_special_tokens=True) in EXPECTED_DECODED_TEXT) + decoded_text = self.processor.decode(output[0], skip_special_tokens=True) + self.assertEqual(decoded_text, expected_decoded_text) @slow @require_bitsandbytes @@ -400,15 +405,17 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): ).to(torch_device) output = model.generate(**inputs, do_sample=False, max_new_tokens=20) + decoded_text = self.processor.batch_decode(output, skip_special_tokens=True) - EXPECTED_DECODED_TEXT = [ - 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a', - 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a' - ] # fmt: skip - self.assertEqual( - self.processor.batch_decode(output, skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) + expected_decoded_text = Expectations( + { + ("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a", + ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The", + } + ).get_expectation() # fmt: off + EXPECTED_DECODED_TEXT = [expected_decoded_text, expected_decoded_text] + + self.assertEqual(decoded_text, EXPECTED_DECODED_TEXT) @slow @require_bitsandbytes @@ -435,8 +442,15 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): # verify generation output = model.generate(**inputs, do_sample=False, max_new_tokens=50) - EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"' # fmt: skip - self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT) + EXPECTED_DECODED_TEXT = Expectations( + { + ("rocm", (9, 5)): "USER: \nWhat is shown in this image? ASSISTANT: The image displays a chart that appears to be a comparison of different models or versions of a machine learning (ML) model, likely a neural network, based on their performance on a task or dataset. The chart is a scatter plot with axes labeled", + ("cuda", None): 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"', + } + ).get_expectation() # fmt: off + + decoded_text = self.processor.decode(output[0], skip_special_tokens=True) + self.assertEqual(decoded_text, EXPECTED_DECODED_TEXT) @slow @require_bitsandbytes diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index 5762a1f6ffc..cf6521424bb 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -21,6 +21,7 @@ from tests.test_modeling_common import floats_tensor from transformers import AutoModelForImageClassification, Mask2FormerConfig, is_torch_available, is_vision_available from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( + Expectations, require_timm, require_torch, require_torch_accelerator, @@ -403,7 +404,7 @@ class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC ) -TOLERANCE = 1e-4 +TOLERANCE = 2e-4 # We will verify our results on an image of cute cats @@ -438,31 +439,52 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): outputs = model(**inputs) expected_slice_hidden_state = torch.tensor( - [[-0.2790, -1.0717, -1.1668], [-0.5128, -0.3128, -0.4987], [-0.5832, 0.1971, -0.0197]] + [ + [-0.2790, -1.0717, -1.1668], + [-0.5128, -0.3128, -0.4987], + [-0.5832, 0.1971, -0.0197], + ] ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE - ) + torch.testing.assert_close( + outputs.encoder_last_hidden_state[0, 0, :3, :3], + expected_slice_hidden_state, + atol=TOLERANCE, + rtol=TOLERANCE, ) - expected_slice_hidden_state = torch.tensor( - [[0.8973, 1.1847, 1.1776], [1.1934, 1.5040, 1.5128], [1.1153, 1.4486, 1.4951]] - ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE - ) + expectations = Expectations( + { + (None, None): [ + [0.8973, 1.1847, 1.1776], + [1.1934, 1.5040, 1.5128], + [1.1153, 1.4486, 1.4951], + ], + ("cuda", 8): [ + [0.8974, 1.1848, 1.1777], + [1.1933, 1.5041, 1.5128], + [1.1154, 1.4487, 1.4950], + ], + } ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE,rtol=TOLERANCE) # fmt: skip - expected_slice_hidden_state = torch.tensor( - [[2.1152, 1.7000, -0.8603], [1.5808, 1.8004, -0.9353], [1.6043, 1.7495, -0.5999]] - ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE - ) + expectations = Expectations( + { + (None, None): [ + [2.1152, 1.7000, -0.8603], + [1.5808, 1.8004, -0.9353], + [1.6043, 1.7495, -0.5999], + ], + ("cuda", 8): [ + [2.1153, 1.7004, -0.8604], + [1.5807, 1.8007, -0.9354], + [1.6040, 1.7498, -0.6001], + ], + } ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) # fmt: skip def test_inference_universal_segmentation_head(self): model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() @@ -482,23 +504,40 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): self.assertEqual( masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4) ) - expected_slice = [ - [-8.7839, -9.0056, -8.8121], - [-7.4104, -7.0313, -6.5401], - [-6.6105, -6.3427, -6.4675], - ] - expected_slice = torch.tensor(expected_slice).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [-8.7839, -9.0056, -8.8121], + [-7.4104, -7.0313, -6.5401], + [-6.6105, -6.3427, -6.4675], + ], + ("cuda", 8): [ + [-8.7809, -9.0041, -8.8087], + [-7.4075, -7.0307, -6.5385], + [-6.6088, -6.3417, -6.4627], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) # class_queries_logits class_queries_logits = outputs.class_queries_logits self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1)) - expected_slice = torch.tensor( - [ - [1.8324, -8.0835, -4.1922], - [0.8450, -9.0050, -3.6053], - [0.3045, -7.7293, -3.0275], - ] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [1.8324, -8.0835, -4.1922], + [0.8450, -9.0050, -3.6053], + [0.3045, -7.7293, -3.0275], + ], + ("cuda", 8): [ + [1.8326, -8.0834, -4.1916], + [0.8446, -9.0048, -3.6048], + [0.3042, -7.7296, -3.0277], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close( outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE ) diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py index 2f30d4dc3c6..8644439f4a8 100644 --- a/tests/models/maskformer/test_modeling_maskformer.py +++ b/tests/models/maskformer/test_modeling_maskformer.py @@ -21,6 +21,7 @@ import numpy as np from tests.test_modeling_common import floats_tensor from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( + Expectations, require_timm, require_torch, require_torch_accelerator, @@ -478,7 +479,7 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) -TOLERANCE = 1e-4 +TOLERANCE = 2e-4 # We will verify our results on an image of cute cats @@ -513,31 +514,43 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): outputs = model(**inputs) expected_slice_hidden_state = torch.tensor( - [[-0.0482, 0.9228, 0.4951], [-0.2547, 0.8017, 0.8527], [-0.0069, 0.3385, -0.0089]] + [ + [-0.0482, 0.9228, 0.4951], + [-0.2547, 0.8017, 0.8527], + [-0.0069, 0.3385, -0.0089], + ] ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE - ) - ) + torch.allclose(outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) # fmt: skip - expected_slice_hidden_state = torch.tensor( - [[-0.8422, -0.8434, -0.9718], [-1.0144, -0.5565, -0.4195], [-1.0038, -0.4484, -0.1961]] - ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE - ) + expectations = Expectations( + { + (None, None): [[-0.8422, -0.8434, -0.9718], [-1.0144, -0.5565, -0.4195], [-1.0038, -0.4484, -0.1961]], + ("cuda", 8): [ + [-0.8422, -0.8435, -0.9717], + [-1.0145, -0.5564, -0.4195], + [-1.0040, -0.4486, -0.1962], + ], + } ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.allclose(outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE,rtol=TOLERANCE) # fmt: skip - expected_slice_hidden_state = torch.tensor( - [[0.2852, -0.0159, 0.9735], [0.6254, 0.1858, 0.8529], [-0.0680, -0.4116, 1.8413]] - ).to(torch_device) - self.assertTrue( - torch.allclose( - outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE - ) + expectations = Expectations( + { + (None, None): [ + [0.2852, -0.0159, 0.9735], + [0.6254, 0.1858, 0.8529], + [-0.0680, -0.4116, 1.8413], + ], + ("cuda", 8): [ + [0.2853, -0.0162, 0.9736], + [0.6256, 0.1856, 0.8530], + [-0.0679, -0.4118, 1.8416], + ], + } ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.allclose(outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) # fmt: skip def test_inference_instance_segmentation_head(self): model = ( @@ -562,25 +575,42 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): masks_queries_logits.shape, (1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4), ) - expected_slice = [ - [-1.3737124, -1.7724937, -1.9364233], - [-1.5977281, -1.9867939, -2.1523695], - [-1.5795398, -1.9269832, -2.093942], - ] - expected_slice = torch.tensor(expected_slice).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [-1.3737124, -1.7724937, -1.9364233], + [-1.5977281, -1.9867939, -2.1523695], + [-1.5795398, -1.9269832, -2.093942], + ], + ("cuda", 8): [ + [-1.3737, -1.7727, -1.9367], + [-1.5979, -1.9871, -2.1527], + [-1.5797, -1.9271, -2.0941], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) # class_queries_logits class_queries_logits = outputs.class_queries_logits self.assertEqual( class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1) ) - expected_slice = torch.tensor( - [ - [1.6512e00, -5.2572e00, -3.3519e00], - [3.6169e-02, -5.9025e00, -2.9313e00], - [1.0766e-04, -7.7630e00, -5.1263e00], - ] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [1.6512e00, -5.2572e00, -3.3519e00], + [3.6169e-02, -5.9025e00, -2.9313e00], + [1.0766e-04, -7.7630e00, -5.1263e00], + ], + ("cuda", 8): [ + [1.6507e00, -5.2568e00, -3.3520e00], + [3.5767e-02, -5.9023e00, -2.9313e00], + [-6.2712e-04, -7.7627e00, -5.1268e00], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close( outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE ) @@ -608,17 +638,34 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): masks_queries_logits.shape, (1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4), ) - expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]] - expected_slice = torch.tensor(expected_slice).to(torch_device) + expectations = Expectations( + { + (None, None): [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]], + ("cuda", 8): [[-0.9000, -2.6283, -4.5964], [-3.4123, -5.7789, -8.7919], [-4.9132, -7.6444, -10.7557]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) # class_queries_logits class_queries_logits = outputs.class_queries_logits self.assertEqual( class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1) ) - expected_slice = torch.tensor( - [[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [4.7188, -3.2585, -2.8857], + [6.6871, -2.9181, -1.2487], + [7.2449, -2.2764, -2.1874], + ], + ("cuda", 8): [ + [4.7177, -3.2586, -2.8853], + [6.6845, -2.9186, -1.2491], + [7.2443, -2.2760, -2.1858], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close( outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE ) diff --git a/tests/models/mgp_str/test_modeling_mgp_str.py b/tests/models/mgp_str/test_modeling_mgp_str.py index 586e9f0bc49..1ff9927f89e 100644 --- a/tests/models/mgp_str/test_modeling_mgp_str.py +++ b/tests/models/mgp_str/test_modeling_mgp_str.py @@ -140,6 +140,9 @@ class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_batching_equivalence(self, atol=1e-4, rtol=1e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="MgpstrModel does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/minimax/test_modeling_minimax.py b/tests/models/minimax/test_modeling_minimax.py index b9ae9d45155..0e36c7219be 100644 --- a/tests/models/minimax/test_modeling_minimax.py +++ b/tests/models/minimax/test_modeling_minimax.py @@ -20,6 +20,7 @@ import pytest from transformers import MiniMaxConfig, is_torch_available from transformers.cache_utils import Cache from transformers.testing_utils import ( + Expectations, require_flash_attn, require_torch, require_torch_accelerator, @@ -250,15 +251,20 @@ class MiniMaxIntegrationTest(unittest.TestCase): model_id, torch_dtype=torch.bfloat16, ).to(torch_device) - expected_slice = torch.tensor( - [[1.0312, -0.5156, -0.3262], [-0.1152, 0.4336, 0.2412], [1.2188, -0.5898, -0.0381]] - ).to(torch_device) with torch.no_grad(): logits = model(dummy_input).logits logits = logits.float() + expectations = Expectations( + { + (None, None): [[1.0312, -0.5156, -0.3262], [-0.1152, 0.4336, 0.2412], [1.2188, -0.5898, -0.0381]], + ("cuda", 8): [[1.0312, -0.5156, -0.3203], [-0.1201, 0.4375, 0.2402], [1.2188, -0.5898, -0.0396]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(logits[0, :3, :3], expected_slice, atol=1e-3, rtol=1e-3) torch.testing.assert_close(logits[1, :3, :3], expected_slice, atol=1e-3, rtol=1e-3) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 94ceb0e4a70..56c71fe7350 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -191,27 +191,26 @@ class MixtralIntegrationTest(unittest.TestCase): # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4. # # considering differences in hardware processing and potential deviations in generated text. - # fmt: off + EXPECTED_LOGITS_LEFT_UNPADDED = Expectations( { - ("xpu", 3): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7070, 0.2461]]).to(torch_device), - ("cuda", 7): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]]).to(torch_device), - ("cuda", 8): torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to(torch_device), - ("rocm", 9): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(torch_device), + ("xpu", 3): [[0.2236, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7070, 0.2461]], + ("cuda", 7): [[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]], + ("cuda", 8): [[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]], + ("rocm", 9): [[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]], } ) - expected_left_unpadded = EXPECTED_LOGITS_LEFT_UNPADDED.get_expectation() + expected_left_unpadded = torch.tensor(EXPECTED_LOGITS_LEFT_UNPADDED.get_expectation(), device=torch_device) EXPECTED_LOGITS_RIGHT_UNPADDED = Expectations( { - ("xpu", 3): torch.Tensor([[0.2178, 0.1270, -0.1641], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to(torch_device), - ("cuda", 7): torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to(torch_device), - ("cuda", 8): torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to(torch_device), - ("rocm", 9): torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(torch_device), + ("xpu", 3): [[0.2178, 0.1270, -0.1641], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]], + ("cuda", 7): [[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]], + ("cuda", 8): [[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]], + ("rocm", 9): [[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]], } ) - expected_right_unpadded = EXPECTED_LOGITS_RIGHT_UNPADDED.get_expectation() - # fmt: on + expected_right_unpadded = torch.tensor(EXPECTED_LOGITS_RIGHT_UNPADDED.get_expectation(), device=torch_device) with torch.no_grad(): logits = model(dummy_input, attention_mask=attention_mask).logits diff --git a/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py b/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py index 7c05d4b41c9..688542c727a 100644 --- a/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py +++ b/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py @@ -16,7 +16,7 @@ import unittest from transformers import MobileNetV1Config -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -246,6 +246,12 @@ class MobileNetV1ModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1001)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-4.1739, -1.1233, 3.1205]).to(torch_device) + expectations = Expectations( + { + (None, None): [-4.1739, -1.1233, 3.1205], + ("cuda", 8): [-4.1725, -1.1238, 3.1191], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py b/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py index 13c7698af5f..5f3807cda82 100644 --- a/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py +++ b/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py @@ -16,7 +16,7 @@ import unittest from transformers import MobileNetV2Config -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -301,9 +301,15 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1001)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([0.2445, -1.1993, 0.1905]).to(torch_device) + expectations = Expectations( + { + (None, None): [0.2445, -1.1993, 0.1905], + ("cuda", 8): [0.2445, -1.1970, 0.1868], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_semantic_segmentation(self): @@ -324,13 +330,20 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 21, 65, 65)) self.assertEqual(logits.shape, expected_shape) - expected_slice = torch.tensor( - [ - [[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]], - [[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]], - [[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]], - ], - device=torch_device, + expectations = Expectations( + { + (None, None): [ + [[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]], + [[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]], + [[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]], + ], + ("cuda", 8): [ + [[17.5809, 17.7571, 18.3341], [18.3240, 18.4216, 18.8974], [18.6174, 18.8662, 19.2177]], + [[-2.1562, -2.0942, -2.3703], [-2.4199, -2.2999, -2.6818], [-2.7800, -2.5944, -2.7678]], + [[4.2092, 4.8356, 4.7694], [4.4181, 5.0401, 4.9409], [4.5089, 4.9700, 4.8802]], + ], + } ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/mobilevit/test_modeling_mobilevit.py b/tests/models/mobilevit/test_modeling_mobilevit.py index f6cc09edddd..43fb0d638eb 100644 --- a/tests/models/mobilevit/test_modeling_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_mobilevit.py @@ -16,7 +16,7 @@ import unittest from transformers import MobileViTConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -304,9 +304,15 @@ class MobileViTModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-1.9364, -1.2327, -0.4653]).to(torch_device) + expectations = Expectations( + { + (None, None): [-1.9364, -1.2327, -0.4653], + ("cuda", 8): [-1.9401, -1.2384, -0.4702], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_semantic_segmentation(self): @@ -327,16 +333,23 @@ class MobileViTModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 21, 32, 32)) self.assertEqual(logits.shape, expected_shape) - expected_slice = torch.tensor( - [ - [[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]], - [[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]], - [[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]], - ], - device=torch_device, + expectations = Expectations( + { + (None, None): [ + [[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]], + [[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]], + [[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]], + ], + ("cuda", 8): [ + [[6.9661, 6.9753, 7.2386], [7.2864, 7.2785, 7.4429], [7.6577, 7.8770, 7.9387]], + [[-10.7046, -10.3411, -10.3641], [-10.4402, -10.0004, -9.7269], [-11.0579, -11.0358, -10.7459]], + [[-3.3022, -2.8465, -2.6661], [-3.2654, -2.5542, -2.5055], [-3.2477, -2.6544, -2.6562]], + ], + } ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_post_processing_semantic_segmentation(self): diff --git a/tests/models/mobilevitv2/test_modeling_mobilevitv2.py b/tests/models/mobilevitv2/test_modeling_mobilevitv2.py index 7a0433f123b..daca2394be2 100644 --- a/tests/models/mobilevitv2/test_modeling_mobilevitv2.py +++ b/tests/models/mobilevitv2/test_modeling_mobilevitv2.py @@ -16,7 +16,14 @@ import unittest from transformers import MobileViTV2Config -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + Expectations, + require_torch, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -317,9 +324,15 @@ class MobileViTV2ModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-1.6336e00, -7.3204e-02, -5.1883e-01]).to(torch_device) + expectations = Expectations( + { + (None, None): [-1.6336e00, -7.3204e-02, -5.1883e-01], + ("cuda", 8): [-1.6341, -0.0665, -0.5158], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_semantic_segmentation(self): @@ -340,16 +353,23 @@ class MobileViTV2ModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 21, 32, 32)) self.assertEqual(logits.shape, expected_shape) - expected_slice = torch.tensor( - [ - [[7.0863, 7.1525, 6.8201], [6.6931, 6.8770, 6.8933], [6.2978, 7.0366, 6.9636]], - [[-3.7134, -3.6712, -3.6675], [-3.5825, -3.3549, -3.4777], [-3.3435, -3.3979, -3.2857]], - [[-2.9329, -2.8003, -2.7369], [-3.0564, -2.4780, -2.0207], [-2.6889, -1.9298, -1.7640]], - ], - device=torch_device, + expectations = Expectations( + { + (None, None): [ + [[7.0863, 7.1525, 6.8201], [6.6931, 6.8770, 6.8933], [6.2978, 7.0366, 6.9636]], + [[-3.7134, -3.6712, -3.6675], [-3.5825, -3.3549, -3.4777], [-3.3435, -3.3979, -3.2857]], + [[-2.9329, -2.8003, -2.7369], [-3.0564, -2.4780, -2.0207], [-2.6889, -1.9298, -1.7640]], + ], + ("cuda", 8): [ + [[7.0866, 7.1509, 6.8188], [6.6935, 6.8757, 6.8927], [6.2988, 7.0365, 6.9631]], + [[-3.7113, -3.6686, -3.6643], [-3.5801, -3.3516, -3.4739], [-3.3432, -3.3966, -3.2832]], + [[-2.9359, -2.8037, -2.7387], [-3.0595, -2.4798, -2.0222], [-2.6901, -1.9306, -1.7659]], + ], + } ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_post_processing_semantic_segmentation(self): diff --git a/tests/models/moonshine/test_modeling_moonshine.py b/tests/models/moonshine/test_modeling_moonshine.py index 99573cff096..a551244a6e1 100644 --- a/tests/models/moonshine/test_modeling_moonshine.py +++ b/tests/models/moonshine/test_modeling_moonshine.py @@ -17,7 +17,7 @@ import copy import unittest from transformers import MoonshineConfig, is_torch_available -from transformers.testing_utils import cleanup, require_torch, slow, torch_device +from transformers.testing_utils import Expectations, cleanup, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -457,13 +457,15 @@ class MoonshineModelIntegrationTests(unittest.TestCase): outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True) # fmt: off - EXPECTED_LOGITS = torch.tensor([ - -9.1106, 4.5542, 6.3892, -6.8139, -7.2456, -7.9074, -7.2839, -7.6043, -8.0384, -7.8351, - -7.3867, -7.2450, -7.7420, -7.3912, -7.3866, -7.6979, -7.6420, -7.0504, -7.3979, -7.2483, - -8.0796, -7.3300, -7.3672, -6.8765, -7.6876, -7.2682, -6.9866, -6.7457, -7.6855, -7.3050, - ]) + expectations = Expectations( + { + (None, None): [-9.1106, 4.5542, 6.3892, -6.8139, -7.2456, -7.9074, -7.2839, -7.6043, -8.0384, -7.8351, -7.3867, -7.2450, -7.7420, -7.3912, -7.3866, -7.6979, -7.6420, -7.0504, -7.3979, -7.2483, -8.0796, -7.3300, -7.3672, -6.8765, -7.6876, -7.2682, -6.9866, -6.7457, -7.6855, -7.3050], + ("cuda", 8): [-9.1107, 4.5538, 6.3902, -6.8141, -7.2459, -7.9076, -7.2842, -7.6045, -8.0387, -7.8354, -7.3869, -7.2453, -7.7423, -7.3914, -7.3869, -7.6982, -7.6422, -7.0507, -7.3982, -7.2486, -8.0798, -7.3302, -7.3675, -6.8769, -7.6878, -7.2684, -6.9868, -6.7459, -7.6858, -7.3052], + } + ) + EXPECTED_LOGITS = torch.tensor(expectations.get_expectation()).to(torch_device) # fmt: on - torch.testing.assert_close(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.logits[0][0, :30], EXPECTED_LOGITS, rtol=2e-4, atol=2e-4) @slow def test_base_logits_single(self): @@ -476,7 +478,7 @@ class MoonshineModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_LOGITS = torch.tensor([ - -6.7336, 1.9482, 5.2448, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616, + -6.7336, 1.9482, 5.2448, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616, -8.1070, -7.7696, -7.8809, -7.9450, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657, -7.9310, -8.1024, -7.8699, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9290, ]) @@ -493,9 +495,9 @@ class MoonshineModelIntegrationTests(unittest.TestCase): outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True) # fmt: off EXPECTED_LOGITS = torch.tensor([ - [-8.0109, 5.0241, 4.5979, -6.8125, -7.1675, -7.8783, -7.2152, -7.5188, -7.9077, -7.7394], - [-4.4399, -1.4422, 6.6710, -6.8929, -7.3751, -7.0969, -6.5257, -7.0257, -7.2585, -7.0008], - [-10.0086, 3.2859, 0.7345, -6.5557, -6.8514, -6.5308, -6.4172, -6.9484, -6.6214, -6.6229], + [-8.0109, 5.0241, 4.5979, -6.8125, -7.1675, -7.8783, -7.2152, -7.5188, -7.9077, -7.7394], + [-4.4399, -1.4422, 6.6710, -6.8929, -7.3751, -7.0969, -6.5257, -7.0257, -7.2585, -7.0008], + [-10.0086, 3.2859, 0.7345, -6.5557, -6.8514, -6.5308, -6.4172, -6.9484, -6.6214, -6.6229], [-10.8078, 4.0030, -0.0633, -5.0505, -5.3906, -5.4590, -5.2420, -5.4746, -5.2665, -5.3158] ]) # fmt: on @@ -512,10 +514,10 @@ class MoonshineModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_LOGITS = torch.tensor([ - [-7.7272, 1.4630, 5.2294, -7.7313, -7.6252, -7.6011, -7.6788, -7.6441, -7.8452, -7.7549], - [-6.2173, -0.5891, 7.9493, -7.0694, -6.9997, -6.9982, -7.0953, -7.0831, -7.1686, -7.0137], - [-7.3184, 3.1192, 3.8937, -5.7206, -5.8428, -5.7609, -5.9996, -5.8212, -5.8615, -5.8719], - [-9.5475, 1.0146, 4.1179, -5.9971, -6.0614, -6.0329, -6.2103, -6.0318, -6.0789, -6.0873] + [-7.7272, 1.4630, 5.2294, -7.7313, -7.6252, -7.6011, -7.6788, -7.6441, -7.8452, -7.7549], + [-6.2173, -0.5891, 7.9493, -7.0694, -6.9997, -6.9982, -7.0953, -7.0831, -7.1686, -7.0137], + [-7.3184, 3.1192, 3.8937, -5.7206, -5.8428, -5.7609, -5.9996, -5.8212, -5.8615, -5.8719], + [-9.5475, 1.0146, 4.1179, -5.9971, -6.0614, -6.0329, -6.2103, -6.0318, -6.0789, -6.0873] ]) # fmt: on diff --git a/tests/models/mpt/test_modeling_mpt.py b/tests/models/mpt/test_modeling_mpt.py index 449ddfbc2aa..15d8fddb9f7 100644 --- a/tests/models/mpt/test_modeling_mpt.py +++ b/tests/models/mpt/test_modeling_mpt.py @@ -446,7 +446,8 @@ class MptIntegrationTests(unittest.TestCase): input_text = "Hello" expected_outputs = Expectations({ - ("cuda", None): "Hello, I'm a new user of the forum. I have a question about the \"Solaris", + (None, None): "Hello, I'm a new user of the forum. I have a question about the \"Solaris", + ("cuda", 8): "Hello, I'm a new user of the forum. I have a question. I have a problem with", ("rocm", (9, 5)): "Hello, I'm a newbie to the forum. I have a question about the \"B\" in", }) # fmt: off expected_output = expected_outputs.get_expectation() @@ -468,10 +469,10 @@ class MptIntegrationTests(unittest.TestCase): input_text = "Hello" expected_outputs = Expectations({ + (None, None): "Hello and welcome to the first episode of the new podcast, The Frugal Feminist.\n", ("rocm", (9, 5)): "Hello and welcome to the first day of the new release at The Stamp Man!\nToday we are", ("xpu", 3): "Hello and welcome to the first ever episode of the new and improved, and hopefully improved, podcast.\n", - ("cuda", 7): "Hello and welcome to the first episode of the new podcast, The Frugal Feminist.\n", - ("cuda", 8): "Hello and welcome to the first day of the new release countdown for the month of May!\nToday", + ("cuda", 8): "Hello and welcome to the first ever episode of the new and improved, and hopefully improved, podcast.\n", }) # fmt: off expected_output = expected_outputs.get_expectation() @@ -499,13 +500,17 @@ class MptIntegrationTests(unittest.TestCase): expected_outputs = Expectations( { + (None, None): [ + "Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for the", + "Today I am going at the gym and then I am going to go to the grocery store. I am going to buy some food and some", + ], ("xpu", 3): [ "Hello my name is Tiffany. I am a mother of two beautiful children. I have been a nanny for over", "Today I am going at the gym and then I am going to go to the mall with my mom. I am going to go to the", ], - ("cuda", 7): [ - "Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for the", - "Today I am going at the gym and then I am going to go to the grocery store. I am going to buy some food and some", + ("cuda", 8): [ + "Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for over", + "Today I am going at the gym and then I am going to go to the grocery store. I am going to make a list of things", ], ("rocm", (9, 5)): [ "Hello my name is Jasmine and I am a very sweet and loving dog. I am a very playful dog and I", @@ -534,8 +539,9 @@ class MptIntegrationTests(unittest.TestCase): expected_slices = Expectations( { + (None, None): torch.Tensor([-0.2520, -0.2178, -0.1953]), ("xpu", 3): torch.Tensor([-0.2090, -0.2061, -0.1465]), - ("cuda", 7): torch.Tensor([-0.2520, -0.2178, -0.1953]), + ("cuda", 8): torch.Tensor([-0.2559, -0.2227, -0.2217]), # TODO: This is quite a bit off, check BnB ("rocm", (9, 5)): torch.Tensor([-0.3008, -0.1309, -0.1562]), } diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 3de8b482d70..9356ddf92e5 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -31,6 +31,7 @@ from transformers import ( T5Config, ) from transformers.testing_utils import ( + Expectations, get_device_properties, is_torch_available, require_flash_attn, @@ -1377,16 +1378,17 @@ class MusicgenIntegrationTests(unittest.TestCase): output_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=10) # fmt: off - EXPECTED_VALUES = torch.tensor( - [ - -0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185, - 0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053, - ] + expectations = Expectations( + { + (None, None): [-0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185, 0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053], + ("cuda", 8): [-0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185, 0.0105, 0.0058, 0.0328, 0.0249, -0.0205, -0.0342, -0.0466, 0.0052], + } ) + EXPECTED_VALUES = torch.tensor(expectations.get_expectation()).to(torch_device) # fmt: on self.assertTrue(output_values.shape == (2, 1, 4480)) - torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(output_values[0, 0, :16], EXPECTED_VALUES, rtol=2e-4, atol=2e-4) @slow def test_generate_text_prompt_greedy(self): @@ -1459,16 +1461,17 @@ class MusicgenIntegrationTests(unittest.TestCase): ) # fmt: off - EXPECTED_VALUES = torch.tensor( - [ - -0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229, - 0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326, - ] + expectations = Expectations( + { + (None, None): [-0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229, 0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326], + ("cuda", 8): [-0.0110, -0.0153, 0.0048, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229, 0.0010, -0.0037, 0.0168, 0.0042, -0.0420, -0.0609, -0.0763, -0.0326], + } ) + EXPECTED_VALUES = torch.tensor(expectations.get_expectation()).to(torch_device) # fmt: on self.assertTrue(output_values.shape == (2, 1, 4480)) - torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(output_values[0, 0, :16], EXPECTED_VALUES, rtol=2e-4, atol=2e-4) @slow def test_generate_text_audio_prompt(self): @@ -1521,13 +1524,13 @@ class MusicgenStereoIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_VALUES_LEFT = torch.tensor( [ - 0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013, + 0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013, -0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099, ] ) EXPECTED_VALUES_RIGHT = torch.tensor( [ - 0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019, + 0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019, 0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096, ] ) @@ -1555,13 +1558,13 @@ class MusicgenStereoIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_VALUES_LEFT = torch.tensor( [ - 0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728, + 0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728, -0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430, ] ) EXPECTED_VALUES_RIGHT = torch.tensor( [ - 0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103, + 0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103, -0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614, ] ) diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index eef833750ba..4aa812a0aee 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -30,6 +30,7 @@ from transformers import ( T5Config, ) from transformers.testing_utils import ( + Expectations, get_device_properties, is_torch_available, is_torchaudio_available, @@ -1472,16 +1473,17 @@ class MusicgenMelodyIntegrationTests(unittest.TestCase): ) # fmt: off - EXPECTED_VALUES = torch.tensor( - [ - -0.0165, -0.0222, -0.0041, -0.0058, -0.0145, -0.0023, -0.0160, -0.0310, - -0.0055, -0.0127, 0.0104, 0.0105, -0.0326, -0.0611, -0.0744, -0.0083 - ] + expectations = Expectations( + { + (None, None): [-0.0165, -0.0222, -0.0041, -0.0058, -0.0145, -0.0023, -0.0160, -0.0310, -0.0055, -0.0127, 0.0104, 0.0105, -0.0326, -0.0611, -0.0744, -0.0083], + ("cuda", 8): [-0.0165, -0.0221, -0.0040, -0.0058, -0.0145, -0.0024, -0.0160, -0.0310, -0.0055, -0.0127, 0.0104, 0.0105, -0.0326, -0.0612, -0.0744, -0.0082], + } ) + EXPECTED_VALUES = torch.tensor(expectations.get_expectation()).to(torch_device) # fmt: on self.assertTrue(output_values.shape == (2, 1, 4480)) - torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(output_values[0, 0, :16], EXPECTED_VALUES, rtol=2e-4, atol=2e-4) @slow def test_generate_text_audio_prompt(self): diff --git a/tests/models/oneformer/test_modeling_oneformer.py b/tests/models/oneformer/test_modeling_oneformer.py index 58a93a8c4fa..670756a9bfa 100644 --- a/tests/models/oneformer/test_modeling_oneformer.py +++ b/tests/models/oneformer/test_modeling_oneformer.py @@ -21,6 +21,7 @@ import numpy as np from tests.test_modeling_common import floats_tensor from transformers import AutoModelForImageClassification, OneFormerConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( + Expectations, is_flaky, require_timm, require_torch, @@ -528,7 +529,7 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) -TOLERANCE = 1e-4 +TOLERANCE = 2e-4 # We will verify our results on an image of cute cats @@ -574,12 +575,15 @@ class OneFormerModelIntegrationTest(unittest.TestCase): slice_hidden_state = outputs.pixel_decoder_hidden_states[0][0, 0, :3, :3] torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) - # fmt: off - expected_slice_hidden_state = [[3.0668, -1.1833, -5.1103], [3.344, -3.362, -5.1101], [2.6017, -4.3613, -4.1444]] - expected_slice_hidden_state = torch.tensor(expected_slice_hidden_state).to(torch_device) + expectations = Expectations( + { + (None, None): [[3.0668, -1.1833, -5.1103], [3.344, -3.362, -5.1101], [2.6017, -4.3613, -4.1444]], + ("cuda", 8): [[3.0590, -1.1903, -5.1119], [3.3919, -3.3547, -5.1469], [2.6041, -4.3592, -4.1406]], + } + ) + expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device) slice_hidden_state = outputs.transformer_decoder_class_predictions[0, :3, :3] torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) - # fmt: on def test_inference_universal_segmentation_head(self): model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() @@ -599,8 +603,13 @@ class OneFormerModelIntegrationTest(unittest.TestCase): masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, (inputs_shape[-1] + 2) // 4), ) - expected_slice = [[3.1848, 4.2141, 4.1993], [2.9000, 3.5721, 3.6603], [2.5358, 3.0883, 3.6168]] - expected_slice = torch.tensor(expected_slice).to(torch_device) + expectations = Expectations( + { + (None, None): [[3.1848, 4.2141, 4.1993], [2.9000, 3.5721, 3.6603], [2.5358, 3.0883, 3.6168]], + ("cuda", 8): [[3.1687, 4.1893, 4.1742], [2.8768, 3.5380, 3.6257], [2.5121, 3.0552, 3.5822]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) # class_queries_logits @@ -609,8 +618,13 @@ class OneFormerModelIntegrationTest(unittest.TestCase): class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1), ) - expected_slice = [[3.0668, -1.1833, -5.1103], [3.3440, -3.3620, -5.1101], [2.6017, -4.3613, -4.1444]] - expected_slice = torch.tensor(expected_slice).to(torch_device) + expectations = Expectations( + { + (None, None): [[3.0668, -1.1833, -5.1103], [3.3440, -3.3620, -5.1101], [2.6017, -4.3613, -4.1444]], + ("cuda", 8): [[3.0590, -1.1903, -5.1119], [3.3919, -3.3547, -5.1469], [2.6041, -4.3592, -4.1406]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) torch.testing.assert_close(class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @require_torch_accelerator diff --git a/tests/models/poolformer/test_modeling_poolformer.py b/tests/models/poolformer/test_modeling_poolformer.py index 0fee2b295f0..56300abbe8c 100644 --- a/tests/models/poolformer/test_modeling_poolformer.py +++ b/tests/models/poolformer/test_modeling_poolformer.py @@ -17,7 +17,7 @@ import unittest from transformers import is_torch_available, is_vision_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -144,6 +144,9 @@ class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_batching_equivalence(self, atol=2e-4, rtol=2e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="PoolFormer does not use inputs_embeds") def test_inputs_embeds(self): pass @@ -235,5 +238,11 @@ class PoolFormerModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.6113, 0.1685, -0.0492]).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [-0.6113, 0.1685, -0.0492], + ("cuda", 8): [-0.6112, 0.1690, -0.0481], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/pvt/test_modeling_pvt.py b/tests/models/pvt/test_modeling_pvt.py index d52348555ad..eeaabcbd608 100644 --- a/tests/models/pvt/test_modeling_pvt.py +++ b/tests/models/pvt/test_modeling_pvt.py @@ -17,6 +17,7 @@ import unittest from transformers import is_torch_available, is_vision_available from transformers.testing_utils import ( + Expectations, require_accelerate, require_torch, require_torch_accelerator, @@ -153,6 +154,9 @@ class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): self.model_tester = PvtModelTester(self) self.config_tester = PvtConfigTester(self, config_class=PvtConfig) + def test_batching_equivalence(self, atol=1e-4, rtol=1e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + def test_config(self): self.config_tester.run_common_tests() @@ -257,9 +261,15 @@ class PvtModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, model.config.num_labels)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-1.4192, -1.9158, -0.9702]).to(torch_device) + expectations = Expectations( + { + (None, None): [-1.4192, -1.9158, -0.9702], + ("cuda", 8): [-1.4194, -1.9161, -0.9705], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_model(self): @@ -278,11 +288,15 @@ class PvtModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 50, 512)) self.assertEqual(outputs.last_hidden_state.shape, expected_shape) - expected_slice = torch.tensor( - [[-0.3086, 1.0402, 1.1816], [-0.2880, 0.5781, 0.6124], [0.1480, 0.6129, -0.0590]] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [[-0.3086, 1.0402, 1.1816], [-0.2880, 0.5781, 0.6124], [0.1480, 0.6129, -0.0590]], + ("cuda", 8): [[-0.3084, 1.0402, 1.1816], [-0.2883, 0.5781, 0.6123], [0.1487, 0.6119, -0.0584]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow @require_accelerate diff --git a/tests/models/pvt_v2/test_modeling_pvt_v2.py b/tests/models/pvt_v2/test_modeling_pvt_v2.py index d1a765b19d4..0aca4e6652b 100644 --- a/tests/models/pvt_v2/test_modeling_pvt_v2.py +++ b/tests/models/pvt_v2/test_modeling_pvt_v2.py @@ -167,6 +167,9 @@ class PvtV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_batching_equivalence(self, atol=5e-4, rtol=5e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="Pvt-V2 does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py index 9f88bc8c9c1..8fc8e452da9 100644 --- a/tests/models/regnet/test_modeling_regnet.py +++ b/tests/models/regnet/test_modeling_regnet.py @@ -17,7 +17,7 @@ import unittest from transformers import RegNetConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -146,6 +146,9 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() + def test_batching_equivalence(self, atol=3e-5, rtol=3e-5): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="RegNet does not use inputs_embeds") def test_inputs_embeds(self): pass @@ -248,6 +251,11 @@ class RegNetModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.4180, -1.5051, -3.4836]).to(torch_device) - - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [-0.4180, -1.5051, -3.4836], + ("cuda", 8): [-0.4168, -1.5056, -3.4836], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index e63d617c0e8..3778bd40054 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -16,7 +16,7 @@ import unittest from transformers import ResNetConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_backbone_common import BackboneTesterMixin @@ -301,9 +301,14 @@ class ResNetModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device) - - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [-11.1069, -9.7877, -8.3777], + ("cuda", 8): [-11.1112, -9.7916, -8.3788], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) @require_torch diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py index fa2938160d7..fad90934265 100644 --- a/tests/models/rt_detr/test_modeling_rt_detr.py +++ b/tests/models/rt_detr/test_modeling_rt_detr.py @@ -29,6 +29,7 @@ from transformers import ( is_vision_available, ) from transformers.testing_utils import ( + Expectations, require_torch, require_torch_accelerator, require_vision, @@ -732,45 +733,69 @@ class RTDetrModelIntegrationTest(unittest.TestCase): expected_shape_logits = torch.Size((1, 300, model.config.num_labels)) self.assertEqual(outputs.logits.shape, expected_shape_logits) - expected_logits = torch.tensor( - [ - [-4.64763879776001, -5.001153945922852, -4.978509902954102], - [-4.159348487854004, -4.703853607177734, -5.946484565734863], - [-4.437461853027344, -4.65836238861084, -6.235235691070557], - ] - ).to(torch_device) - expected_boxes = torch.tensor( - [ - [0.1688060760498047, 0.19992263615131378, 0.21225441992282867], - [0.768376350402832, 0.41226309537887573, 0.4636859893798828], - [0.25953856110572815, 0.5483334064483643, 0.4777486026287079], - ] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [-4.64763879776001, -5.001153945922852, -4.978509902954102], + [-4.159348487854004, -4.703853607177734, -5.946484565734863], + [-4.437461853027344, -4.65836238861084, -6.235235691070557], + ], + ("cuda", 8): [[-4.6471, -5.0008, -4.9786], [-4.1599, -4.7041, -5.9458], [-4.4374, -4.6582, -6.2340]], + } + ) + expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [ + [0.1688060760498047, 0.19992263615131378, 0.21225441992282867], + [0.768376350402832, 0.41226309537887573, 0.4636859893798828], + [0.25953856110572815, 0.5483334064483643, 0.4777486026287079], + ], + ("cuda", 8): [[0.1688, 0.1999, 0.2123], [0.7684, 0.4123, 0.4637], [0.2596, 0.5483, 0.4777]], + } + ) + expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, rtol=2e-4, atol=2e-4) expected_shape_boxes = torch.Size((1, 300, 4)) self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) - torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=2e-4, atol=2e-4) # verify postprocessing results = image_processor.post_process_object_detection( outputs, threshold=0.0, target_sizes=[image.size[::-1]] )[0] - expected_scores = torch.tensor( - [0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493], device=torch_device - ) - expected_labels = [57, 15, 15, 65] - expected_slice_boxes = torch.tensor( - [ - [0.13774872, 0.37821293, 640.13074, 476.21088], - [343.38132, 24.276838, 640.1404, 371.49573], - [13.225126, 54.179348, 318.98422, 472.2207], - [40.114475, 73.44104, 175.9573, 118.48469], - ], - device=torch_device, - ) - torch.testing.assert_close(results["scores"][:4], expected_scores, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493], + ("cuda", 8): [0.9704, 0.9599, 0.9576, 0.9507], + } + ) + expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device) + + expected_labels = [57, 15, 15, 65] + + expectations = Expectations( + { + (None, None): [ + [0.13774872, 0.37821293, 640.13074, 476.21088], + [343.38132, 24.276838, 640.1404, 371.49573], + [13.225126, 54.179348, 318.98422, 472.2207], + [40.114475, 73.44104, 175.9573, 118.48469], + ], + ("cuda", 8): [ + [1.4183e-01, 3.8063e-01, 6.4013e02, 4.7621e02], + [3.4338e02, 2.4275e01, 6.4014e02, 3.7150e02], + [1.3236e01, 5.4179e01, 3.1899e02, 4.7222e02], + [4.0114e01, 7.3441e01, 1.7596e02, 1.1848e02], + ], + } + ) + expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(results["scores"][:4], expected_scores, rtol=2e-4, atol=2e-4) self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels) - torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, rtol=2e-4, atol=2e-4) diff --git a/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py b/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py index a78f11ea46c..79202d3cf71 100644 --- a/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py +++ b/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py @@ -28,6 +28,7 @@ from transformers import ( is_vision_available, ) from transformers.testing_utils import ( + Expectations, require_torch, require_torch_accelerator, require_vision, @@ -736,42 +737,60 @@ class RTDetrV2ModelIntegrationTest(unittest.TestCase): expected_shape_logits = torch.Size((1, 300, model.config.num_labels)) self.assertEqual(outputs.logits.shape, expected_shape_logits) - expected_logits = torch.tensor( - [ - [-3.7047, -5.1914, -6.1787], - [-4.0108, -9.3449, -5.2047], - [-4.1287, -4.7461, -5.8633], - ] - ).to(torch_device) - expected_boxes = torch.tensor( - [ - [0.2582, 0.5497, 0.4764], - [0.1684, 0.1985, 0.2120], - [0.7665, 0.4146, 0.4669], - ] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [[-3.7047, -5.1914, -6.1787], [-4.0108, -9.3449, -5.2047], [-4.1287, -4.7461, -5.8633]], + ("cuda", 8): [[-3.7039, -5.1923, -6.1787], [-4.0106, -9.3452, -5.2045], [-4.1285, -4.7468, -5.8641]], + } + ) + expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=1e-4, rtol=1e-4) + expectations = Expectations( + { + (None, None): [[0.2582, 0.5497, 0.4764], [0.1684, 0.1985, 0.2120], [0.7665, 0.4146, 0.4669]], + } + ) + expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=2e-4, rtol=2e-4) expected_shape_boxes = torch.Size((1, 300, 4)) self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) - torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=2e-4, rtol=2e-4) # verify postprocessing results = image_processor.post_process_object_detection( outputs, threshold=0.0, target_sizes=[image.size[::-1]] )[0] - expected_scores = torch.tensor([0.9652, 0.9599, 0.9462, 0.8613], device=torch_device) - expected_labels = [15, 15, 65, 57] - expected_slice_boxes = torch.tensor( - [ - [3.4114e02, 2.5111e01, 6.3998e02, 3.7289e02], - [1.2780e01, 5.6346e01, 3.1767e02, 4.7134e02], - [3.9959e01, 7.3117e01, 1.7565e02, 1.1744e02], - [-1.0521e-01, 2.9717e00, 6.3989e02, 4.7362e02], - ], - device=torch_device, + + expectations = Expectations( + { + (None, None): [0.9652, 0.9599, 0.9462, 0.8613], + ("cuda", 8): [0.9652, 0.9599, 0.9461, 0.8613], + } ) - self.assertTrue(torch.allclose(results["scores"][:4], expected_scores, atol=1e-3, rtol=1e-4)) + expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device) + + expected_labels = [15, 15, 65, 57] + + expectations = Expectations( + { + (None, None): [ + [3.4114e02, 2.5111e01, 6.3998e02, 3.7289e02], + [1.2780e01, 5.6346e01, 3.1767e02, 4.7134e02], + [3.9959e01, 7.3117e01, 1.7565e02, 1.1744e02], + [-1.0521e-01, 2.9717e00, 6.3989e02, 4.7362e02], + ], + ("cuda", 8): [ + [3.4115e02, 2.5109e01, 6.3997e02, 3.7290e02], + [1.2785e01, 5.6350e01, 3.1767e02, 4.7134e02], + [3.9959e01, 7.3117e01, 1.7565e02, 1.1744e02], + [-1.0471e-01, 2.9680e00, 6.3989e02, 4.7362e02], + ], + } + ) + expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=2e-4) self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels) - torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=2e-4) diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index fa1ada4f616..660d529dc9f 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -19,7 +19,7 @@ import unittest import requests from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline -from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device +from transformers.testing_utils import Expectations, cleanup, require_torch, require_torch_sdpa, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -771,9 +771,18 @@ class SamModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze().cpu() - masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() + masks = outputs.pred_masks[0, 0, 0, 0, :3] + + expectations = Expectations( + { + (None, None): [-12.7729, -12.3665, -12.6061], + ("cuda", 8): [-12.7657, -12.3683, -12.5983], + } + ) + expected_masks = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(scores[-1], torch.tensor(0.9566), rtol=2e-4, atol=2e-4) - torch.testing.assert_close(masks, torch.tensor([-12.7729, -12.3665, -12.6061]), rtol=2e-4, atol=2e-4) + torch.testing.assert_close(masks, expected_masks, rtol=2e-4, atol=2e-4) def test_inference_mask_generation_batched_points_batched_images(self): model = SamModel.from_pretrained("facebook/sam-vit-base") diff --git a/tests/models/sam_hq/test_modeling_sam_hq.py b/tests/models/sam_hq/test_modeling_sam_hq.py index 830b537031d..b4701fa975d 100644 --- a/tests/models/sam_hq/test_modeling_sam_hq.py +++ b/tests/models/sam_hq/test_modeling_sam_hq.py @@ -27,7 +27,7 @@ from transformers import ( SamHQVisionModel, pipeline, ) -from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device +from transformers.testing_utils import Expectations, cleanup, require_torch, require_torch_sdpa, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -802,9 +802,15 @@ class SamHQModelIntegrationTest(unittest.TestCase): masks = outputs.pred_masks[0, 0, 0, 0, :3] self.assertTrue(torch.allclose(scores[0][0][-1], torch.tensor(0.4482), atol=2e-4)) - self.assertTrue( - torch.allclose(masks, torch.tensor([-13.1695, -14.6201, -14.8989]).to(torch_device), atol=2e-3) + + expectations = Expectations( + { + (None, None): [-13.1695, -14.6201, -14.8989], + ("cuda", 8): [-13.1668, -14.6182, -14.8970], + } ) + EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(masks, EXPECTED_MASKS, atol=2e-3, rtol=2e-3) def test_inference_mask_generation_one_point_one_bb(self): model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base") @@ -849,28 +855,53 @@ class SamHQModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(**inputs) - scores = outputs.iou_scores.squeeze().cpu() - masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() - EXPECTED_SCORES = torch.tensor( - [ - [ - [0.9195, 0.8316, 0.6614], - [0.9195, 0.8316, 0.6614], - [0.9195, 0.8316, 0.6614], - [0.9195, 0.8316, 0.6614], - ], - [ - [0.7598, 0.7388, 0.3110], - [0.9195, 0.8317, 0.6614], - [0.9195, 0.8317, 0.6614], - [0.9195, 0.8317, 0.6614], - ], - ] - ) - EXPECTED_MASKS = torch.tensor([-40.2445, -37.4300, -38.1577]) + scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) - self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=9e-3)) + expectations = Expectations( + { + (None, None): [ + [ + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + ], + [ + [0.7598, 0.7388, 0.3110], + [0.9195, 0.8317, 0.6614], + [0.9195, 0.8317, 0.6614], + [0.9195, 0.8317, 0.6614], + ], + ], + ("cuda", 8): [ + [ + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + ], + [ + [0.7597, 0.7387, 0.3110], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + ], + ], + } + ) + EXPECTED_SCORES = torch.tensor(expectations.get_expectation()).to(torch_device) + + expectations = Expectations( + { + (None, None): [-40.2445, -37.4300, -38.1577], + ("cuda", 8): [-40.2351, -37.4334, -38.1526], + } + ) + EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(scores, EXPECTED_SCORES, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(masks, EXPECTED_MASKS, atol=9e-3, rtol=9e-3) def test_inference_mask_generation_one_point_one_bb_zero(self): model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base") diff --git a/tests/models/segformer/test_modeling_segformer.py b/tests/models/segformer/test_modeling_segformer.py index cd75545c62a..fcd6594217c 100644 --- a/tests/models/segformer/test_modeling_segformer.py +++ b/tests/models/segformer/test_modeling_segformer.py @@ -16,7 +16,7 @@ import unittest from transformers import SegformerConfig, is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -200,6 +200,9 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs) + def test_batching_equivalence(self, atol=2e-4, rtol=2e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="SegFormer does not use inputs_embeds") def test_inputs_embeds(self): pass @@ -367,14 +370,22 @@ class SegformerModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, model.config.num_labels, 128, 128)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor( - [ - [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]], - [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]], - [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]], - ] - ).to(torch_device) - torch.testing.assert_close(outputs.logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [ + [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]], + [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]], + [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]], + ], + ("cuda", 8): [ + [[-4.6310, -5.5232, -6.2361], [-5.1918, -6.1445, -6.5996], [-5.4427, -6.2792, -6.7580]], + [[-12.1397, -13.3124, -13.9551], [-12.8736, -13.9347, -14.3569], [-12.9440, -13.8222, -14.2514]], + [[-12.5135, -13.4682, -14.4913], [-12.8670, -14.4339, -14.7766], [-13.2519, -14.5800, -15.0685]], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_image_segmentation_city(self): @@ -396,13 +407,24 @@ class SegformerModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, model.config.num_labels, 128, 128)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor( - [ - [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]], - [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]], - [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]], - ] - ).to(torch_device) + expected_slice = torch.tensor([]).to(torch_device) + + expectations = Expectations( + { + (None, None): [ + [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]], + [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]], + [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]], + ], + ("cuda", 8): [ + [[-13.5728, -13.9089, -12.6492], [-14.3478, -15.3656, -14.2309], [-14.7512, -16.0394, -15.6065]], + [[-17.1642, -15.8704, -12.9641], [-17.2572, -17.3701, -14.8214], [-16.6043, -16.8761, -16.7425]], + [[-3.6444, -3.0189, -1.4195], [-3.0787, -3.1953, -1.9993], [-1.8755, -1.9219, -1.7002]], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3, :3, :3], expected_slice, rtol=1e-1, atol=1e-1) @slow diff --git a/tests/models/seggpt/test_modeling_seggpt.py b/tests/models/seggpt/test_modeling_seggpt.py index 1176613fa20..4083276e185 100644 --- a/tests/models/seggpt/test_modeling_seggpt.py +++ b/tests/models/seggpt/test_modeling_seggpt.py @@ -21,6 +21,7 @@ from datasets import load_dataset from transformers import SegGptConfig from transformers.testing_utils import ( + Expectations, require_torch, require_vision, slow, @@ -379,15 +380,23 @@ class SegGptModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 3, 896, 448)) self.assertEqual(outputs.pred_masks.shape, expected_shape) - expected_slice = torch.tensor( - [ - [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]], - [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]], - [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]], - ] - ).to(torch_device) + expectations = Expectations( + { + (None, None): [ + [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]], + [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]], + [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]], + ], + ("cuda", 8): [ + [[-2.1208, -2.1189, -2.1198], [-2.1236, -2.1229, -2.1230], [-2.1233, -2.1227, -2.1228]], + [[-2.0408, -2.0398, -2.0405], [-2.0435, -2.0437, -2.0438], [-2.0431, -2.0435, -2.0436]], + [[-1.8101, -1.8086, -1.8098], [-1.8129, -1.8126, -1.8130], [-1.8128, -1.8128, -1.8130]], + ], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) - torch.testing.assert_close(outputs.pred_masks[0, :, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.pred_masks[0, :, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) result = image_processor.post_process_semantic_segmentation(outputs, [input_image.size[::-1]])[0] diff --git a/tests/models/swin2sr/test_modeling_swin2sr.py b/tests/models/swin2sr/test_modeling_swin2sr.py index 125d5418e8e..a1767a0ab24 100644 --- a/tests/models/swin2sr/test_modeling_swin2sr.py +++ b/tests/models/swin2sr/test_modeling_swin2sr.py @@ -16,7 +16,7 @@ import unittest from transformers import Swin2SRConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -360,7 +360,12 @@ class Swin2SRModelIntegrationTest(unittest.TestCase): # verify the logits expected_shape = torch.Size([1, 3, 976, 1296]) self.assertEqual(outputs.reconstruction.shape, expected_shape) - expected_slice = torch.tensor( - [[0.5454, 0.5542, 0.5640], [0.5518, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]], dtype=model.dtype - ).to(torch_device) - torch.testing.assert_close(outputs.reconstruction[0, 0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + + expectations = Expectations( + { + (None, None): [[0.5454, 0.5542, 0.5640], [0.5518, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]], + ("cuda", 8): [[0.5454, 0.5547, 0.5640], [0.5522, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device, dtype=model.dtype) + torch.testing.assert_close(outputs.reconstruction[0, 0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 2b5eb30dcf4..67b59fef4ff 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -19,6 +19,7 @@ import unittest from transformers import SwitchTransformersConfig, is_torch_available from transformers.testing_utils import ( + Expectations, require_tokenizers, require_torch, require_torch_accelerator, @@ -1035,18 +1036,28 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase): decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device) # fmt: off - EXPECTED_MEAN_LOGITS = torch.Tensor( - [ - -0.204102, -0.193359, 0.523438, -0.296875, 0.108887, - 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875, - 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445, - 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883, - 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012, - -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789 - ] - ).to(torch.bfloat16) + expectations = Expectations( + { + (None, None): [ + -0.204102, -0.193359, 0.523438, -0.296875, 0.108887, + 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875, + 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445, + 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883, + 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012, + -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789 + ], + ("cuda", 8): [ + -0.2051, -0.1914, 0.5352, -0.2988, 0.1108, 0.0200, 0.6094, -0.1025, + -0.0549, 0.2988, -0.0018, 0.1758, 0.1348, -0.1689, -0.1035, 0.0266, + 0.0383, 0.0493, -0.2119, 0.1328, 0.3906, -0.2041, -0.1240, -0.1836, + 0.0454, -0.3477, -0.0256, -0.1050, -0.1572, -0.1338 + ], + } + ) + EXPECTED_MEAN_LOGITS = torch.tensor(expectations.get_expectation()).to(torch_device, dtype=torch.bfloat16) # fmt: on - hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu() + + hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state hf_logits = hf_logits[0, 0, :30] torch.testing.assert_close(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3) diff --git a/tests/models/timesformer/test_modeling_timesformer.py b/tests/models/timesformer/test_modeling_timesformer.py index 247d0a5fba6..5b74a8507a0 100644 --- a/tests/models/timesformer/test_modeling_timesformer.py +++ b/tests/models/timesformer/test_modeling_timesformer.py @@ -21,7 +21,7 @@ from huggingface_hub import hf_hub_download from transformers import TimesformerConfig from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -350,6 +350,11 @@ class TimesformerModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 400)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.3016, -0.7713, -0.4205]).to(torch_device) - - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [-0.3016, -0.7713, -0.4205], + ("cuda", 8): [-0.3004, -0.7708, -0.4205], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index 3f103309a04..a37f10d3818 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -18,6 +18,7 @@ import unittest from transformers import pipeline from transformers.testing_utils import ( + Expectations, require_bitsandbytes, require_timm, require_torch, @@ -304,10 +305,16 @@ class TimmWrapperModelIntegrationTest(unittest.TestCase): expected_label = 281 # tabby cat self.assertEqual(torch.argmax(outputs.logits).item(), expected_label) - expected_slice = torch.tensor([-11.2618, -9.6192, -10.3205]).to(torch_device) + expectations = Expectations( + { + (None, None): [-11.2618, -9.6192, -10.3205], + ("cuda", 8): [-11.2634, -9.6208, -10.3199], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + resulted_slice = outputs.logits[0, :3] - is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3) - self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}") + torch.testing.assert_close(resulted_slice, expected_slice, atol=1e-3, rtol=1e-3) @slow def test_inference_with_pipeline(self): @@ -349,10 +356,16 @@ class TimmWrapperModelIntegrationTest(unittest.TestCase): expected_label = 281 # tabby cat self.assertEqual(torch.argmax(outputs.logits).item(), expected_label) - expected_slice = torch.tensor([-2.4043, 1.4492, -0.5127]).to(outputs.logits.dtype) - resulted_slice = outputs.logits[0, :3].cpu() - is_close = torch.allclose(resulted_slice, expected_slice, atol=0.1) - self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}") + expectations = Expectations( + { + (None, None): [-2.4043, 1.4492, -0.5127], + ("cuda", 8): [-2.2676, 1.5303, -0.4409], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + + resulted_slice = outputs.logits[0, :3].to(dtype=torch.float32) + torch.testing.assert_close(resulted_slice, expected_slice, atol=0.1, rtol=0.1) @slow def test_transformers_model_for_classification_is_equivalent_to_timm(self): diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index 0b85c31f8a9..2c592290a62 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -24,6 +24,7 @@ from pytest import mark from transformers import VideoMAEConfig from transformers.models.auto import get_values from transformers.testing_utils import ( + Expectations, is_flaky, require_flash_attn, require_torch, @@ -442,9 +443,14 @@ class VideoMAEModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 400)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([0.3669, -0.0688, -0.2421]).to(torch_device) - - torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [0.3669, -0.0688, -0.2421], + ("cuda", 8): [0.3668, -0.0690, -0.2421], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_for_pretraining(self): diff --git a/tests/models/vitpose/test_modeling_vitpose.py b/tests/models/vitpose/test_modeling_vitpose.py index 6f4ac621327..e9bce2d4c6a 100644 --- a/tests/models/vitpose/test_modeling_vitpose.py +++ b/tests/models/vitpose/test_modeling_vitpose.py @@ -169,6 +169,9 @@ class VitPoseModelTest(ModelTesterMixin, unittest.TestCase): self.config_tester.check_config_can_be_init_without_params() self.config_tester.check_config_arguments_init() + def test_batching_equivalence(self, atol=3e-4, rtol=3e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + @unittest.skip(reason="VitPose does not support input and output embeddings") def test_model_common_attributes(self): pass diff --git a/tests/models/vitpose_backbone/test_modeling_vitpose_backbone.py b/tests/models/vitpose_backbone/test_modeling_vitpose_backbone.py index 64ff79a68ec..a95d6ca1fa5 100644 --- a/tests/models/vitpose_backbone/test_modeling_vitpose_backbone.py +++ b/tests/models/vitpose_backbone/test_modeling_vitpose_backbone.py @@ -137,6 +137,9 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() + def test_batching_equivalence(self, atol=3e-4, rtol=3e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + # TODO: @Pavel @unittest.skip(reason="currently failing") def test_initialization(self): diff --git a/tests/models/vivit/test_modeling_vivit.py b/tests/models/vivit/test_modeling_vivit.py index d4d3efe3748..f2866febb7f 100644 --- a/tests/models/vivit/test_modeling_vivit.py +++ b/tests/models/vivit/test_modeling_vivit.py @@ -22,7 +22,7 @@ from huggingface_hub import hf_hub_download from transformers import VivitConfig from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -355,10 +355,14 @@ class VivitModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 400)) self.assertEqual(outputs.logits.shape, expected_shape) - # taken from original model - expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device) - - torch.testing.assert_close(outputs.logits[0, :5], expected_slice, rtol=1e-4, atol=1e-4) + expectations = Expectations( + { + (None, None): [-0.9498, 2.7971, -1.4049, 0.1024, -1.8353], + ("cuda", 8): [-0.9502, 2.7967, -1.4046, 0.1027, -1.8345], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :5], expected_slice, rtol=2e-4, atol=2e-4) @slow def test_inference_interpolate_pos_encoding(self): diff --git a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py index 1408e097443..a8e2c4843eb 100644 --- a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py @@ -20,6 +20,7 @@ from datasets import load_dataset from transformers import Wav2Vec2BertConfig, is_torch_available from transformers.testing_utils import ( + is_flaky, require_torch, require_torch_accelerator, require_torch_fp16, @@ -434,6 +435,10 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @is_flaky(description="Get lager difference with A10 and even with the new `5e-4` still flaky") + def test_batching_equivalence(self, atol=5e-4, rtol=5e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + def test_model_with_relative(self): config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative") self.model_tester.create_and_check_model(*config_and_inputs) diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py index f9805820246..430653721d1 100644 --- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py +++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py @@ -428,8 +428,8 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest @is_flaky( description="The `codevector_idx` computed with `argmax()` in `Wav2Vec2ConformerGumbelVectorQuantizer.forward` is not stable." ) - def test_batching_equivalence(self): - super().test_batching_equivalence() + def test_batching_equivalence(self, atol=1e-4, rtol=1e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) def test_model_with_relative(self): config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative") diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 1a0c7dda6e0..8c5134fc6db 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -22,7 +22,14 @@ import numpy as np from huggingface_hub import hf_hub_download from transformers import XCLIPConfig, XCLIPTextConfig, XCLIPVisionConfig -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + Expectations, + require_torch, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -751,10 +758,13 @@ class XCLIPModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) - expected_slice = torch.tensor( - [[0.0126, 0.2109, 0.0609], [0.0448, 0.5862, -0.1688], [-0.0881, 0.8525, -0.3044]] - ).to(torch_device) - - torch.testing.assert_close( - outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4 + expectations = Expectations( + { + (None, None): [[0.0126, 0.2109, 0.0609], [0.0448, 0.5862, -0.1688], [-0.0881, 0.8525, -0.3044]], + ("cuda", 8): [[0.0141, 0.2114, 0.0599], [0.0446, 0.5866, -0.1674], [-0.0876, 0.8592, -0.3025]], + } + ) + expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close( + outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4 ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 878940b937f..1a7f5120253 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -662,6 +662,11 @@ class TrainerIntegrationCommon: metrics = trainer.evaluate() self.assertEqual(metrics[metric], best_value) + def remove_nan_logs(self, log): + for key in list(log.keys()): + if log[key] != log[key]: # Check if the value is NaN + del log[key] + def check_trainer_state_are_the_same(self, trainer_state, trainer_state1): # We'll pop things so operate on copies. state = trainer_state.copy() @@ -675,6 +680,10 @@ class TrainerIntegrationCommon: for key in skip_log_keys: _ = log.pop(key, None) _ = log1.pop(key, None) + + self.remove_nan_logs(log) + self.remove_nan_logs(log1) + self.assertEqual(log, log1) def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True): @@ -3174,6 +3183,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertAlmostEqual(b, b1, delta=1e-5) @slow + @require_non_hpu @require_accelerate @require_torch_non_multi_accelerator def test_auto_batch_size_finder(self): diff --git a/utils/split_model_tests.py b/utils/split_model_tests.py index e5083aaeb46..3539a2fb317 100644 --- a/utils/split_model_tests.py +++ b/utils/split_model_tests.py @@ -62,4 +62,5 @@ if __name__ == "__main__": start = end end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0) model_splits.append(d[start:end]) + print(model_splits)