Offload fixes (#17810)

* Offload fixes

* Add a test
This commit is contained in:
Sylvain Gugger 2022-06-22 12:23:07 -04:00 committed by GitHub
parent 0d0c392c45
commit df8e6804c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 5 deletions

View File

@ -2166,11 +2166,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict=False,
dtype=None,
):
if device_map is not None and "disk" in device_map.values() and offload_folder is None:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for"
" them."
)
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them."
)
os.makedirs(offload_folder, exist_ok=True)
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
@ -2344,6 +2346,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
gc.collect()
if offload_index is not None and len(offload_index) > 0:
if model != model_to_load:
# We need to add the prefix of the base model
prefix = cls.base_model_prefix
for weight_name in offload_index:
shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
save_offload_index(offload_index, offload_folder)
if offload_state_dict:

View File

@ -2811,6 +2811,48 @@ class ModelUtilsTest(TestCasePlus):
text_output = tokenizer.decode(output[0].tolist())
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
@require_accelerate
@require_torch_gpu
def test_from_pretrained_disk_offload_task_model(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
device_map = {
"transformer.wte": 0,
"transformer.wpe": 0,
"transformer.h.0": "cpu",
"transformer.h.1": "cpu",
"transformer.h.2": "cpu",
"transformer.h.3": "disk",
"transformer.h.4": "disk",
"transformer.ln_f": 0,
"lm_head": 0,
}
with tempfile.TemporaryDirectory() as tmp_dir:
inputs = torch.tensor([[1, 2, 3]]).to(0)
model.save_pretrained(tmp_dir)
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0)
outputs1 = new_model.to(0)(inputs)
offload_folder = os.path.join(tmp_dir, "offload")
new_model_with_offload = AutoModelForCausalLM.from_pretrained(
tmp_dir, device_map=device_map, offload_folder=offload_folder
)
outputs2 = new_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
# With state dict temp offload
offload_folder = os.path.join(tmp_dir, "offload")
new_model_with_offload = AutoModelForCausalLM.from_pretrained(
tmp_dir,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=True,
)
outputs2 = new_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()