diff --git a/examples/run_openai_gpt.py b/examples/run_openai_gpt.py index 546c11b528d..bddfcc4e0ad 100644 --- a/examples/run_openai_gpt.py +++ b/examples/run_openai_gpt.py @@ -163,7 +163,7 @@ def main(): datasets = (train_dataset, eval_dataset) encoded_datasets = tokenize_and_encode(datasets) - # Compute the mex input length for the Transformer + # Compute the max input length for the Transformer max_length = model.config.n_positions // 2 - 2 input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)