diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 900b425b38c..3c01286c6b7 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2065,7 +2065,7 @@ class ModelTesterMixin: memory_after_parallelization = get_current_gpu_memory_use() # Assert that the memory use on all devices is higher than it was when loaded only on CPU - for n in range(torch.cuda.device_count()): + for n in range(len(model.device_map.keys())): self.assertGreater(memory_after_parallelization[n], memory_at_start[n]) # Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it