9.7 KiB
Gemma
Gemma is a family of lightweight language models with pretrained and instruction-tuned variants, available in 2B and 7B parameters. The architecture is based on a transformer decoder-only design. It features Multi-Query Attention, rotary positional embeddings (RoPE), GeGLU activation functions, and RMSNorm layer normalization.
The instruction-tuned variant was fine-tuned with supervised learning on instruction-following data, followed by reinforcement learning from human feedback (RLHF) to align the model outputs with human preferences.
You can find all the original Gemma checkpoints under the Gemma release.
Tip
Click on the Gemma models in the right sidebar for more examples of how to apply Gemma to different language tasks.
The example below demonstrates how to generate text with [Pipeline
] or the [AutoModel
] class, and from the command line.
import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="google/gemma-2b",
torch_dtype=torch.bfloat16,
device="cuda",
)
pipeline("LLMs generate text through a process known as", max_new_tokens=50)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=50, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
echo -e "LLMs generate text through a process known as" | transformers-cli run --task text-generation --model google/gemma-2b --device 0
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.
The example below uses bitsandbytes to only quantize the weights to int4.
#!pip install bitsandbytes
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b",
quantization_config=quantization_config,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(
**input_ids,
max_new_tokens=50,
cache_implementation="static"
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("LLMs generate text through a process known as")

Notes
-
The original Gemma models support standard kv-caching used in many transformer-based language models. You can use use the default [
DynamicCache
] instance or a tuple of tensors for past key values during generation. This makes it compatible with typical autoregressive generation workflows.import torch from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa" ) input_text = "LLMs generate text through a process known as" input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") past_key_values = DynamicCache() outputs = model.generate(**input_ids, max_new_tokens=50, past_key_values=past_key_values) print(tokenizer.decode(outputs[0], skip_special_tokens=True))
GemmaConfig
autodoc GemmaConfig
GemmaTokenizer
autodoc GemmaTokenizer
GemmaTokenizerFast
autodoc GemmaTokenizerFast
GemmaModel
autodoc GemmaModel - forward
GemmaForCausalLM
autodoc GemmaForCausalLM - forward
GemmaForSequenceClassification
autodoc GemmaForSequenceClassification - forward
GemmaForTokenClassification
autodoc GemmaForTokenClassification - forward
FlaxGemmaModel
autodoc FlaxGemmaModel - call
FlaxGemmaForCausalLM
autodoc FlaxGemmaForCausalLM - call