Fix Device map for bitsandbytes tests (#36800)

fix
This commit is contained in:
Mohamed Mekkouri 2025-03-19 11:57:13 +01:00 committed by GitHub
parent b9374a0763
commit a861db01e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 2 deletions

View File

@ -528,9 +528,39 @@ class Bnb4bitTestMultiGpu(Base4bitTest):
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
"""
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"lm_head": 0,
"transformer.h.0": 0,
"transformer.h.1": 0,
"transformer.h.2": 0,
"transformer.h.3": 0,
"transformer.h.4": 0,
"transformer.h.5": 0,
"transformer.h.6": 0,
"transformer.h.7": 0,
"transformer.h.8": 0,
"transformer.h.9": 0,
"transformer.h.10": 1,
"transformer.h.11": 1,
"transformer.h.12": 1,
"transformer.h.13": 1,
"transformer.h.14": 1,
"transformer.h.15": 1,
"transformer.h.16": 1,
"transformer.h.17": 0,
"transformer.h.18": 0,
"transformer.h.19": 0,
"transformer.h.20": 0,
"transformer.h.21": 0,
"transformer.h.22": 0,
"transformer.h.23": 1,
"transformer.ln_f": 0,
}
model_parallel = AutoModelForCausalLM.from_pretrained(
self.model_name, load_in_4bit=True, device_map="balanced"
self.model_name, load_in_4bit=True, device_map=device_map
)
# Check correct device map

View File

@ -682,9 +682,39 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
"""
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"lm_head": 0,
"transformer.h.0": 0,
"transformer.h.1": 0,
"transformer.h.2": 0,
"transformer.h.3": 0,
"transformer.h.4": 0,
"transformer.h.5": 0,
"transformer.h.6": 0,
"transformer.h.7": 0,
"transformer.h.8": 0,
"transformer.h.9": 0,
"transformer.h.10": 1,
"transformer.h.11": 1,
"transformer.h.12": 1,
"transformer.h.13": 1,
"transformer.h.14": 1,
"transformer.h.15": 1,
"transformer.h.16": 1,
"transformer.h.17": 0,
"transformer.h.18": 0,
"transformer.h.19": 0,
"transformer.h.20": 0,
"transformer.h.21": 0,
"transformer.h.22": 0,
"transformer.h.23": 1,
"transformer.ln_f": 0,
}
model_parallel = AutoModelForCausalLM.from_pretrained(
self.model_name, load_in_8bit=True, device_map="balanced"
self.model_name, load_in_8bit=True, device_map=device_map
)
# Check correct device map