mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Added GPT to the generative fine-tuning.
This commit is contained in:
parent
47975ed53e
commit
3e3e145497
@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,)
|
||||
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_lm import WikiTextDataset
|
||||
@ -40,7 +41,8 @@ logger = logging.getLogger(__name__)
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
|
||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
|
||||
}
|
||||
|
||||
|
||||
|
@ -28,8 +28,6 @@ class WikiTextDataset(Dataset):
|
||||
# Sort the array by example length.
|
||||
self.examples.sort(key=len)
|
||||
|
||||
print("nice")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user