diff --git a/examples/run_generative_finetuning.py b/examples/run_generative_finetuning.py index e9e4545dfe8..458c1235535 100644 --- a/examples/run_generative_finetuning.py +++ b/examples/run_generative_finetuning.py @@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler from tensorboardX import SummaryWriter from tqdm import tqdm, trange -from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,) +from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, + OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer) from pytorch_transformers import AdamW, WarmupLinearSchedule from utils_lm import WikiTextDataset @@ -40,7 +41,8 @@ logger = logging.getLogger(__name__) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ()) MODEL_CLASSES = { - 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer) + 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), + 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer) } diff --git a/examples/utils_lm.py b/examples/utils_lm.py index 2b6c393a91f..4a3bafb7893 100644 --- a/examples/utils_lm.py +++ b/examples/utils_lm.py @@ -28,8 +28,6 @@ class WikiTextDataset(Dataset): # Sort the array by example length. self.examples.sort(key=len) - print("nice") - def __len__(self): return len(self.examples)