mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00
[CI] green llama tests (#37244)
* green llama tests * use cleanup instead * better test comment; cleanup upgrade * better test comment; cleanup upgrade
This commit is contained in:
parent
782d7d945d
commit
9a1c1fe7ed
@ -118,7 +118,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + num_tokens_to_generate,
|
max_cache_len=seq_length + num_tokens_to_generate,
|
||||||
@ -144,7 +144,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + num_tokens_to_generate,
|
max_cache_len=seq_length + num_tokens_to_generate,
|
||||||
@ -187,7 +187,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
||||||
@ -254,7 +254,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + 128,
|
max_cache_len=seq_length + 128,
|
||||||
@ -271,7 +271,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + 128,
|
max_cache_len=seq_length + 128,
|
||||||
@ -287,7 +287,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + 128,
|
max_cache_len=seq_length + 128,
|
||||||
@ -303,7 +303,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
model.config,
|
model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
max_cache_len=seq_length + 128,
|
max_cache_len=seq_length + 128,
|
||||||
|
@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
||||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||||
device=model.device,
|
device=model.device,
|
||||||
@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
|
|||||||
batch_size, seq_length = inputs["input_ids"].shape
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||||
)
|
)
|
||||||
cache_position = torch.arange(seq_length, device=torch_device)
|
cache_position = torch.arange(seq_length, device=torch_device)
|
||||||
generated_ids = torch.zeros(
|
generated_ids = torch.zeros(
|
||||||
|
@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
|
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
|
||||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||||
device=model.device,
|
device=model.device,
|
||||||
|
@ -3075,6 +3075,7 @@ def cleanup(device: str, gc_collect=False):
|
|||||||
if gc_collect:
|
if gc_collect:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
backend_empty_cache(device)
|
backend_empty_cache(device)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
|
||||||
# Type definition of key used in `Expectations` class.
|
# Type definition of key used in `Expectations` class.
|
||||||
|
@ -2285,6 +2285,14 @@ class GenerationTesterMixin:
|
|||||||
inputs_dict[input_name] = input_data
|
inputs_dict[input_name] = input_data
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
main_input = inputs_dict[model_class.main_input_name]
|
||||||
|
|
||||||
|
# FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
|
||||||
|
# attention masks at test time and, with generate, the mask will be appended with 1s on the right,
|
||||||
|
# resulting in a mask with holes (not supported properly by FA2).
|
||||||
|
if attn_implementation == "flash_attention_2":
|
||||||
|
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
|
||||||
|
if input_name in inputs_dict:
|
||||||
|
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
# make sure that all models have enough positions for generation
|
||||||
if hasattr(config, "max_position_embeddings"):
|
if hasattr(config, "max_position_embeddings"):
|
||||||
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
|
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
|
||||||
@ -2339,8 +2347,6 @@ class GenerationTesterMixin:
|
|||||||
@slow
|
@slow
|
||||||
def test_eager_matches_fa2_generate(self):
|
def test_eager_matches_fa2_generate(self):
|
||||||
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
||||||
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
|
|
||||||
# check whether we still need the overwrites
|
|
||||||
self._test_attention_implementation("flash_attention_2")
|
self._test_attention_implementation("flash_attention_2")
|
||||||
|
|
||||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
@ -3974,7 +3980,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# TODO: We need to raise a warning in case the cache is not set correctly
|
# TODO: We need to raise a warning in case the cache is not set correctly
|
||||||
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
|
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
|
||||||
# past_key_values = StaticCache(
|
# past_key_values = StaticCache(
|
||||||
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
# config=model.config, max_batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
||||||
# )
|
# )
|
||||||
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||||
|
|
||||||
@ -3982,7 +3988,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
layer_device_map = {0: 0, 1: 1}
|
layer_device_map = {0: 0, 1: 1}
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=30,
|
max_cache_len=30,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=model.dtype,
|
dtype=model.dtype,
|
||||||
@ -4183,7 +4189,11 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
batch_size = 2
|
batch_size = 2
|
||||||
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
|
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
|
||||||
static_cache = StaticCache(
|
static_cache = StaticCache(
|
||||||
config=config, batch_size=batch_size, max_cache_len=max_cache_len, device=torch_device, dtype=torch.float32
|
config=config,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
max_cache_len=max_cache_len,
|
||||||
|
device=torch_device,
|
||||||
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values
|
static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values
|
||||||
model_inputs = model.prepare_inputs_for_generation(
|
model_inputs = model.prepare_inputs_for_generation(
|
||||||
|
@ -21,6 +21,7 @@ from parameterized import parameterized
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, DeepseekV3Config, is_torch_available, set_seed
|
from transformers import AutoTokenizer, DeepseekV3Config, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@ -605,6 +606,10 @@ class DeepseekV3IntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
@ -25,6 +25,7 @@ from parameterized import parameterized
|
|||||||
from transformers import AutoTokenizer, DiffLlamaConfig, StaticCache, is_torch_available, set_seed
|
from transformers import AutoTokenizer, DiffLlamaConfig, StaticCache, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
cleanup,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
@ -685,6 +686,10 @@ class DiffLlamaIntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@ -884,7 +889,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=max_cache_len,
|
max_cache_len=max_cache_len,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
@ -932,7 +937,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=max_cache_len,
|
max_cache_len=max_cache_len,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
|
@ -23,6 +23,7 @@ from packaging import version
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@ -498,6 +499,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_2b_fp16(self):
|
def test_model_2b_fp16(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
|
@ -549,6 +549,13 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
|
||||||
|
# some memory allocated in the cache, which means some object is not being released properly. This causes some
|
||||||
|
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
|
||||||
|
# Investigate the root cause.
|
||||||
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_llama_3_1_hard(self):
|
def test_llama_3_1_hard(self):
|
||||||
@ -748,14 +755,6 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all "
|
"Simply put, the theory of relativity states that 1) the speed of light is the same for all "
|
||||||
"observers, regardless of their location, and 2) the laws of physics are the same for all observers"
|
"observers, regardless of their location, and 2) the laws of physics are the same for all observers"
|
||||||
],
|
],
|
||||||
"meta-llama/Llama-3.2-3B": [
|
|
||||||
"Simply put, the theory of relativity states that 1. the speed of light is constant, and 2. "
|
|
||||||
"the speed of light is the fastest speed possible"
|
|
||||||
],
|
|
||||||
"meta-llama/Llama-2-7b-hf": [
|
|
||||||
"Simply put, the theory of relativity states that 1) the speed of light is a constant, and 2) "
|
|
||||||
"the laws of physics are the same for all",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items():
|
for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items():
|
||||||
@ -946,7 +945,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=max_cache_len,
|
max_cache_len=max_cache_len,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
@ -994,7 +993,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=max_cache_len,
|
max_cache_len=max_cache_len,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
|
@ -53,7 +53,7 @@ if is_torch_available():
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.cache = StaticCache(
|
self.cache = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
max_cache_len=max_seq_len,
|
max_cache_len=max_seq_len,
|
||||||
device=self.model.device,
|
device=self.model.device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
|
@ -227,10 +227,6 @@ class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="This one tries to use right padding as well")
|
|
||||||
def test_eager_matches_fa2_generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
|
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
|
||||||
def test_training_gradient_checkpointing(self):
|
def test_training_gradient_checkpointing(self):
|
||||||
pass
|
pass
|
||||||
|
@ -52,7 +52,7 @@ if is_torch_available():
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.cache = StaticCache(
|
self.cache = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
max_cache_len=max_seq_len,
|
max_cache_len=max_seq_len,
|
||||||
device=self.model.device,
|
device=self.model.device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
|
@ -24,6 +24,7 @@ from transformers import T5Config, is_torch_available
|
|||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
@ -1170,6 +1171,10 @@ class T5ModelFp16Tests(unittest.TestCase):
|
|||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class T5ModelIntegrationTests(unittest.TestCase):
|
class T5ModelIntegrationTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
||||||
|
cleanup(torch_device, gc_collect=False)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
return T5ForConditionalGeneration.from_pretrained("google-t5/t5-base").to(torch_device)
|
return T5ForConditionalGeneration.from_pretrained("google-t5/t5-base").to(torch_device)
|
||||||
|
@ -226,7 +226,7 @@ class AqlmTest(unittest.TestCase):
|
|||||||
# Setup static KV cache for generation
|
# Setup static KV cache for generation
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.quantized_model.config,
|
config=self.quantized_model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=seq_length + self.max_new_tokens + 1,
|
max_cache_len=seq_length + self.max_new_tokens + 1,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.quantized_model.config._pre_quantization_dtype,
|
dtype=self.quantized_model.config._pre_quantization_dtype,
|
||||||
|
@ -207,7 +207,7 @@ class SpQRTest(unittest.TestCase):
|
|||||||
# Setup static KV cache for generation
|
# Setup static KV cache for generation
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.quantized_model.config,
|
config=self.quantized_model.config,
|
||||||
batch_size=1,
|
max_batch_size=1,
|
||||||
max_cache_len=seq_length + self.max_new_tokens + 1,
|
max_cache_len=seq_length + self.max_new_tokens + 1,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.quantized_model.config._pre_quantization_dtype,
|
dtype=self.quantized_model.config._pre_quantization_dtype,
|
||||||
|
Loading…
Reference in New Issue
Block a user