use block_size instead of max_seq_length in tf run_clm example (#15036)

* use block_size instead of max_seq_length

* fixup

* remove pad_to_block_size

Co-authored-by: Russell Klopfer <russell@kloper.us>
This commit is contained in:
Russell Klopfer 2022-01-12 08:57:00 -05:00 committed by GitHub
parent 68cc4ccde2
commit 27b819b0e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -148,11 +148,12 @@ class DataTrainingArguments:
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
block_size: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
"help": "Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
},
)
preprocessing_num_workers: Optional[int] = field(
@ -166,13 +167,6 @@ class DataTrainingArguments:
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
@ -259,10 +253,6 @@ def main():
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
logger.warning("We are training on TPU - forcing pad_to_max_length")
data_args.pad_to_max_length = True
# endregion
# region Checkpoints
@ -364,22 +354,6 @@ def main():
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
@ -396,13 +370,21 @@ def main():
desc="Running tokenizer on dataset",
)
block_size = tokenizer.model_max_length
if block_size > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can reduce that value by passing --block_size xxx."
)
block_size = 1024
if data_args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
)
block_size = 1024
else:
if data_args.block_size > tokenizer.model_max_length:
logger.warning(
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
)
block_size = min(data_args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):