mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
parent
b9374a0763
commit
a861db01e5
@ -528,9 +528,39 @@ class Bnb4bitTestMultiGpu(Base4bitTest):
|
||||
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
|
||||
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
|
||||
"""
|
||||
device_map = {
|
||||
"transformer.word_embeddings": 0,
|
||||
"transformer.word_embeddings_layernorm": 0,
|
||||
"lm_head": 0,
|
||||
"transformer.h.0": 0,
|
||||
"transformer.h.1": 0,
|
||||
"transformer.h.2": 0,
|
||||
"transformer.h.3": 0,
|
||||
"transformer.h.4": 0,
|
||||
"transformer.h.5": 0,
|
||||
"transformer.h.6": 0,
|
||||
"transformer.h.7": 0,
|
||||
"transformer.h.8": 0,
|
||||
"transformer.h.9": 0,
|
||||
"transformer.h.10": 1,
|
||||
"transformer.h.11": 1,
|
||||
"transformer.h.12": 1,
|
||||
"transformer.h.13": 1,
|
||||
"transformer.h.14": 1,
|
||||
"transformer.h.15": 1,
|
||||
"transformer.h.16": 1,
|
||||
"transformer.h.17": 0,
|
||||
"transformer.h.18": 0,
|
||||
"transformer.h.19": 0,
|
||||
"transformer.h.20": 0,
|
||||
"transformer.h.21": 0,
|
||||
"transformer.h.22": 0,
|
||||
"transformer.h.23": 1,
|
||||
"transformer.ln_f": 0,
|
||||
}
|
||||
|
||||
model_parallel = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, load_in_4bit=True, device_map="balanced"
|
||||
self.model_name, load_in_4bit=True, device_map=device_map
|
||||
)
|
||||
|
||||
# Check correct device map
|
||||
|
@ -682,9 +682,39 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
|
||||
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
|
||||
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
|
||||
"""
|
||||
device_map = {
|
||||
"transformer.word_embeddings": 0,
|
||||
"transformer.word_embeddings_layernorm": 0,
|
||||
"lm_head": 0,
|
||||
"transformer.h.0": 0,
|
||||
"transformer.h.1": 0,
|
||||
"transformer.h.2": 0,
|
||||
"transformer.h.3": 0,
|
||||
"transformer.h.4": 0,
|
||||
"transformer.h.5": 0,
|
||||
"transformer.h.6": 0,
|
||||
"transformer.h.7": 0,
|
||||
"transformer.h.8": 0,
|
||||
"transformer.h.9": 0,
|
||||
"transformer.h.10": 1,
|
||||
"transformer.h.11": 1,
|
||||
"transformer.h.12": 1,
|
||||
"transformer.h.13": 1,
|
||||
"transformer.h.14": 1,
|
||||
"transformer.h.15": 1,
|
||||
"transformer.h.16": 1,
|
||||
"transformer.h.17": 0,
|
||||
"transformer.h.18": 0,
|
||||
"transformer.h.19": 0,
|
||||
"transformer.h.20": 0,
|
||||
"transformer.h.21": 0,
|
||||
"transformer.h.22": 0,
|
||||
"transformer.h.23": 1,
|
||||
"transformer.ln_f": 0,
|
||||
}
|
||||
|
||||
model_parallel = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, load_in_8bit=True, device_map="balanced"
|
||||
self.model_name, load_in_8bit=True, device_map=device_map
|
||||
)
|
||||
|
||||
# Check correct device map
|
||||
|
Loading…
Reference in New Issue
Block a user