transformers/docs/source/en/model_doc/gemma2.md
Lysandre Debut d538293f62
Transformers cli clean command (#37657)
* transformers-cli -> transformers

* Chat command works with positional argument

* update doc references to transformers-cli

* doc headers

* deepspeed

---------

Co-authored-by: Joao Gante <joao@huggingface.co>
2025-04-30 12:15:43 +01:00

9.1 KiB

PyTorch TensorFlow Flax FlashAttention SDPA

Gemma2

Gemma 2 is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance.

The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning.

You can find all the original Gemma 2 checkpoints under the Gemma 2 collection.

Tip

Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks.

The example below demonstrates how to chat with the model with [Pipeline] or the [AutoModel] class, and from the command line.

import torch
from transformers import pipeline

pipe = pipeline(
    task="text-generation",
    model="google/gemma-2-9b",
    torch_dtype=torch.bfloat16,
    device="cuda",
)

pipe("Explain quantum computing simply. ", max_new_tokens=50)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

input_text = "Explain quantum computing simply."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

echo -e "Explain quantum computing simply." | transformers run --task text-generation --model google/gemma-2-2b --device 0

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses bitsandbytes to only quantize the weights to int4.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-27b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

input_text = "Explain quantum computing simply."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.

from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("You are an assistant. Make sure you print me")

Notes

  • Use a [HybridCache] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [DynamicCache] or tuples of tensors because it uses sliding window attention every second layer.

    from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
    
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
    
    inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
    max_generated_length = inputs.input_ids.shape[1] + 10
    past_key_values = HybridCache(config=model.config, max_batch_size=1,
    max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
    outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
    

Gemma2Config

autodoc Gemma2Config

Gemma2Model

autodoc Gemma2Model - forward

Gemma2ForCausalLM

autodoc Gemma2ForCausalLM - forward

Gemma2ForSequenceClassification

autodoc Gemma2ForSequenceClassification - forward

Gemma2ForTokenClassification

autodoc Gemma2ForTokenClassification - forward