PyTorch TensorFlow Flax FlashAttention SDPA
# Gemma2 [Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance. The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning. You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) collection. > [!TIP] > Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks. The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line. ```python import torch from transformers import pipeline pipe = pipeline( task="text-generation", model="google/gemma-2-9b", torch_dtype=torch.bfloat16, device="cuda", ) pipe("Explain quantum computing simply. ", max_new_tokens=50) ``` ```python import torch from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") model = AutoModelForCausalLM.from_pretrained( "google/gemma-2-9b", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa" ) input_text = "Explain quantum computing simply." input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static") print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ``` echo -e "Explain quantum computing simply." | transformers run --task text-generation --model google/gemma-2-2b --device 0 ``` Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4. ```python import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_4bit=True) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b") model = AutoModelForCausalLM.from_pretrained( "google/gemma-2-27b", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa" ) input_text = "Explain quantum computing simply." input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static") print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to. ```python from transformers.utils.attention_visualizer import AttentionMaskVisualizer visualizer = AttentionMaskVisualizer("google/gemma-2b") visualizer("You are an assistant. Make sure you print me") ```
## Notes - Use a [`HybridCache`] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [`DynamicCache`] or tuples of tensors because it uses sliding window attention every second layer. ```python from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") inputs = tokenizer(text="My name is Gemma", return_tensors="pt") max_generated_length = inputs.input_ids.shape[1] + 10 past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) ``` ## Gemma2Config [[autodoc]] Gemma2Config ## Gemma2Model [[autodoc]] Gemma2Model - forward ## Gemma2ForCausalLM [[autodoc]] Gemma2ForCausalLM - forward ## Gemma2ForSequenceClassification [[autodoc]] Gemma2ForSequenceClassification - forward ## Gemma2ForTokenClassification [[autodoc]] Gemma2ForTokenClassification - forward