mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
12048990a9
commit
6ce238fe7a
@ -860,11 +860,12 @@ class ModelTesterMixin:
|
||||
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
|
||||
|
||||
model_eager.save_pretrained(tmpdir)
|
||||
with torch.device(torch_device):
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
|
||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
_ = model(**inputs_dict, return_dict=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
||||
)
|
||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
_ = model(**inputs_dict, return_dict=False)
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
# Scenario - 1 default behaviour
|
||||
|
Loading…
Reference in New Issue
Block a user