diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ce8921b3337..42aea7b67d0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -860,11 +860,12 @@ class ModelTesterMixin: model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32) model_eager.save_pretrained(tmpdir) - with torch.device(torch_device): - model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32) - inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] - inputs_dict["labels"] = inputs_dict["input_ids"] - _ = model(**inputs_dict, return_dict=False) + model = AutoModelForCausalLM.from_pretrained( + tmpdir, torch_dtype=torch.float32, device_map=torch_device + ) + inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] + inputs_dict["labels"] = inputs_dict["input_ids"] + _ = model(**inputs_dict, return_dict=False) def test_training_gradient_checkpointing(self): # Scenario - 1 default behaviour