mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Added GPT to the generative fine-tuning.
This commit is contained in:
parent
47975ed53e
commit
3e3e145497
@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,)
|
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||||
|
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_lm import WikiTextDataset
|
from utils_lm import WikiTextDataset
|
||||||
@ -40,7 +41,8 @@ logger = logging.getLogger(__name__)
|
|||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
|
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
|
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,8 +28,6 @@ class WikiTextDataset(Dataset):
|
|||||||
# Sort the array by example length.
|
# Sort the array by example length.
|
||||||
self.examples.sort(key=len)
|
self.examples.sort(key=len)
|
||||||
|
|
||||||
print("nice")
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.examples)
|
return len(self.examples)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user