[docs] Llama3 (#31662)

quick usage to top
This commit is contained in:
Steven Liu 2024-06-27 10:32:51 -07:00 committed by GitHub
parent e44b878c02
commit 464aa74659
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,15 @@ rendered properly in your Markdown viewer.
# Llama3 # Llama3
```py3
import transformers
import torch
model_id = "meta-llama/Meta-Llama-3-8B"
pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
pipeline("Hey how are you doing today?")
```
## Overview ## Overview
@ -66,20 +75,7 @@ model = AutoModelForCausalLM.from_pretrained("/output/path")
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed. come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed.
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type. - When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
## Quick usage
```py3
import transformers
import torch
model_id = "meta-llama/Meta-Llama-3-8B"
pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
pipeline("Hey how are you doing today?")
```
## Resources ## Resources
A ton of cool resources are already available on the documentation page of [~llama2], inviting contributors to add new resources curated for Llama3 here! 🤗 A ton of cool resources are already available on the documentation page of [Llama2](./llama2), inviting contributors to add new resources curated for Llama3 here! 🤗