# Gemma
[Gemma](https://huggingface.co/papers/2403.08295) is a family of lightweight language models with pretrained and instruction-tuned variants, available in 2B and 7B parameters. The architecture is based on a transformer decoder-only design. It features Multi-Query Attention, rotary positional embeddings (RoPE), GeGLU activation functions, and RMSNorm layer normalization.
The instruction-tuned variant was fine-tuned with supervised learning on instruction-following data, followed by reinforcement learning from human feedback (RLHF) to align the model outputs with human preferences.
You can find all the original Gemma checkpoints under the [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) release.
> [!TIP]
> Click on the Gemma models in the right sidebar for more examples of how to apply Gemma to different language tasks.
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
```py
import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="google/gemma-2b",
torch_dtype=torch.bfloat16,
device="cuda",
)
pipeline("LLMs generate text through a process known as", max_new_tokens=50)
```
```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=50, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
```bash
echo -e "LLMs generate text through a process known as" | transformers run --task text-generation --model google/gemma-2b --device 0
```
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
```py
#!pip install bitsandbytes
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b",
quantization_config=quantization_config,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(
**input_ids,
max_new_tokens=50,
cache_implementation="static"
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
```py
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("LLMs generate text through a process known as")
```
## Notes
- The original Gemma models support standard kv-caching used in many transformer-based language models. You can use use the default [`DynamicCache`] instance or a tuple of tensors for past key values during generation. This makes it compatible with typical autoregressive generation workflows.
```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "LLMs generate text through a process known as"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
past_key_values = DynamicCache()
outputs = model.generate(**input_ids, max_new_tokens=50, past_key_values=past_key_values)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## GemmaConfig
[[autodoc]] GemmaConfig
## GemmaTokenizer
[[autodoc]] GemmaTokenizer
## GemmaTokenizerFast
[[autodoc]] GemmaTokenizerFast
## GemmaModel
[[autodoc]] GemmaModel
- forward
## GemmaForCausalLM
[[autodoc]] GemmaForCausalLM
- forward
## GemmaForSequenceClassification
[[autodoc]] GemmaForSequenceClassification
- forward
## GemmaForTokenClassification
[[autodoc]] GemmaForTokenClassification
- forward
## FlaxGemmaModel
[[autodoc]] FlaxGemmaModel
- __call__
## FlaxGemmaForCausalLM
[[autodoc]] FlaxGemmaForCausalLM
- __call__