diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index f8e356aedbf..bf137a6af55 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -528,9 +528,39 @@ class Bnb4bitTestMultiGpu(Base4bitTest): This tests that the model has been loaded and can be used correctly on a multi-GPU setup. Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice """ + device_map = { + "transformer.word_embeddings": 0, + "transformer.word_embeddings_layernorm": 0, + "lm_head": 0, + "transformer.h.0": 0, + "transformer.h.1": 0, + "transformer.h.2": 0, + "transformer.h.3": 0, + "transformer.h.4": 0, + "transformer.h.5": 0, + "transformer.h.6": 0, + "transformer.h.7": 0, + "transformer.h.8": 0, + "transformer.h.9": 0, + "transformer.h.10": 1, + "transformer.h.11": 1, + "transformer.h.12": 1, + "transformer.h.13": 1, + "transformer.h.14": 1, + "transformer.h.15": 1, + "transformer.h.16": 1, + "transformer.h.17": 0, + "transformer.h.18": 0, + "transformer.h.19": 0, + "transformer.h.20": 0, + "transformer.h.21": 0, + "transformer.h.22": 0, + "transformer.h.23": 1, + "transformer.ln_f": 0, + } model_parallel = AutoModelForCausalLM.from_pretrained( - self.model_name, load_in_4bit=True, device_map="balanced" + self.model_name, load_in_4bit=True, device_map=device_map ) # Check correct device map diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index fa783b3cbe2..bc7804de9b8 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -682,9 +682,39 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test): This tests that the model has been loaded and can be used correctly on a multi-GPU setup. Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice """ + device_map = { + "transformer.word_embeddings": 0, + "transformer.word_embeddings_layernorm": 0, + "lm_head": 0, + "transformer.h.0": 0, + "transformer.h.1": 0, + "transformer.h.2": 0, + "transformer.h.3": 0, + "transformer.h.4": 0, + "transformer.h.5": 0, + "transformer.h.6": 0, + "transformer.h.7": 0, + "transformer.h.8": 0, + "transformer.h.9": 0, + "transformer.h.10": 1, + "transformer.h.11": 1, + "transformer.h.12": 1, + "transformer.h.13": 1, + "transformer.h.14": 1, + "transformer.h.15": 1, + "transformer.h.16": 1, + "transformer.h.17": 0, + "transformer.h.18": 0, + "transformer.h.19": 0, + "transformer.h.20": 0, + "transformer.h.21": 0, + "transformer.h.22": 0, + "transformer.h.23": 1, + "transformer.ln_f": 0, + } model_parallel = AutoModelForCausalLM.from_pretrained( - self.model_name, load_in_8bit=True, device_map="balanced" + self.model_name, load_in_8bit=True, device_map=device_map ) # Check correct device map