mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix test_model_parallelization (#17249)
* Fix test_model_parallelization * Modify
This commit is contained in:
parent
e705e1267c
commit
f0395cf58e
@ -2065,7 +2065,7 @@ class ModelTesterMixin:
|
||||
memory_after_parallelization = get_current_gpu_memory_use()
|
||||
|
||||
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
|
||||
for n in range(torch.cuda.device_count()):
|
||||
for n in range(len(model.device_map.keys())):
|
||||
self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
|
||||
|
||||
# Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it
|
||||
|
Loading…
Reference in New Issue
Block a user