transformers/docs/source/en/model_doc/gemma.md
AfafEL cfe666919e
Update model card for Gemma (#37674)
* Update Gemma model card

* Updated after review

* Update following review
2025-04-24 09:58:46 -07:00

9.7 KiB

PyTorch TensorFlow Flax FlashAttention SDPA

Gemma

Gemma is a family of lightweight language models with pretrained and instruction-tuned variants, available in 2B and 7B parameters. The architecture is based on a transformer decoder-only design. It features Multi-Query Attention, rotary positional embeddings (RoPE), GeGLU activation functions, and RMSNorm layer normalization.

The instruction-tuned variant was fine-tuned with supervised learning on instruction-following data, followed by reinforcement learning from human feedback (RLHF) to align the model outputs with human preferences.

You can find all the original Gemma checkpoints under the Gemma release.

Tip

Click on the Gemma models in the right sidebar for more examples of how to apply Gemma to different language tasks.

The example below demonstrates how to generate text with [Pipeline] or the [AutoModel] class, and from the command line.

import torch
from transformers import pipeline

pipeline = pipeline(
    task="text-generation",
    model="google/gemma-2b",
    torch_dtype=torch.bfloat16,
    device="cuda",
)

pipeline("LLMs generate text through a process known as", max_new_tokens=50)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

input_text = "LLMs generate text through a process known as"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=50, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
echo -e "LLMs generate text through a process known as" | transformers-cli run --task text-generation --model google/gemma-2b --device 0

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses bitsandbytes to only quantize the weights to int4.

#!pip install bitsandbytes
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-7b",
    quantization_config=quantization_config,
    device_map="auto",
    attn_implementation="sdpa"
)

input_text = "LLMs generate text through a process known as."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(
    **input_ids, 
    max_new_tokens=50, 
    cache_implementation="static"
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.

from transformers.utils.attention_visualizer import AttentionMaskVisualizer

visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("LLMs generate text through a process known as") 

Notes

  • The original Gemma models support standard kv-caching used in many transformer-based language models. You can use use the default [DynamicCache] instance or a tuple of tensors for past key values during generation. This makes it compatible with typical autoregressive generation workflows.

    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
    
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
    model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2b",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="sdpa"
    )
    input_text = "LLMs generate text through a process known as"
    input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
    past_key_values = DynamicCache()
    outputs = model.generate(**input_ids, max_new_tokens=50, past_key_values=past_key_values)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    

GemmaConfig

autodoc GemmaConfig

GemmaTokenizer

autodoc GemmaTokenizer

GemmaTokenizerFast

autodoc GemmaTokenizerFast

GemmaModel

autodoc GemmaModel - forward

GemmaForCausalLM

autodoc GemmaForCausalLM - forward

GemmaForSequenceClassification

autodoc GemmaForSequenceClassification - forward

GemmaForTokenClassification

autodoc GemmaForTokenClassification - forward

FlaxGemmaModel

autodoc FlaxGemmaModel - call

FlaxGemmaForCausalLM

autodoc FlaxGemmaForCausalLM - call