Added GPT to the generative fine-tuning.

This commit is contained in:
LysandreJik 2019-08-06 12:14:18 -04:00
parent 47975ed53e
commit 3e3e145497
2 changed files with 4 additions and 4 deletions

View File

@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,)
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_lm import WikiTextDataset
@ -40,7 +41,8 @@ logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
MODEL_CLASSES = {
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
}

View File

@ -28,8 +28,6 @@ class WikiTextDataset(Dataset):
# Sort the array by example length.
self.examples.sort(key=len)
print("nice")
def __len__(self):
return len(self.examples)