mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[docs/gpt-j] addd instructions for how minimize CPU RAM usage (#13795)
* add a note about tokenizer * add tips to load model is less RAM * fix link * fix more links
This commit is contained in:
parent
55695df0f7
commit
bf6118e70c
@ -24,15 +24,34 @@ This model was contributed by `Stella Biderman <https://huggingface.co/stellaath
|
||||
|
||||
Tips:
|
||||
|
||||
- Running [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 precision on GPU requires at least 24 GB of
|
||||
RAM. On GPUs with less than 24 GB RAM, one should therefore load the model in half-precision:
|
||||
- To load `GPT-J <https://huggingface.co/EleutherAI/gpt-j-6B>`__ in float32 one would need at least 2x model size CPU
|
||||
RAM: 1x for initial weights and another 1x to load the checkpoint. So for GPT-J it would take at least 48GB of CPU
|
||||
RAM to just load the model. To reduce the CPU RAM usage there are a few options. The ``torch_dtype`` argument can be
|
||||
used to initialize the model in half-precision. And the ``low_cpu_mem_usage`` argument can be used to keep the RAM
|
||||
usage to 1x. There is also a `fp16 branch <https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16>`__ which stores
|
||||
the fp16 weights, which could be used to further minimize the RAM usage. Combining all this it should take roughly
|
||||
12.1GB of CPU RAM to load the model.
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import GPTJForCausalLM
|
||||
>>> import torch
|
||||
|
||||
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
|
||||
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
|
||||
|
||||
- The model should fit on 16GB GPU for inference. For training/fine-tuning it would take much more GPU RAM. Adam
|
||||
optimizer for example makes four copies of the model: model, gradients, average and squared average of the gradients.
|
||||
So it would need at least 4x model size GPU memory, even with mixed precision as gradient updates are in fp32. This
|
||||
is not including the activations and data batches, which would again require some more GPU RAM. So one should explore
|
||||
solutions such as DeepSpeed, to train/fine-tune the model. Another option is to use the original codebase to
|
||||
train/fine-tune the model on TPU and then convert the model to Transformers format for inference. Instructions for
|
||||
that could be found `here <https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md>`__
|
||||
|
||||
- Although the embedding matrix has a size of 50400, only 50257 entries are used by the GPT-2 tokenizer. These extra
|
||||
tokens are added for the sake of efficiency on TPUs. To avoid the mis-match between embedding matrix size and vocab
|
||||
size, the tokenizer for `GPT-J <https://huggingface.co/EleutherAI/gpt-j-6B>`__ contains 143 extra tokens
|
||||
``<|extratoken_1|>... <|extratoken_143|>``, so the ``vocab_size`` of tokenizer also becomes 50400.
|
||||
|
||||
- Although the embedding matrix has a size of 50400, only 50257 entries are used by the GPT-2 tokenizer. These extra
|
||||
tokens are added for the sake of efficiency on TPUs. To avoid the mis-match between embedding matrix size and vocab
|
||||
|
Loading…
Reference in New Issue
Block a user