mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
add separator between data import and train
This commit is contained in:
parent
a424892fab
commit
e4e0ee14bd
@ -52,6 +52,10 @@ def set_seed(args):
|
|||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Load dataset
|
||||||
|
# ------------
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
""" Abstracts the dataset used to train seq2seq models.
|
""" Abstracts the dataset used to train seq2seq models.
|
||||||
|
|
||||||
@ -212,6 +216,11 @@ def load_and_cache_examples(args, tokenizer):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Train
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Fine-tune the pretrained model on the corpus. """
|
""" Fine-tune the pretrained model on the corpus. """
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
Loading…
Reference in New Issue
Block a user