mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
fix llama
tests (#39161)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
4c1715b610
commit
8e87adc45f
@ -25,6 +25,7 @@ from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
run_test_using_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -96,36 +97,28 @@ class LlamaModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_read_token
|
||||
class LlamaIntegrationTest(unittest.TestCase):
|
||||
def setup(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
|
||||
# some memory allocated in the cache, which means some object is not being released properly. This causes some
|
||||
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
|
||||
# Investigate the root cause.
|
||||
cleanup(torch_device, gc_collect=False)
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_llama_3_1_hard(self):
|
||||
"""
|
||||
An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences
|
||||
from llama 3.1.'s RoPE can be detected
|
||||
"""
|
||||
# diff on `EXPECTED_TEXT`:
|
||||
# 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results.
|
||||
expected_base_text = (
|
||||
"Tell me about the french revolution. The french revolution was a period of radical political and social "
|
||||
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
|
||||
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
|
||||
"First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative "
|
||||
"assembly that had not met since 1614. The Third Estate, which represented the common people, "
|
||||
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
|
||||
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
|
||||
)
|
||||
expected_texts = Expectations(
|
||||
{
|
||||
("rocm", (9, 5)): expected_base_text.replace("political and social", "social and political"),
|
||||
("cuda", None): expected_base_text,
|
||||
("rocm", (9, 5)): 'Tell me about the french revolution. The french revolution was a period of radical social and political upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. This marked the beginning of the end of the absolute monarchy and the rise of the middle class.\n',
|
||||
("cuda", None): 'Tell me about the french revolution. The french revolution was a period of radical political and social upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. The National Assembly adopted the Declaration of the Rights of Man and of the Citizen, which enshr',
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_TEXT = expected_texts.get_expectation()
|
||||
@ -142,7 +135,6 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(generated_text, EXPECTED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_model_7b_logits_bf16(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
@ -191,7 +183,6 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
@ -240,6 +231,9 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: check why we have the following strange situation.
|
||||
# without running in subprocess, this test causes subsequent tests failing with `RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!`
|
||||
@run_test_using_subprocess
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
@ -265,7 +259,6 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
@ -306,7 +299,6 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_export_static_cache(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
|
@ -407,6 +407,8 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
self.tokenizer.add_eos_token = False
|
||||
self.rust_tokenizer.add_eos_token = False
|
||||
|
||||
# See internal discussion: https://huggingface.slack.com/archives/C01NE71C4F7/p1750680376085749?thread_ts=1750676268.233309&cid=C01NE71C4F7
|
||||
@unittest.skip("failing, won't fix")
|
||||
@slow
|
||||
def test_conversion(self):
|
||||
# This is excruciatingly slow since it has to recreate the entire merge
|
||||
|
Loading…
Reference in New Issue
Block a user