Make Gemma work with torch.compile (#30775)

* fix

* [run-slow] gemma

* add test

* add `test_compile_static_cache`

* fix

* style

* remove subprocess

* use attribute

* fix

* style

* update

* [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-05-16 13:41:33 +02:00 committed by GitHub
parent 0753134f4d
commit 1b3dba9417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 110 additions and 26 deletions

View File

@ -55,15 +55,14 @@ class DbrxRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts

View File

@ -104,15 +104,14 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts

View File

@ -397,15 +397,14 @@ class JetMoeRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts

View File

@ -99,15 +99,14 @@ class Phi3RotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts

View File

@ -68,16 +68,14 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
super().__init__()
self.dim = dim
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts

View File

@ -17,6 +17,7 @@ import tempfile
import unittest
import pytest
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import (
@ -40,7 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
class GemmaModelTester:
@ -302,6 +303,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.6]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/gemma-2b"
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
@ -801,3 +805,51 @@ class GemmaIntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
@slow
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = {
8: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
7: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
}
prompts = ["Hello I am doing", "Hi today"]
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)

View File

@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
def setUp(self):
self.model_tester = LlamaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)

View File

@ -27,6 +27,7 @@ from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
from packaging import version
from parameterized import parameterized
from pytest import mark
@ -35,6 +36,7 @@ from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
@ -71,6 +73,7 @@ from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_safetensors,
require_torch,
require_torch_gpu,
@ -4399,6 +4402,38 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
@require_read_token
def test_torch_compile(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest("This test requires torch >= 2.3 to run.")
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
batch_size = 1
n_iter = 3
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
model.generation_config.max_new_tokens = 4
model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "Why dogs are cute?"
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)
global_rng = random.Random()