mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[test_model_parallelization] multiple fixes (#9354)
This commit is contained in:
parent
086718ac6e
commit
143289dcf7
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import inspect
|
||||
import os.path
|
||||
import random
|
||||
@ -1081,15 +1082,15 @@ class ModelTesterMixin:
|
||||
if not self.test_model_parallel:
|
||||
return
|
||||
|
||||
import subprocess
|
||||
|
||||
# a candidate for testing_utils
|
||||
def get_current_gpu_memory_use():
|
||||
run_process = subprocess.Popen(
|
||||
"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader", shell=True, stdout=subprocess.PIPE
|
||||
)
|
||||
""" returns a list of cuda memory allocations per GPU in MBs"""
|
||||
|
||||
per_device_memory = []
|
||||
for id in range(torch.cuda.device_count()):
|
||||
with torch.cuda.device(id):
|
||||
per_device_memory.append(torch.cuda.memory_allocated() >> 20)
|
||||
|
||||
memory_usage = run_process.stdout.read().decode("utf-8").strip()
|
||||
per_device_memory = [int(memory) for memory in memory_usage.split("\n")]
|
||||
return per_device_memory
|
||||
|
||||
# Needs a large model to see the difference.
|
||||
@ -1098,39 +1099,44 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_parallelizable_model_classes:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Retrieve initial memory usage (should be close to 0)
|
||||
initial_memory = get_current_gpu_memory_use()
|
||||
# 1. single gpu memory load + unload + memory measurements
|
||||
# Retrieve initial memory usage (can easily be ~0.6-1.5GB if cuda-kernels have been preloaded by previous tests)
|
||||
memory_at_start = get_current_gpu_memory_use()
|
||||
|
||||
# Put model on device
|
||||
model = model_class(config.from_pretrained("gpt2"))
|
||||
# Put model on device 0 and take a memory snapshot
|
||||
model = model_class(config)
|
||||
model.to("cuda:0")
|
||||
|
||||
# Retrieve the memory after the model is put on the device
|
||||
memory_after_model_load = get_current_gpu_memory_use()
|
||||
|
||||
# The memory use on device 0 should be higher than it was initially.
|
||||
self.assertGreater(memory_after_model_load[0], memory_at_start[0])
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# The memory use on that device should be higher than it was initially.
|
||||
self.assertGreater(memory_after_model_load[0], initial_memory[0])
|
||||
# 2. MP test
|
||||
# it's essential to re-calibrate the usage before the next stage
|
||||
memory_at_start = get_current_gpu_memory_use()
|
||||
|
||||
# Spread model layers over multiple devices
|
||||
model = model_class(config.from_pretrained("gpt2"))
|
||||
model = model_class(config)
|
||||
model.parallelize()
|
||||
memory_after_parallelization = get_current_gpu_memory_use()
|
||||
|
||||
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
|
||||
for n in range(torch.cuda.device_count()):
|
||||
self.assertGreater(memory_after_parallelization[n], initial_memory[n])
|
||||
self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
|
||||
|
||||
# Assert that the memory use of the first device is lower than it was when the entire model was loaded on it
|
||||
# Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it
|
||||
self.assertLess(memory_after_parallelization[0], memory_after_model_load[0])
|
||||
|
||||
# Assert that the memory use of the second device is higher than it was when the entire model was loaded
|
||||
# on the other device.
|
||||
# Assert that the memory use of device 1 is higher than it was when the entire model was loaded
|
||||
# on device 0 and device 1 wasn't used at all
|
||||
self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1])
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
|
Loading…
Reference in New Issue
Block a user