* Update test_modeling_common.py

* style
This commit is contained in:
Cyril Vallez 2025-04-03 10:24:34 +02:00 committed by GitHub
parent 12048990a9
commit 6ce238fe7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -860,11 +860,12 @@ class ModelTesterMixin:
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
model_eager.save_pretrained(tmpdir)
with torch.device(torch_device):
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False)
model = AutoModelForCausalLM.from_pretrained(
tmpdir, torch_dtype=torch.float32, device_map=torch_device
)
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False)
def test_training_gradient_checkpointing(self):
# Scenario - 1 default behaviour