mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00

* Add jamba arch * apply "make fix-copies" changes * fix link to model in JambaConfig docstring * Add n_ctx in modeling file because repo-consistency wants that * Add jamba to flash attention and sdpa documentation * mamba dt_proj quant fix now works for LoRA as well * override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers * add jamba to tokenization auto * fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24) * simple PR fixes * remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer * remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (https://github.com/huggingface/peft/pull/1530) * Add copied comment on JambaMLP (it's the same as MixtralMLP) * remove padding_mask warnings. It's not supported anymore * fix docstring. Float instead of int * A few more minor PR fixes * (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass * Return None attention weights from mamba layers. Append to all attentions only if not None. * remove some leftover jamba archive lists * Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel * no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers * Add Jamba paper on READMEs * (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes) * Add copied from comment * remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms * clearer docstring for _convert_to_standard_cache * style fixes * Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs * rename test so it still overrides what its meant to override * draft * oups * nit * remove more complexe logic * fix names used in config * fix fix fix * style * fix some more failing tests * generate did not init the cache 🙃 * more small nits * typo * config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes * fix init of pkv with torch.tensor() * empty tensor * fix some init issues * stupid changes required by generate because it does not even support it's own DynamicCache class * more fixes * fix general assisted gen cache_position bug * tests passing * Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py * fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache * no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore * fix docstrings and typehints for past_key_values * style fixes * fix docs * change typehint due to copy from Mixtral * forgot import * import order * Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward) * Add integration test with tiny tandom Jamba model on hub * fix flash attention cache shapes * bring back forgotten hidden states * rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model * align integration test after modeling fixes * bugfix - mamba can use precomputed states only of forward pass is on a single token * bugfix - mamba can use precomputed states only if they match the batch size * typo * remove making _prepare_4d_causal_attention_mask a leaf function * stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co>
437 lines
20 KiB
Markdown
437 lines
20 KiB
Markdown
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
|
rendered properly in your Markdown viewer.
|
|
|
|
-->
|
|
|
|
# Causal language modeling
|
|
|
|
[[open-in-colab]]
|
|
|
|
There are two types of language modeling, causal and masked. This guide illustrates causal language modeling.
|
|
Causal language models are frequently used for text generation. You can use these models for creative applications like
|
|
choosing your own text adventure or an intelligent coding assistant like Copilot or CodeParrot.
|
|
|
|
<Youtube id="Vpjb1lu0MDk"/>
|
|
|
|
Causal language modeling predicts the next token in a sequence of tokens, and the model can only attend to tokens on
|
|
the left. This means the model cannot see future tokens. GPT-2 is an example of a causal language model.
|
|
|
|
This guide will show you how to:
|
|
|
|
1. Finetune [DistilGPT2](https://huggingface.co/distilbert/distilgpt2) on the [r/askscience](https://www.reddit.com/r/askscience/) subset of the [ELI5](https://huggingface.co/datasets/eli5) dataset.
|
|
2. Use your finetuned model for inference.
|
|
|
|
<Tip>
|
|
You can finetune other architectures for causal language modeling following the same steps in this guide.
|
|
Choose one of the following architectures:
|
|
|
|
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
|
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [Jamba](../model_doc/jamba), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OLMo](../model_doc/olmo), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
|
|
|
|
|
|
|
|
|
<!--End of the generated tip-->
|
|
|
|
</Tip>
|
|
|
|
Before you begin, make sure you have all the necessary libraries installed:
|
|
|
|
```bash
|
|
pip install transformers datasets evaluate
|
|
```
|
|
|
|
We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in:
|
|
|
|
```py
|
|
>>> from huggingface_hub import notebook_login
|
|
|
|
>>> notebook_login()
|
|
```
|
|
|
|
## Load ELI5 dataset
|
|
|
|
Start by loading the first 5000 examples from the [ELI5-Category](https://huggingface.co/datasets/eli5_category) dataset with the 🤗 Datasets library. This'll give you a chance to experiment and make sure everything works before spending more time training on the full dataset.
|
|
|
|
```py
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> eli5 = load_dataset("eli5_category", split="train[:5000]")
|
|
```
|
|
|
|
Split the dataset's `train` split into a train and test set with the [`~datasets.Dataset.train_test_split`] method:
|
|
|
|
```py
|
|
>>> eli5 = eli5.train_test_split(test_size=0.2)
|
|
```
|
|
|
|
Then take a look at an example:
|
|
|
|
```py
|
|
>>> eli5["train"][0]
|
|
{'q_id': '7h191n',
|
|
'title': 'What does the tax bill that was passed today mean? How will it affect Americans in each tax bracket?',
|
|
'selftext': '',
|
|
'category': 'Economics',
|
|
'subreddit': 'explainlikeimfive',
|
|
'answers': {'a_id': ['dqnds8l', 'dqnd1jl', 'dqng3i1', 'dqnku5x'],
|
|
'text': ["The tax bill is 500 pages long and there were a lot of changes still going on right to the end. It's not just an adjustment to the income tax brackets, it's a whole bunch of changes. As such there is no good answer to your question. The big take aways are: - Big reduction in corporate income tax rate will make large companies very happy. - Pass through rate change will make certain styles of business (law firms, hedge funds) extremely happy - Income tax changes are moderate, and are set to expire (though it's the kind of thing that might just always get re-applied without being made permanent) - People in high tax states (California, New York) lose out, and many of them will end up with their taxes raised.",
|
|
'None yet. It has to be reconciled with a vastly different house bill and then passed again.',
|
|
'Also: does this apply to 2017 taxes? Or does it start with 2018 taxes?',
|
|
'This article explains both the House and senate bills, including the proposed changes to your income taxes based on your income level. URL_0'],
|
|
'score': [21, 19, 5, 3],
|
|
'text_urls': [[],
|
|
[],
|
|
[],
|
|
['https://www.investopedia.com/news/trumps-tax-reform-what-can-be-done/']]},
|
|
'title_urls': ['url'],
|
|
'selftext_urls': ['url']}
|
|
```
|
|
|
|
While this may look like a lot, you're only really interested in the `text` field. What's cool about language modeling
|
|
tasks is you don't need labels (also known as an unsupervised task) because the next word *is* the label.
|
|
|
|
## Preprocess
|
|
|
|
<Youtube id="ma1TrR7gE7I"/>
|
|
|
|
The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:
|
|
|
|
```py
|
|
>>> from transformers import AutoTokenizer
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
|
|
```
|
|
|
|
You'll notice from the example above, the `text` field is actually nested inside `answers`. This means you'll need to
|
|
extract the `text` subfield from its nested structure with the [`flatten`](https://huggingface.co/docs/datasets/process#flatten) method:
|
|
|
|
```py
|
|
>>> eli5 = eli5.flatten()
|
|
>>> eli5["train"][0]
|
|
{'q_id': '7h191n',
|
|
'title': 'What does the tax bill that was passed today mean? How will it affect Americans in each tax bracket?',
|
|
'selftext': '',
|
|
'category': 'Economics',
|
|
'subreddit': 'explainlikeimfive',
|
|
'answers.a_id': ['dqnds8l', 'dqnd1jl', 'dqng3i1', 'dqnku5x'],
|
|
'answers.text': ["The tax bill is 500 pages long and there were a lot of changes still going on right to the end. It's not just an adjustment to the income tax brackets, it's a whole bunch of changes. As such there is no good answer to your question. The big take aways are: - Big reduction in corporate income tax rate will make large companies very happy. - Pass through rate change will make certain styles of business (law firms, hedge funds) extremely happy - Income tax changes are moderate, and are set to expire (though it's the kind of thing that might just always get re-applied without being made permanent) - People in high tax states (California, New York) lose out, and many of them will end up with their taxes raised.",
|
|
'None yet. It has to be reconciled with a vastly different house bill and then passed again.',
|
|
'Also: does this apply to 2017 taxes? Or does it start with 2018 taxes?',
|
|
'This article explains both the House and senate bills, including the proposed changes to your income taxes based on your income level. URL_0'],
|
|
'answers.score': [21, 19, 5, 3],
|
|
'answers.text_urls': [[],
|
|
[],
|
|
[],
|
|
['https://www.investopedia.com/news/trumps-tax-reform-what-can-be-done/']],
|
|
'title_urls': ['url'],
|
|
'selftext_urls': ['url']}
|
|
```
|
|
|
|
Each subfield is now a separate column as indicated by the `answers` prefix, and the `text` field is a list now. Instead
|
|
of tokenizing each sentence separately, convert the list to a string so you can jointly tokenize them.
|
|
|
|
Here is a first preprocessing function to join the list of strings for each example and tokenize the result:
|
|
|
|
```py
|
|
>>> def preprocess_function(examples):
|
|
... return tokenizer([" ".join(x) for x in examples["answers.text"]])
|
|
```
|
|
|
|
To apply this preprocessing function over the entire dataset, use the 🤗 Datasets [`~datasets.Dataset.map`] method. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once, and increasing the number of processes with `num_proc`. Remove any columns you don't need:
|
|
|
|
```py
|
|
>>> tokenized_eli5 = eli5.map(
|
|
... preprocess_function,
|
|
... batched=True,
|
|
... num_proc=4,
|
|
... remove_columns=eli5["train"].column_names,
|
|
... )
|
|
```
|
|
|
|
This dataset contains the token sequences, but some of these are longer than the maximum input length for the model.
|
|
|
|
You can now use a second preprocessing function to
|
|
|
|
- concatenate all the sequences
|
|
- split the concatenated sequences into shorter chunks defined by `block_size`, which should be both shorter than the maximum input length and short enough for your GPU RAM.
|
|
|
|
```py
|
|
>>> block_size = 128
|
|
|
|
|
|
>>> def group_texts(examples):
|
|
... # Concatenate all texts.
|
|
... concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
|
... total_length = len(concatenated_examples[list(examples.keys())[0]])
|
|
... # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
|
... # customize this part to your needs.
|
|
... if total_length >= block_size:
|
|
... total_length = (total_length // block_size) * block_size
|
|
... # Split by chunks of block_size.
|
|
... result = {
|
|
... k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
... for k, t in concatenated_examples.items()
|
|
... }
|
|
... result["labels"] = result["input_ids"].copy()
|
|
... return result
|
|
```
|
|
|
|
Apply the `group_texts` function over the entire dataset:
|
|
|
|
```py
|
|
>>> lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)
|
|
```
|
|
|
|
Now create a batch of examples using [`DataCollatorForLanguageModeling`]. It's more efficient to *dynamically pad* the
|
|
sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.
|
|
|
|
<frameworkcontent>
|
|
<pt>
|
|
Use the end-of-sequence token as the padding token and set `mlm=False`. This will use the inputs as labels shifted to the right by one element:
|
|
|
|
```py
|
|
>>> from transformers import DataCollatorForLanguageModeling
|
|
|
|
>>> tokenizer.pad_token = tokenizer.eos_token
|
|
>>> data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
```
|
|
|
|
</pt>
|
|
<tf>
|
|
Use the end-of-sequence token as the padding token and set `mlm=False`. This will use the inputs as labels shifted to the right by one element:
|
|
|
|
```py
|
|
>>> from transformers import DataCollatorForLanguageModeling
|
|
|
|
>>> data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="tf")
|
|
```
|
|
|
|
</tf>
|
|
</frameworkcontent>
|
|
|
|
|
|
## Train
|
|
|
|
<frameworkcontent>
|
|
<pt>
|
|
<Tip>
|
|
|
|
If you aren't familiar with finetuning a model with the [`Trainer`], take a look at the [basic tutorial](../training#train-with-pytorch-trainer)!
|
|
|
|
</Tip>
|
|
|
|
You're ready to start training your model now! Load DistilGPT2 with [`AutoModelForCausalLM`]:
|
|
|
|
```py
|
|
>>> from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
|
|
```
|
|
|
|
At this point, only three steps remain:
|
|
|
|
1. Define your training hyperparameters in [`TrainingArguments`]. The only required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model).
|
|
2. Pass the training arguments to [`Trainer`] along with the model, datasets, and data collator.
|
|
3. Call [`~Trainer.train`] to finetune your model.
|
|
|
|
```py
|
|
>>> training_args = TrainingArguments(
|
|
... output_dir="my_awesome_eli5_clm-model",
|
|
... evaluation_strategy="epoch",
|
|
... learning_rate=2e-5,
|
|
... weight_decay=0.01,
|
|
... push_to_hub=True,
|
|
... )
|
|
|
|
>>> trainer = Trainer(
|
|
... model=model,
|
|
... args=training_args,
|
|
... train_dataset=lm_dataset["train"],
|
|
... eval_dataset=lm_dataset["test"],
|
|
... data_collator=data_collator,
|
|
... )
|
|
|
|
>>> trainer.train()
|
|
```
|
|
|
|
Once training is completed, use the [`~transformers.Trainer.evaluate`] method to evaluate your model and get its perplexity:
|
|
|
|
```py
|
|
>>> import math
|
|
|
|
>>> eval_results = trainer.evaluate()
|
|
>>> print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
|
|
Perplexity: 49.61
|
|
```
|
|
|
|
Then share your model to the Hub with the [`~transformers.Trainer.push_to_hub`] method so everyone can use your model:
|
|
|
|
```py
|
|
>>> trainer.push_to_hub()
|
|
```
|
|
</pt>
|
|
<tf>
|
|
<Tip>
|
|
|
|
If you aren't familiar with finetuning a model with Keras, take a look at the [basic tutorial](../training#train-a-tensorflow-model-with-keras)!
|
|
|
|
</Tip>
|
|
To finetune a model in TensorFlow, start by setting up an optimizer function, learning rate schedule, and some training hyperparameters:
|
|
|
|
```py
|
|
>>> from transformers import create_optimizer, AdamWeightDecay
|
|
|
|
>>> optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)
|
|
```
|
|
|
|
Then you can load DistilGPT2 with [`TFAutoModelForCausalLM`]:
|
|
|
|
```py
|
|
>>> from transformers import TFAutoModelForCausalLM
|
|
|
|
>>> model = TFAutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
|
|
```
|
|
|
|
Convert your datasets to the `tf.data.Dataset` format with [`~transformers.TFPreTrainedModel.prepare_tf_dataset`]:
|
|
|
|
```py
|
|
>>> tf_train_set = model.prepare_tf_dataset(
|
|
... lm_dataset["train"],
|
|
... shuffle=True,
|
|
... batch_size=16,
|
|
... collate_fn=data_collator,
|
|
... )
|
|
|
|
>>> tf_test_set = model.prepare_tf_dataset(
|
|
... lm_dataset["test"],
|
|
... shuffle=False,
|
|
... batch_size=16,
|
|
... collate_fn=data_collator,
|
|
... )
|
|
```
|
|
|
|
Configure the model for training with [`compile`](https://keras.io/api/models/model_training_apis/#compile-method). Note that Transformers models all have a default task-relevant loss function, so you don't need to specify one unless you want to:
|
|
|
|
```py
|
|
>>> import tensorflow as tf
|
|
|
|
>>> model.compile(optimizer=optimizer) # No loss argument!
|
|
```
|
|
|
|
This can be done by specifying where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
|
|
|
|
```py
|
|
>>> from transformers.keras_callbacks import PushToHubCallback
|
|
|
|
>>> callback = PushToHubCallback(
|
|
... output_dir="my_awesome_eli5_clm-model",
|
|
... tokenizer=tokenizer,
|
|
... )
|
|
```
|
|
|
|
Finally, you're ready to start training your model! Call [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) with your training and validation datasets, the number of epochs, and your callback to finetune the model:
|
|
|
|
```py
|
|
>>> model.fit(x=tf_train_set, validation_data=tf_test_set, epochs=3, callbacks=[callback])
|
|
```
|
|
|
|
Once training is completed, your model is automatically uploaded to the Hub so everyone can use it!
|
|
</tf>
|
|
</frameworkcontent>
|
|
|
|
<Tip>
|
|
|
|
For a more in-depth example of how to finetune a model for causal language modeling, take a look at the corresponding
|
|
[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb)
|
|
or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
|
|
|
</Tip>
|
|
|
|
## Inference
|
|
|
|
Great, now that you've finetuned a model, you can use it for inference!
|
|
|
|
Come up with a prompt you'd like to generate text from:
|
|
|
|
```py
|
|
>>> prompt = "Somatic hypermutation allows the immune system to"
|
|
```
|
|
|
|
The simplest way to try out your finetuned model for inference is to use it in a [`pipeline`]. Instantiate a `pipeline` for text generation with your model, and pass your text to it:
|
|
|
|
```py
|
|
>>> from transformers import pipeline
|
|
|
|
>>> generator = pipeline("text-generation", model="username/my_awesome_eli5_clm-model")
|
|
>>> generator(prompt)
|
|
[{'generated_text': "Somatic hypermutation allows the immune system to be able to effectively reverse the damage caused by an infection.\n\n\nThe damage caused by an infection is caused by the immune system's ability to perform its own self-correcting tasks."}]
|
|
```
|
|
|
|
<frameworkcontent>
|
|
<pt>
|
|
Tokenize the text and return the `input_ids` as PyTorch tensors:
|
|
|
|
```py
|
|
>>> from transformers import AutoTokenizer
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("username/my_awesome_eli5_clm-model")
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt").input_ids
|
|
```
|
|
|
|
Use the [`~transformers.generation_utils.GenerationMixin.generate`] method to generate text.
|
|
For more details about the different text generation strategies and parameters for controlling generation, check out the [Text generation strategies](../generation_strategies) page.
|
|
|
|
```py
|
|
>>> from transformers import AutoModelForCausalLM
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("username/my_awesome_eli5_clm-model")
|
|
>>> outputs = model.generate(inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
|
|
```
|
|
|
|
Decode the generated token ids back into text:
|
|
|
|
```py
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
["Somatic hypermutation allows the immune system to react to drugs with the ability to adapt to a different environmental situation. In other words, a system of 'hypermutation' can help the immune system to adapt to a different environmental situation or in some cases even a single life. In contrast, researchers at the University of Massachusetts-Boston have found that 'hypermutation' is much stronger in mice than in humans but can be found in humans, and that it's not completely unknown to the immune system. A study on how the immune system"]
|
|
```
|
|
</pt>
|
|
<tf>
|
|
Tokenize the text and return the `input_ids` as TensorFlow tensors:
|
|
|
|
```py
|
|
>>> from transformers import AutoTokenizer
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("username/my_awesome_eli5_clm-model")
|
|
>>> inputs = tokenizer(prompt, return_tensors="tf").input_ids
|
|
```
|
|
|
|
Use the [`~transformers.generation_tf_utils.TFGenerationMixin.generate`] method to create the summarization. For more details about the different text generation strategies and parameters for controlling generation, check out the [Text generation strategies](../generation_strategies) page.
|
|
|
|
```py
|
|
>>> from transformers import TFAutoModelForCausalLM
|
|
|
|
>>> model = TFAutoModelForCausalLM.from_pretrained("username/my_awesome_eli5_clm-model")
|
|
>>> outputs = model.generate(input_ids=inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
|
|
```
|
|
|
|
Decode the generated token ids back into text:
|
|
|
|
```py
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
['Somatic hypermutation allows the immune system to detect the presence of other viruses as they become more prevalent. Therefore, researchers have identified a high proportion of human viruses. The proportion of virus-associated viruses in our study increases with age. Therefore, we propose a simple algorithm to detect the presence of these new viruses in our samples as a sign of improved immunity. A first study based on this algorithm, which will be published in Science on Friday, aims to show that this finding could translate into the development of a better vaccine that is more effective for']
|
|
```
|
|
</tf>
|
|
</frameworkcontent>
|