Fix test_model_parallelization (#17249)

* Fix test_model_parallelization

* Modify
This commit is contained in:
Kyungmin Lee 2022-05-17 06:30:49 +09:00 committed by GitHub
parent e705e1267c
commit f0395cf58e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2065,7 +2065,7 @@ class ModelTesterMixin:
memory_after_parallelization = get_current_gpu_memory_use()
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
for n in range(torch.cuda.device_count()):
for n in range(len(model.device_map.keys())):
self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
# Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it