mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
switch to device agnostic device calling for test cases (#38247)
* use device agnostic APIs in test cases Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add one more Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xpu now supports integer device id, aligning to CUDA behaviors Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update to use device_properties Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update comment Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix comments Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
cba279f46c
commit
a5a0c7b888
@ -273,7 +273,7 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||||||
elif is_torch_hpu_available():
|
elif is_torch_hpu_available():
|
||||||
device_map = {"": f"hpu:{torch.hpu.current_device()}"}
|
device_map = {"": f"hpu:{torch.hpu.current_device()}"}
|
||||||
elif is_torch_xpu_available():
|
elif is_torch_xpu_available():
|
||||||
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
|
device_map = {"": torch.xpu.current_device()}
|
||||||
else:
|
else:
|
||||||
device_map = {"": "cpu"}
|
device_map = {"": "cpu"}
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -136,7 +136,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device_map = {"": torch.cuda.current_device()}
|
device_map = {"": torch.cuda.current_device()}
|
||||||
elif is_torch_xpu_available():
|
elif is_torch_xpu_available():
|
||||||
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
|
device_map = {"": torch.xpu.current_device()}
|
||||||
else:
|
else:
|
||||||
device_map = {"": "cpu"}
|
device_map = {"": "cpu"}
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -28,6 +28,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
|
get_device_properties,
|
||||||
require_deterministic_for_xpu,
|
require_deterministic_for_xpu,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
@ -572,10 +573,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
|
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
|
||||||
)
|
)
|
||||||
batch = data_collator(features)
|
batch = data_collator(features)
|
||||||
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||||
|
|
||||||
res_padded = model(**inputs_dict)
|
res_padded = model(**inputs_dict)
|
||||||
res_padfree = model(**batch_cuda)
|
res_padfree = model(**batch_accelerator)
|
||||||
|
|
||||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
logits_padfree = res_padfree.logits[0]
|
logits_padfree = res_padfree.logits[0]
|
||||||
@ -594,7 +595,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = None
|
tokenizer = None
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
cuda_compute_capability_major_version = None
|
device_properties = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -606,9 +607,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
cls.tokenizer.pad_token_id = cls.model.config.pad_token_id
|
cls.tokenizer.pad_token_id = cls.model.config.pad_token_id
|
||||||
cls.tokenizer.padding_side = "left"
|
cls.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
cls.device_properties = get_device_properties()
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def test_simple_generate(self):
|
def test_simple_generate(self):
|
||||||
expectations = Expectations(
|
expectations = Expectations(
|
||||||
@ -639,7 +638,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_sentence, expected)
|
self.assertEqual(output_sentence, expected)
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.cuda_compute_capability_major_version == 8:
|
if self.device_properties == ("cuda", 8):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
|
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
|
||||||
|
|
||||||
@ -692,7 +691,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_sentences[1], EXPECTED_TEXT[1])
|
self.assertEqual(output_sentences[1], EXPECTED_TEXT[1])
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.cuda_compute_capability_major_version == 8:
|
if self.device_properties == ("cuda", 8):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||||
|
|
||||||
|
@ -763,7 +763,6 @@ class BloomEmbeddingTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_hidden_states_transformers(self):
|
def test_hidden_states_transformers(self):
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
|
model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
|
||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
@ -782,7 +781,7 @@ class BloomEmbeddingTest(unittest.TestCase):
|
|||||||
"max": logits.last_hidden_state.max(dim=-1).values[0][0].item(),
|
"max": logits.last_hidden_state.max(dim=-1).values[0][0].item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if cuda_available:
|
if torch_device == "cuda":
|
||||||
self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=4)
|
self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=4)
|
||||||
else:
|
else:
|
||||||
self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3)
|
self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3)
|
||||||
@ -791,7 +790,6 @@ class BloomEmbeddingTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_logits(self):
|
def test_logits(self):
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
|
model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
|
||||||
torch_device
|
torch_device
|
||||||
) # load in bf16
|
) # load in bf16
|
||||||
@ -807,9 +805,5 @@ class BloomEmbeddingTest(unittest.TestCase):
|
|||||||
output = model(tensor_ids).logits
|
output = model(tensor_ids).logits
|
||||||
|
|
||||||
output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
|
output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
|
||||||
if cuda_available:
|
|
||||||
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6)
|
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6)
|
||||||
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
|
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
|
||||||
else:
|
|
||||||
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
|
|
||||||
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
|
|
||||||
|
@ -133,15 +133,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||||||
@require_torch_large_gpu
|
@require_torch_large_gpu
|
||||||
class Cohere2IntegrationTest(unittest.TestCase):
|
class Cohere2IntegrationTest(unittest.TestCase):
|
||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def test_model_bf16(self):
|
def test_model_bf16(self):
|
||||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||||
|
@ -495,16 +495,6 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class DeepseekV3IntegrationTest(unittest.TestCase):
|
class DeepseekV3IntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
cleanup(torch_device, gc_collect=False)
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
@ -565,16 +565,6 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class DiffLlamaIntegrationTest(unittest.TestCase):
|
class DiffLlamaIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
cleanup(torch_device, gc_collect=False)
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
@ -21,7 +21,9 @@ from packaging import version
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
cleanup,
|
cleanup,
|
||||||
|
get_device_properties,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
@ -105,15 +107,13 @@ class GemmaModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class GemmaIntegrationTest(unittest.TestCase):
|
class GemmaIntegrationTest(unittest.TestCase):
|
||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
cuda_compute_capability_major_version = None
|
device_properties = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
cls.device_properties = get_device_properties()
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
@ -270,7 +270,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_fp16(self):
|
def test_model_7b_fp16(self):
|
||||||
if self.cuda_compute_capability_major_version == 7:
|
if self.device_properties == ("cuda", 7):
|
||||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||||
|
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
@ -293,7 +293,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_bf16(self):
|
def test_model_7b_bf16(self):
|
||||||
if self.cuda_compute_capability_major_version == 7:
|
if self.device_properties == ("cuda", 7):
|
||||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||||
|
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
@ -302,20 +302,16 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
#
|
#
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||||
# considering differences in hardware processing and potential deviations in generated text.
|
# considering differences in hardware processing and potential deviations in generated text.
|
||||||
EXPECTED_TEXTS = {
|
# fmt: off
|
||||||
7: [
|
EXPECTED_TEXTS = Expectations(
|
||||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
{
|
||||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
("cuda", 7): ["""Hello I am doing a project on a 1991 240sx and I am trying to find""", "Hi today I am going to show you how to make a very simple and easy to make a very simple and",],
|
||||||
],
|
("cuda", 8): ["Hello I am doing a project for my school and I am trying to make a program that will read a .txt file", "Hi today I am going to show you how to make a very simple and easy to make a very simple and",],
|
||||||
8: [
|
("rocm", 9): ["Hello I am doing a project for my school and I am trying to get a servo to move a certain amount of degrees", "Hi today I am going to show you how to make a very simple and easy to make DIY light up sign",],
|
||||||
"Hello I am doing a project for my school and I am trying to make a program that will read a .txt file",
|
|
||||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
|
||||||
],
|
|
||||||
9: [
|
|
||||||
"Hello I am doing a project for my school and I am trying to get a servo to move a certain amount of degrees",
|
|
||||||
"Hi today I am going to show you how to make a very simple and easy to make DIY light up sign",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
expected_text = EXPECTED_TEXTS.get_expectation()
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||||
torch_device
|
torch_device
|
||||||
@ -326,11 +322,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
self.assertEqual(output_text, expected_text)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_fp16_static_cache(self):
|
def test_model_7b_fp16_static_cache(self):
|
||||||
if self.cuda_compute_capability_major_version == 7:
|
if self.device_properties == ("cuda", 7):
|
||||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||||
|
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
|
@ -176,15 +176,6 @@ class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class Gemma2IntegrationTest(unittest.TestCase):
|
class Gemma2IntegrationTest(unittest.TestCase):
|
||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@tooslow
|
@tooslow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
@ -80,15 +80,6 @@ class GlmIntegrationTest(unittest.TestCase):
|
|||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
model_id = "THUDM/glm-4-9b"
|
model_id = "THUDM/glm-4-9b"
|
||||||
revision = "refs/pr/15"
|
revision = "refs/pr/15"
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def test_model_9b_fp16(self):
|
def test_model_9b_fp16(self):
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
|
@ -82,15 +82,6 @@ class Glm4IntegrationTest(unittest.TestCase):
|
|||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
model_id = "THUDM/glm-4-0414-9b-chat"
|
model_id = "THUDM/glm-4-0414-9b-chat"
|
||||||
revision = "refs/pr/15"
|
revision = "refs/pr/15"
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def test_model_9b_fp16(self):
|
def test_model_9b_fp16(self):
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
|
@ -305,16 +305,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class GraniteIntegrationTest(unittest.TestCase):
|
class GraniteIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_3b_logits_bf16(self):
|
def test_model_3b_logits_bf16(self):
|
||||||
@ -378,6 +368,7 @@ class GraniteIntegrationTest(unittest.TestCase):
|
|||||||
("cuda", 8): torch.tensor([[-3.2934, -2.6019, -2.6258, -2.1691, -2.6394, -2.6876, -2.7032, -2.9688]]),
|
("cuda", 8): torch.tensor([[-3.2934, -2.6019, -2.6258, -2.1691, -2.6394, -2.6876, -2.7032, -2.9688]]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# fmt: on
|
||||||
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
|
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
|
||||||
|
|
||||||
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
|
||||||
|
@ -304,16 +304,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class GraniteMoeIntegrationTest(unittest.TestCase):
|
class GraniteMoeIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_3b_logits(self):
|
def test_model_3b_logits(self):
|
||||||
@ -360,6 +350,7 @@ class GraniteMoeIntegrationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_model_3b_generation(self):
|
def test_model_3b_generation(self):
|
||||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||||
|
# fmt: off
|
||||||
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
||||||
{
|
{
|
||||||
("xpu", 3): (
|
("xpu", 3): (
|
||||||
@ -378,6 +369,7 @@ class GraniteMoeIntegrationTest(unittest.TestCase):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# fmt: on
|
||||||
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
||||||
|
|
||||||
prompt = "Simply put, the theory of relativity states that "
|
prompt = "Simply put, the theory of relativity states that "
|
||||||
|
@ -105,16 +105,6 @@ class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.
|
|||||||
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
|
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
|
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_logits(self):
|
def test_model_logits(self):
|
||||||
input_ids = [31390, 631, 4162, 30, 322, 25342, 432, 1875, 43826, 10066, 688, 225]
|
input_ids = [31390, 631, 4162, 30, 322, 25342, 432, 1875, 43826, 10066, 688, 225]
|
||||||
|
@ -307,16 +307,6 @@ class GraniteMoeSharedModelTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class GraniteMoeSharedIntegrationTest(unittest.TestCase):
|
class GraniteMoeSharedIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_3b_logits(self):
|
def test_model_3b_logits(self):
|
||||||
@ -363,6 +353,7 @@ class GraniteMoeSharedIntegrationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_model_3b_generation(self):
|
def test_model_3b_generation(self):
|
||||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||||
|
# fmt: off
|
||||||
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
||||||
{
|
{
|
||||||
("xpu", 3): (
|
("xpu", 3): (
|
||||||
@ -381,6 +372,7 @@ class GraniteMoeSharedIntegrationTest(unittest.TestCase):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# fmt: on
|
||||||
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
||||||
|
|
||||||
prompt = "Simply put, the theory of relativity states that "
|
prompt = "Simply put, the theory of relativity states that "
|
||||||
|
@ -79,15 +79,6 @@ class HeliumModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
# @require_torch_gpu
|
# @require_torch_gpu
|
||||||
class HeliumIntegrationTest(unittest.TestCase):
|
class HeliumIntegrationTest(unittest.TestCase):
|
||||||
input_text = ["Hello, today is a great day to"]
|
input_text = ["Hello, today is a great day to"]
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_2b(self):
|
def test_model_2b(self):
|
||||||
|
@ -21,6 +21,8 @@ import pytest
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
|
get_device_properties,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
@ -554,30 +556,32 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
class JambaModelIntegrationTest(unittest.TestCase):
|
class JambaModelIntegrationTest(unittest.TestCase):
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
cuda_compute_capability_major_version = None
|
device_properties = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
model_id = "ai21labs/Jamba-tiny-dev"
|
model_id = "ai21labs/Jamba-tiny-dev"
|
||||||
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||||
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
cls.device_properties = get_device_properties()
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_simple_generate(self):
|
def test_simple_generate(self):
|
||||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
|
||||||
#
|
#
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
|
||||||
# considering differences in hardware processing and potential deviations in generated text.
|
# considering differences in hardware processing and potential deviations in generated text.
|
||||||
EXPECTED_TEXTS = {
|
# fmt: off
|
||||||
7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
EXPECTED_TEXTS = Expectations(
|
||||||
8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
{
|
||||||
9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
|
("cuda", 7): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||||
|
("cuda", 8): "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
||||||
|
("rocm", 9): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
expected_sentence = EXPECTED_TEXTS.get_expectation()
|
||||||
|
|
||||||
self.model.to(torch_device)
|
self.model.to(torch_device)
|
||||||
|
|
||||||
@ -586,10 +590,10 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
].to(torch_device)
|
].to(torch_device)
|
||||||
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||||
output_sentence = self.tokenizer.decode(out[0, :])
|
output_sentence = self.tokenizer.decode(out[0, :])
|
||||||
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
self.assertEqual(output_sentence, expected_sentence)
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.cuda_compute_capability_major_version == 8:
|
if self.device_properties == ("cuda", 8):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=input_ids).logits
|
logits = self.model(input_ids=input_ids).logits
|
||||||
|
|
||||||
@ -607,24 +611,19 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_simple_batched_generate_with_padding(self):
|
def test_simple_batched_generate_with_padding(self):
|
||||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
|
||||||
#
|
#
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
|
||||||
# considering differences in hardware processing and potential deviations in generated text.
|
# considering differences in hardware processing and potential deviations in generated text.
|
||||||
EXPECTED_TEXTS = {
|
# fmt: off
|
||||||
7: [
|
EXPECTED_TEXTS = Expectations(
|
||||||
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
|
{
|
||||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
|
("cuda", 7): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",],
|
||||||
],
|
("cuda", 8): ["<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",],
|
||||||
8: [
|
("rocm", 9): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",],
|
||||||
"<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
|
||||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",
|
|
||||||
],
|
|
||||||
9: [
|
|
||||||
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
|
||||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
expected_sentences = EXPECTED_TEXTS.get_expectation()
|
||||||
|
|
||||||
self.model.to(torch_device)
|
self.model.to(torch_device)
|
||||||
|
|
||||||
@ -633,11 +632,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||||
output_sentences = self.tokenizer.batch_decode(out)
|
output_sentences = self.tokenizer.batch_decode(out)
|
||||||
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
|
self.assertEqual(output_sentences[0], expected_sentences[0])
|
||||||
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
|
self.assertEqual(output_sentences[1], expected_sentences[1])
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.cuda_compute_capability_major_version == 8:
|
if self.device_properties == ("cuda", 8):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||||
|
|
||||||
|
@ -38,15 +38,9 @@ if is_torch_available():
|
|||||||
@require_read_token
|
@require_read_token
|
||||||
class Llama4IntegrationTest(unittest.TestCase):
|
class Llama4IntegrationTest(unittest.TestCase):
|
||||||
model_id = "meta-llama/Llama-4-Scout-17B-16E"
|
model_id = "meta-llama/Llama-4-Scout-17B-16E"
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
cls.model = Llama4ForConditionalGeneration.from_pretrained(
|
cls.model = Llama4ForConditionalGeneration.from_pretrained(
|
||||||
"meta-llama/Llama-4-Scout-17B-16E",
|
"meta-llama/Llama-4-Scout-17B-16E",
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
|
@ -21,8 +21,10 @@ from packaging import version
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
cleanup,
|
cleanup,
|
||||||
|
get_device_properties,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
@ -110,15 +112,13 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class MistralIntegrationTest(unittest.TestCase):
|
class MistralIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
cuda_compute_capability_major_version = None
|
device_properties = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
cls.device_properties = get_device_properties()
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cleanup(torch_device, gc_collect=True)
|
cleanup(torch_device, gc_collect=True)
|
||||||
@ -136,19 +136,20 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
||||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
# ("cuda", 8) for A100/A10, and ("cuda", 7) 7 for T4.
|
||||||
#
|
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
|
||||||
# considering differences in hardware processing and potential deviations in output.
|
# considering differences in hardware processing and potential deviations in output.
|
||||||
EXPECTED_SLICE = {
|
# fmt: off
|
||||||
7: torch.tensor([-5.8828, -5.8633, -0.1042, -4.7266, -5.8828, -5.8789, -5.8789, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -1.0801, 1.7598, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828]),
|
EXPECTED_SLICES = Expectations(
|
||||||
8: torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]),
|
{
|
||||||
9: torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]),
|
("cuda", 7): torch.tensor([-5.8828, -5.8633, -0.1042, -4.7266, -5.8828, -5.8789, -5.8789, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -1.0801, 1.7598, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828]),
|
||||||
} # fmt: skip
|
("cuda", 8): torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]),
|
||||||
|
("rocm", 9): torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]),
|
||||||
torch.testing.assert_close(
|
}
|
||||||
out[0, 0, :30], EXPECTED_SLICE[self.cuda_compute_capability_major_version], atol=1e-4, rtol=1e-4
|
|
||||||
)
|
)
|
||||||
|
# fmt: on
|
||||||
|
expected_slice = EXPECTED_SLICES.get_expectation()
|
||||||
|
|
||||||
|
torch.testing.assert_close(out[0, 0, :30], expected_slice, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@ -278,7 +279,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
if self.cuda_compute_capability_major_version == 7:
|
if self.device_properties == ("cuda", 7):
|
||||||
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
|
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
|
||||||
|
|
||||||
NUM_TOKENS_TO_GENERATE = 40
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
|
@ -19,6 +19,8 @@ import pytest
|
|||||||
|
|
||||||
from transformers import MixtralConfig, is_torch_available
|
from transformers import MixtralConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
|
get_device_properties,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@ -142,13 +144,11 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
class MixtralIntegrationTest(unittest.TestCase):
|
class MixtralIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
cuda_compute_capability_major_version = None
|
device_properties = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
cls.device_properties = get_device_properties()
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@ -161,32 +161,26 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||||
# these logits have been obtained with the original megablocks implementation.
|
# these logits have been obtained with the original megablocks implementation.
|
||||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4
|
||||||
#
|
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
|
||||||
# considering differences in hardware processing and potential deviations in output.
|
# considering differences in hardware processing and potential deviations in output.
|
||||||
EXPECTED_LOGITS = {
|
# fmt: off
|
||||||
7: torch.Tensor([[0.1640, 0.1621, 0.6093], [-0.8906, -0.1640, -0.6093], [0.1562, 0.1250, 0.7226]]).to(
|
EXPECTED_LOGITS = Expectations(
|
||||||
torch_device
|
{
|
||||||
),
|
("cuda", 7): torch.Tensor([[0.1640, 0.1621, 0.6093], [-0.8906, -0.1640, -0.6093], [0.1562, 0.1250, 0.7226]]).to(torch_device),
|
||||||
8: torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(
|
("cuda", 8): torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(torch_device),
|
||||||
torch_device
|
("rocm", 9): torch.Tensor([[0.1641, 0.1621, 0.6094], [-0.8906, -0.1631, -0.6094], [0.1572, 0.1260, 0.7227]]).to(torch_device),
|
||||||
),
|
|
||||||
9: torch.Tensor([[0.1641, 0.1621, 0.6094], [-0.8906, -0.1631, -0.6094], [0.1572, 0.1260, 0.7227]]).to(
|
|
||||||
torch_device
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
expected_logit = EXPECTED_LOGITS.get_expectation()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(dummy_input).logits
|
logits = model(dummy_input).logits
|
||||||
|
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(logits[0, :3, :3], expected_logit, atol=1e-3, rtol=1e-3)
|
||||||
logits[0, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
torch.testing.assert_close(logits[1, :3, :3], expected_logit, atol=1e-3, rtol=1e-3)
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
logits[1, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
|
||||||
)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@ -201,33 +195,28 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||||
#
|
#
|
||||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
|
||||||
#
|
#
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
|
||||||
# considering differences in hardware processing and potential deviations in generated text.
|
# considering differences in hardware processing and potential deviations in generated text.
|
||||||
EXPECTED_LOGITS_LEFT_UNPADDED = {
|
# fmt: off
|
||||||
7: torch.Tensor(
|
EXPECTED_LOGITS_LEFT_UNPADDED = Expectations(
|
||||||
[[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]],
|
{
|
||||||
).to(torch_device),
|
("cuda", 7): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]]).to(torch_device),
|
||||||
8: torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to(
|
("cuda", 8): torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to(torch_device),
|
||||||
torch_device,
|
("rocm", 9): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(torch_device),
|
||||||
),
|
|
||||||
9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(
|
|
||||||
torch_device
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
expected_left_unpadded = EXPECTED_LOGITS_LEFT_UNPADDED.get_expectation()
|
||||||
|
|
||||||
EXPECTED_LOGITS_RIGHT_UNPADDED = {
|
EXPECTED_LOGITS_RIGHT_UNPADDED = Expectations(
|
||||||
7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to(
|
{
|
||||||
torch_device
|
("cuda", 7): torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to(torch_device),
|
||||||
),
|
("cuda", 8): torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to(torch_device),
|
||||||
8: torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to(
|
("rocm", 9): torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(torch_device),
|
||||||
torch_device,
|
|
||||||
),
|
|
||||||
9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(
|
|
||||||
torch_device
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
expected_right_unpadded = EXPECTED_LOGITS_RIGHT_UNPADDED.get_expectation()
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(dummy_input, attention_mask=attention_mask).logits
|
logits = model(dummy_input, attention_mask=attention_mask).logits
|
||||||
@ -235,13 +224,13 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
logits[0, -3:, -3:],
|
logits[0, -3:, -3:],
|
||||||
EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version],
|
expected_left_unpadded,
|
||||||
atol=1e-3,
|
atol=1e-3,
|
||||||
rtol=1e-3,
|
rtol=1e-3,
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
logits[1, -3:, -3:],
|
logits[1, -3:, -3:],
|
||||||
EXPECTED_LOGITS_RIGHT_UNPADDED[self.cuda_compute_capability_major_version],
|
expected_right_unpadded,
|
||||||
atol=1e-3,
|
atol=1e-3,
|
||||||
rtol=1e-3,
|
rtol=1e-3,
|
||||||
)
|
)
|
||||||
|
@ -99,16 +99,6 @@ class NemotronModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class NemotronIntegrationTest(unittest.TestCase):
|
class NemotronIntegrationTest(unittest.TestCase):
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
||||||
# Depending on the hardware we get different logits / generations
|
|
||||||
cuda_compute_capability_major_version = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if is_torch_available() and torch.cuda.is_available():
|
|
||||||
# 8 is for A100 / A10 and 7 for T4
|
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_nemotron_8b_generation_sdpa(self):
|
def test_nemotron_8b_generation_sdpa(self):
|
||||||
|
@ -22,6 +22,7 @@ from packaging import version
|
|||||||
|
|
||||||
from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache
|
from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_aqlm,
|
require_aqlm,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -81,8 +82,6 @@ class AqlmTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I"
|
EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I"
|
||||||
|
|
||||||
device_map = "cuda"
|
|
||||||
|
|
||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -92,12 +91,12 @@ class AqlmTest(unittest.TestCase):
|
|||||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
cls.model_name,
|
cls.model_name,
|
||||||
device_map=cls.device_map,
|
device_map=torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model_conversion(self):
|
def test_quantized_model_conversion(self):
|
||||||
@ -170,7 +169,7 @@ class AqlmTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
self.quantized_model.save_pretrained(tmpdirname)
|
self.quantized_model.save_pretrained(tmpdirname)
|
||||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
|
||||||
|
|
||||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import unittest
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
get_device_properties,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_auto_awq,
|
require_auto_awq,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@ -61,12 +62,10 @@ class AwqConfigTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Only cuda and xpu devices can run this function
|
# Only cuda and xpu devices can run this function
|
||||||
support_llm_awq = False
|
support_llm_awq = False
|
||||||
if torch.cuda.is_available():
|
device_type, major = get_device_properties()
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
if device_type == "cuda" and major >= 8:
|
||||||
major, minor = compute_capability
|
|
||||||
if major >= 8:
|
|
||||||
support_llm_awq = True
|
support_llm_awq = True
|
||||||
elif torch.xpu.is_available():
|
elif device_type == "xpu":
|
||||||
support_llm_awq = True
|
support_llm_awq = True
|
||||||
|
|
||||||
if support_llm_awq:
|
if support_llm_awq:
|
||||||
@ -357,7 +356,7 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear))
|
self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear))
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
|
||||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||||
)
|
)
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@ -388,7 +387,7 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
|
||||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||||
)
|
)
|
||||||
def test_generation_fused_batched(self):
|
def test_generation_fused_batched(self):
|
||||||
@ -441,7 +440,7 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
|
||||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||||
)
|
)
|
||||||
def test_generation_custom_model(self):
|
def test_generation_custom_model(self):
|
||||||
|
@ -23,6 +23,7 @@ from transformers import (
|
|||||||
OPTForCausalLM,
|
OPTForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
@ -56,7 +57,6 @@ class BitNetQuantConfigTest(unittest.TestCase):
|
|||||||
@require_accelerate
|
@require_accelerate
|
||||||
class BitNetTest(unittest.TestCase):
|
class BitNetTest(unittest.TestCase):
|
||||||
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -65,11 +65,11 @@ class BitNetTest(unittest.TestCase):
|
|||||||
Load the model
|
Load the model
|
||||||
"""
|
"""
|
||||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device)
|
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=torch_device)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_replace_with_bitlinear(self):
|
def test_replace_with_bitlinear(self):
|
||||||
@ -100,7 +100,7 @@ class BitNetTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
input_text = "What are we having for dinner?"
|
input_text = "What are we having for dinner?"
|
||||||
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
|
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
|
||||||
input_ids = self.tokenizer(input_text, return_tensors="pt").to("cuda")
|
input_ids = self.tokenizer(input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
|
output = self.quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
|
||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||||
@ -127,7 +127,7 @@ class BitNetTest(unittest.TestCase):
|
|||||||
from transformers.integrations import BitLinear
|
from transformers.integrations import BitLinear
|
||||||
|
|
||||||
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
|
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
|
||||||
layer.to(self.device)
|
layer.to(torch_device)
|
||||||
|
|
||||||
input_tensor = torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32).to(torch_device)
|
input_tensor = torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32).to(torch_device)
|
||||||
|
|
||||||
@ -202,9 +202,8 @@ class BitNetTest(unittest.TestCase):
|
|||||||
class BitNetSerializationTest(unittest.TestCase):
|
class BitNetSerializationTest(unittest.TestCase):
|
||||||
def test_model_serialization(self):
|
def test_model_serialization(self):
|
||||||
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
||||||
device = "cuda"
|
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=torch_device)
|
||||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
|
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_ref = quantized_model.forward(input_tensor).logits
|
logits_ref = quantized_model.forward(input_tensor).logits
|
||||||
@ -215,10 +214,10 @@ class BitNetSerializationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Remove old model
|
# Remove old model
|
||||||
del quantized_model
|
del quantized_model
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
# Load and check if the logits match
|
# Load and check if the logits match
|
||||||
model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=device)
|
model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||||
|
@ -32,6 +32,7 @@ from transformers.models.opt.modeling_opt import OPTAttention
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
apply_skip_if_not_implemented,
|
apply_skip_if_not_implemented,
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
backend_torch_accelerator_module,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
@ -376,7 +377,7 @@ class Bnb4BitT5Test(unittest.TestCase):
|
|||||||
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
|
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
|
||||||
"""
|
"""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_inference_without_keep_in_fp32(self):
|
def test_inference_without_keep_in_fp32(self):
|
||||||
r"""
|
r"""
|
||||||
@ -460,7 +461,7 @@ class Classes4BitModelTest(Base4bitTest):
|
|||||||
del self.seq_to_seq_model
|
del self.seq_to_seq_model
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_correct_head_class(self):
|
def test_correct_head_class(self):
|
||||||
r"""
|
r"""
|
||||||
@ -491,7 +492,7 @@ class Pipeline4BitTest(Base4bitTest):
|
|||||||
del self.pipe
|
del self.pipe
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_pipeline(self):
|
def test_pipeline(self):
|
||||||
r"""
|
r"""
|
||||||
@ -589,10 +590,10 @@ class Bnb4BitTestTraining(Base4bitTest):
|
|||||||
# Step 1: freeze all parameters
|
# Step 1: freeze all parameters
|
||||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
|
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch_device in ["cuda", "xpu"]:
|
||||||
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
|
self.assertEqual(
|
||||||
elif torch.xpu.is_available():
|
set(model.hf_device_map.values()), {backend_torch_accelerator_module(torch_device).current_device()}
|
||||||
self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
|
)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
|
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
|
||||||
|
|
||||||
|
@ -31,6 +31,8 @@ from transformers import (
|
|||||||
from transformers.models.opt.modeling_opt import OPTAttention
|
from transformers.models.opt.modeling_opt import OPTAttention
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
apply_skip_if_not_implemented,
|
apply_skip_if_not_implemented,
|
||||||
|
backend_empty_cache,
|
||||||
|
backend_torch_accelerator_module,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@ -137,7 +139,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
del self.model_8bit
|
del self.model_8bit
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_get_keys_to_not_convert(self):
|
def test_get_keys_to_not_convert(self):
|
||||||
r"""
|
r"""
|
||||||
@ -484,7 +486,7 @@ class MixedInt8T5Test(unittest.TestCase):
|
|||||||
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
|
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
|
||||||
"""
|
"""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_inference_without_keep_in_fp32(self):
|
def test_inference_without_keep_in_fp32(self):
|
||||||
r"""
|
r"""
|
||||||
@ -599,7 +601,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
|||||||
del self.seq_to_seq_model
|
del self.seq_to_seq_model
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_correct_head_class(self):
|
def test_correct_head_class(self):
|
||||||
r"""
|
r"""
|
||||||
@ -631,7 +633,7 @@ class MixedInt8TestPipeline(BaseMixedInt8Test):
|
|||||||
del self.pipe
|
del self.pipe
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def test_pipeline(self):
|
def test_pipeline(self):
|
||||||
r"""
|
r"""
|
||||||
@ -872,10 +874,10 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
|
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch_device in ["cuda", "xpu"]:
|
||||||
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
|
self.assertEqual(
|
||||||
elif torch.xpu.is_available():
|
set(model.hf_device_map.values()), {backend_torch_accelerator_module(torch_device).current_device()}
|
||||||
self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
|
)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
|
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import unittest
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.testing_utils import require_compressed_tensors, require_torch
|
from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
from transformers.utils.quantization_config import CompressedTensorsConfig
|
from transformers.utils.quantization_config import CompressedTensorsConfig
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class StackCompressedModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_compressed_uncompressed_model_shapes(self):
|
def test_compressed_uncompressed_model_shapes(self):
|
||||||
@ -160,7 +160,7 @@ class RunCompressedTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_default_run_compressed__True(self):
|
def test_default_run_compressed__True(self):
|
||||||
|
@ -2,7 +2,7 @@ import gc
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
|
||||||
from transformers.testing_utils import require_compressed_tensors, require_torch
|
from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ class CompressedTensorsTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_config_args(self):
|
def test_config_args(self):
|
||||||
|
@ -18,6 +18,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, EetqConfig, OPTForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, EetqConfig, OPTForCausalLM
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_eetq,
|
require_eetq,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -87,7 +88,7 @@ class EetqTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model_conversion(self):
|
def test_quantized_model_conversion(self):
|
||||||
|
@ -18,6 +18,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_fbgemm_gpu,
|
require_fbgemm_gpu,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
@ -126,7 +127,7 @@ class FbgemmFp8Test(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model_conversion(self):
|
def test_quantized_model_conversion(self):
|
||||||
|
@ -19,6 +19,7 @@ import unittest
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
get_device_properties,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@ -254,7 +255,7 @@ class FP8LinearTest(unittest.TestCase):
|
|||||||
device = torch_device
|
device = torch_device
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9,
|
||||||
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
||||||
)
|
)
|
||||||
def test_linear_preserves_shape(self):
|
def test_linear_preserves_shape(self):
|
||||||
@ -270,7 +271,7 @@ class FP8LinearTest(unittest.TestCase):
|
|||||||
self.assertEqual(x_.shape, x.shape)
|
self.assertEqual(x_.shape, x.shape)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9,
|
||||||
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
||||||
)
|
)
|
||||||
def test_linear_with_diff_feature_size_preserves_shape(self):
|
def test_linear_with_diff_feature_size_preserves_shape(self):
|
||||||
|
@ -18,6 +18,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HiggsConfig, OPTForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HiggsConfig, OPTForCausalLM
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_flute_hadamard,
|
require_flute_hadamard,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -87,7 +88,7 @@ class HiggsTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model_conversion(self):
|
def test_quantized_model_conversion(self):
|
||||||
|
@ -17,6 +17,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_hqq,
|
require_hqq,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -50,7 +51,7 @@ class HQQLLMRunner:
|
|||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
@ -187,7 +188,7 @@ class HQQTestBias(unittest.TestCase):
|
|||||||
hqq_runner.model.save_pretrained(tmpdirname)
|
hqq_runner.model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
del hqq_runner.model
|
del hqq_runner.model
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
model_loaded = AutoModelForCausalLM.from_pretrained(
|
model_loaded = AutoModelForCausalLM.from_pretrained(
|
||||||
tmpdirname, torch_dtype=torch.float16, device_map=torch_device
|
tmpdirname, torch_dtype=torch.float16, device_map=torch_device
|
||||||
@ -228,7 +229,7 @@ class HQQSerializationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Remove old model
|
# Remove old model
|
||||||
del hqq_runner.model
|
del hqq_runner.model
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
# Load and check if the logits match
|
# Load and check if the logits match
|
||||||
model_loaded = AutoModelForCausalLM.from_pretrained(
|
model_loaded = AutoModelForCausalLM.from_pretrained(
|
||||||
|
@ -18,6 +18,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, SpQRConfig, StaticCache
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, SpQRConfig, StaticCache
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_spqr,
|
require_spqr,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -82,8 +83,6 @@ class SpQRTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
EXPECTED_OUTPUT_COMPILE = "Hello my name is Jake and I am a 20 year old student at the University of North Texas. (Go Mean Green!) I am a huge fan of the Dallas"
|
EXPECTED_OUTPUT_COMPILE = "Hello my name is Jake and I am a 20 year old student at the University of North Texas. (Go Mean Green!) I am a huge fan of the Dallas"
|
||||||
|
|
||||||
device_map = "cuda"
|
|
||||||
|
|
||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -93,12 +92,12 @@ class SpQRTest(unittest.TestCase):
|
|||||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
cls.model_name,
|
cls.model_name,
|
||||||
device_map=cls.device_map,
|
device_map=torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model_conversion(self):
|
def test_quantized_model_conversion(self):
|
||||||
@ -158,7 +157,7 @@ class SpQRTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
self.quantized_model.save_pretrained(tmpdirname)
|
self.quantized_model.save_pretrained(tmpdirname)
|
||||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
|
||||||
|
|
||||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
@ -21,10 +21,13 @@ from packaging import version
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
|
get_device_properties,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torchao,
|
require_torchao,
|
||||||
require_torchao_version_greater_or_equal,
|
require_torchao_version_greater_or_equal,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_torch_available, is_torchao_available
|
from transformers.utils import is_torch_available, is_torchao_available
|
||||||
|
|
||||||
@ -131,7 +134,7 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_int4wo_quant(self):
|
def test_int4wo_quant(self):
|
||||||
@ -260,7 +263,7 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoGPUTest(TorchAoTest):
|
class TorchAoGPUTest(TorchAoTest):
|
||||||
device = "cuda"
|
device = torch_device
|
||||||
quant_scheme_kwargs = {"group_size": 32}
|
quant_scheme_kwargs = {"group_size": 32}
|
||||||
|
|
||||||
def test_int4wo_offload(self):
|
def test_int4wo_offload(self):
|
||||||
@ -397,7 +400,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_original_model_expected_output(self):
|
def test_original_model_expected_output(self):
|
||||||
@ -452,33 +455,33 @@ class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoSerializationGPTTest(TorchAoSerializationTest):
|
class TorchAoSerializationGPTTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
|
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
|
||||||
device = "cuda:0"
|
device = f"{torch_device}:0"
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
device = "cuda:0"
|
device = f"{torch_device}:0"
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
||||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
device = "cuda:0"
|
device = f"{torch_device}:0"
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torchao_version_greater_or_equal("0.10.0")
|
@require_torchao_version_greater_or_equal("0.10.0")
|
||||||
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
||||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
device = "cuda:0"
|
device = f"{torch_device}:0"
|
||||||
|
|
||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
if not (get_device_properties()[0] == "cuda" and get_device_properties()[1] >= 9):
|
||||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
from torchao.quantization import Float8WeightOnlyConfig
|
from torchao.quantization import Float8WeightOnlyConfig
|
||||||
@ -493,12 +496,12 @@ class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
|||||||
@require_torchao_version_greater_or_equal("0.10.0")
|
@require_torchao_version_greater_or_equal("0.10.0")
|
||||||
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
||||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
device = "cuda:0"
|
device = f"{torch_device}:0"
|
||||||
|
|
||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
if not (get_device_properties()[0] == "cuda" and get_device_properties()[1] >= 9):
|
||||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||||
|
@ -18,6 +18,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, VptqConfig
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, VptqConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
@ -74,7 +75,7 @@ class VptqTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model(self):
|
def test_quantized_model(self):
|
||||||
|
@ -73,7 +73,10 @@ from transformers.models.auto.modeling_auto import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
backend_device_count,
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
backend_memory_allocated,
|
||||||
|
backend_torch_accelerator_module,
|
||||||
get_device_properties,
|
get_device_properties,
|
||||||
hub_retry,
|
hub_retry,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
@ -2613,7 +2616,7 @@ class ModelTesterMixin:
|
|||||||
for k in blacklist_non_batched_params:
|
for k in blacklist_non_batched_params:
|
||||||
inputs_dict.pop(k, None)
|
inputs_dict.pop(k, None)
|
||||||
|
|
||||||
# move input tensors to cuda:O
|
# move input tensors to accelerator O
|
||||||
for k, v in inputs_dict.items():
|
for k, v in inputs_dict.items():
|
||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
inputs_dict[k] = v.to(0)
|
inputs_dict[k] = v.to(0)
|
||||||
@ -2636,12 +2639,12 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# a candidate for testing_utils
|
# a candidate for testing_utils
|
||||||
def get_current_gpu_memory_use():
|
def get_current_gpu_memory_use():
|
||||||
"""returns a list of cuda memory allocations per GPU in MBs"""
|
"""returns a list of VRAM allocations per GPU in MBs"""
|
||||||
|
|
||||||
per_device_memory = []
|
per_device_memory = []
|
||||||
for id in range(torch.cuda.device_count()):
|
for id in range(backend_device_count(torch_device)):
|
||||||
with torch.cuda.device(id):
|
with backend_torch_accelerator_module(torch_device).device(id):
|
||||||
per_device_memory.append(torch.cuda.memory_allocated() >> 20)
|
per_device_memory.append(backend_memory_allocated(torch_device) >> 20)
|
||||||
|
|
||||||
return per_device_memory
|
return per_device_memory
|
||||||
|
|
||||||
@ -2657,7 +2660,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Put model on device 0 and take a memory snapshot
|
# Put model on device 0 and take a memory snapshot
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to("cuda:0")
|
model.to(f"{torch_device}:0")
|
||||||
memory_after_model_load = get_current_gpu_memory_use()
|
memory_after_model_load = get_current_gpu_memory_use()
|
||||||
|
|
||||||
# The memory use on device 0 should be higher than it was initially.
|
# The memory use on device 0 should be higher than it was initially.
|
||||||
@ -2717,7 +2720,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
model.parallelize()
|
model.parallelize()
|
||||||
|
|
||||||
parallel_output = model(**cast_to_device(inputs_dict, "cuda:0"))
|
parallel_output = model(**cast_to_device(inputs_dict, f"{torch_device}:0"))
|
||||||
|
|
||||||
for value, parallel_value in zip(output, parallel_output):
|
for value, parallel_value in zip(output, parallel_output):
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
@ -4240,10 +4243,10 @@ class ModelTesterMixin:
|
|||||||
# add position_ids + fa_kwargs
|
# add position_ids + fa_kwargs
|
||||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||||
batch = data_collator(features)
|
batch = data_collator(features)
|
||||||
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||||
|
|
||||||
res_padded = model(**inputs_dict)
|
res_padded = model(**inputs_dict)
|
||||||
res_padfree = model(**batch_cuda)
|
res_padfree = model(**batch_accelerator)
|
||||||
|
|
||||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
logits_padfree = res_padfree.logits[0]
|
logits_padfree = res_padfree.logits[0]
|
||||||
|
@ -3224,7 +3224,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
|
# For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
|
||||||
# in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
|
# in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
|
||||||
# GPU 0 will call first and sometimes GPU 1).
|
# GPU 0 will call first and sometimes GPU 1).
|
||||||
random_torch = not torch.cuda.is_available() or torch.cuda.device_count() <= 1
|
random_torch = not torch.cuda.is_available() or backend_device_count(torch_device) <= 1
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
@ -22,6 +22,7 @@ from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATION
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureStderr,
|
CaptureStderr,
|
||||||
backend_device_count,
|
backend_device_count,
|
||||||
|
backend_torch_accelerator_module,
|
||||||
cleanup,
|
cleanup,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@ -430,11 +431,7 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
|||||||
original = GenerationConfig(**common)
|
original = GenerationConfig(**common)
|
||||||
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||||
|
|
||||||
torch_accelerator_module = None
|
torch_accelerator_module = backend_torch_accelerator_module(device.type)
|
||||||
if device.type == "cuda":
|
|
||||||
torch_accelerator_module = torch.cuda
|
|
||||||
elif device.type == "xpu":
|
|
||||||
torch_accelerator_module = torch.xpu
|
|
||||||
|
|
||||||
torch_accelerator_module.reset_peak_memory_stats(device)
|
torch_accelerator_module.reset_peak_memory_stats(device)
|
||||||
model.generate(generation_config=original, **inputs)
|
model.generate(generation_config=original, **inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user