mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Raise error when using an mlm flag for a clm model + correct TextDataset
This commit is contained in:
parent
569897ce2c
commit
f54a5bd37f
@ -86,6 +86,9 @@ MODEL_CLASSES = {
|
|||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
|
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
|
||||||
assert os.path.isfile(file_path)
|
assert os.path.isfile(file_path)
|
||||||
|
|
||||||
|
block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)
|
||||||
|
|
||||||
directory, filename = os.path.split(file_path)
|
directory, filename = os.path.split(file_path)
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
|
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
|
||||||
@ -195,6 +198,12 @@ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -
|
|||||||
|
|
||||||
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
|
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
||||||
|
|
||||||
|
if tokenizer.mask_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
labels = inputs.clone()
|
labels = inputs.clone()
|
||||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
||||||
|
Loading…
Reference in New Issue
Block a user