diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 425db5ecdcf..bd3bbe7c60c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1073,6 +1073,9 @@ class GenerationTesterMixin: @require_torch_multi_accelerator def test_model_parallel_beam_search(self): for model_class in self.all_generative_model_classes: + if "xpu" in torch_device: + return unittest.skip("device_map='auto' does not work with XPU devices") + if model_class._no_split_modules is None: continue