mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 20:18:24 +06:00
Make Gemma
work with torch.compile
(#30775)
* fix * [run-slow] gemma * add test * add `test_compile_static_cache` * fix * style * remove subprocess * use attribute * fix * style * update * [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
0753134f4d
commit
1b3dba9417
@ -55,15 +55,14 @@ class DbrxRotaryEmbedding(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
self.register_buffer("inv_freq", None, persistent=False)
|
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
if self.inv_freq is None:
|
self.inv_freq.to(x.device)
|
||||||
self.inv_freq = 1.0 / (
|
|
||||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
@ -104,15 +104,14 @@ class GemmaRotaryEmbedding(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
self.register_buffer("inv_freq", None, persistent=False)
|
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
if self.inv_freq is None:
|
self.inv_freq.to(x.device)
|
||||||
self.inv_freq = 1.0 / (
|
|
||||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
@ -397,15 +397,14 @@ class JetMoeRotaryEmbedding(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
self.register_buffer("inv_freq", None, persistent=False)
|
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
if self.inv_freq is None:
|
self.inv_freq.to(x.device)
|
||||||
self.inv_freq = 1.0 / (
|
|
||||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
@ -99,15 +99,14 @@ class Phi3RotaryEmbedding(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
self.register_buffer("inv_freq", None, persistent=False)
|
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
if self.inv_freq is None:
|
self.inv_freq.to(x.device)
|
||||||
self.inv_freq = 1.0 / (
|
|
||||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
@ -68,16 +68,14 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.base = base
|
self.base = base
|
||||||
self.register_buffer("inv_freq", None, persistent=False)
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma
|
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
if self.inv_freq is None:
|
self.inv_freq.to(x.device)
|
||||||
self.inv_freq = 1.0 / (
|
|
||||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
# Force float32 since bfloat16 loses precision on long contexts
|
||||||
|
@ -17,6 +17,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@ -40,7 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel
|
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
|
||||||
|
|
||||||
|
|
||||||
class GemmaModelTester:
|
class GemmaModelTester:
|
||||||
@ -302,6 +303,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.6]
|
model_split_percents = [0.5, 0.6]
|
||||||
|
|
||||||
|
# used in `test_torch_compile`
|
||||||
|
_torch_compile_test_ckpt = "google/gemma-2b"
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
@ -801,3 +805,51 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_read_token
|
||||||
|
def test_compile_static_cache(self):
|
||||||
|
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||||
|
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||||
|
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
|
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||||
|
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||||
|
EXPECTED_TEXT_COMPLETION = {
|
||||||
|
8: [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||||
|
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
|
||||||
|
],
|
||||||
|
7: [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
|
||||||
|
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts = ["Hello I am doing", "Hi today"]
|
||||||
|
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||||
|
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
|
||||||
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
|
# Dynamic Cache
|
||||||
|
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||||
|
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
|
||||||
|
|
||||||
|
# Static Cache
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
|
)
|
||||||
|
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||||
|
|
||||||
|
# Static Cache + compile
|
||||||
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
|
)
|
||||||
|
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||||
|
@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.7, 0.8]
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
|
# used in `test_torch_compile`
|
||||||
|
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = LlamaModelTester(self)
|
self.model_tester = LlamaModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
||||||
|
@ -27,6 +27,7 @@ from collections import defaultdict
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ from transformers import (
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@ -71,6 +73,7 @@ from transformers.testing_utils import (
|
|||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
|
require_read_token,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -4399,6 +4402,38 @@ class ModelTesterMixin:
|
|||||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||||
|
|
||||||
|
# For now, Let's focus only on GPU for `torch.compile`
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_read_token
|
||||||
|
def test_torch_compile(self):
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||||
|
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
if not hasattr(self, "_torch_compile_test_ckpt"):
|
||||||
|
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
||||||
|
ckpt = self._torch_compile_test_ckpt
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
n_iter = 3
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
|
||||||
|
model.generation_config.max_new_tokens = 4
|
||||||
|
model.generation_config.max_new_tokens = 4
|
||||||
|
|
||||||
|
model.generation_config.cache_implementation = "static"
|
||||||
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
|
|
||||||
|
input_text = "Why dogs are cute?"
|
||||||
|
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
for i in range(n_iter):
|
||||||
|
_ = model.generate(**input_ids, do_sample=False)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user