mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
parent
9bd67ac7bb
commit
afe5d42d8d
@ -854,7 +854,7 @@ jobs:
|
|||||||
key: v0.4-code_quality-{{ checksum "setup.py" }}
|
key: v0.4-code_quality-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
- '~/.cache/pip'
|
- '~/.cache/pip'
|
||||||
- run: black --check examples tests src utils
|
- run: black --check --preview examples tests src utils
|
||||||
- run: isort --check-only examples tests src utils
|
- run: isort --check-only examples tests src utils
|
||||||
- run: python utils/custom_init_isort.py --check_only
|
- run: python utils/custom_init_isort.py --check_only
|
||||||
- run: flake8 examples tests src utils
|
- run: flake8 examples tests src utils
|
||||||
|
6
Makefile
6
Makefile
@ -9,7 +9,7 @@ modified_only_fixup:
|
|||||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||||
@if test -n "$(modified_py_files)"; then \
|
@if test -n "$(modified_py_files)"; then \
|
||||||
echo "Checking/fixing $(modified_py_files)"; \
|
echo "Checking/fixing $(modified_py_files)"; \
|
||||||
black $(modified_py_files); \
|
black --preview $(modified_py_files); \
|
||||||
isort $(modified_py_files); \
|
isort $(modified_py_files); \
|
||||||
flake8 $(modified_py_files); \
|
flake8 $(modified_py_files); \
|
||||||
else \
|
else \
|
||||||
@ -45,7 +45,7 @@ repo-consistency:
|
|||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
black --check $(check_dirs)
|
black --check --preview $(check_dirs)
|
||||||
isort --check-only $(check_dirs)
|
isort --check-only $(check_dirs)
|
||||||
python utils/custom_init_isort.py --check_only
|
python utils/custom_init_isort.py --check_only
|
||||||
flake8 $(check_dirs)
|
flake8 $(check_dirs)
|
||||||
@ -60,7 +60,7 @@ extra_style_checks:
|
|||||||
# this target runs checks on all files and potentially modifies some of them
|
# this target runs checks on all files and potentially modifies some of them
|
||||||
|
|
||||||
style:
|
style:
|
||||||
black $(check_dirs)
|
black --preview $(check_dirs)
|
||||||
isort $(check_dirs)
|
isort $(check_dirs)
|
||||||
${MAKE} autogenerate_code
|
${MAKE} autogenerate_code
|
||||||
${MAKE} extra_style_checks
|
${MAKE} extra_style_checks
|
||||||
|
@ -42,14 +42,18 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
encoder_model_name_or_path: str = field(
|
encoder_model_name_or_path: str = field(
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The encoder model checkpoint for weights initialization."
|
"help": (
|
||||||
|
"The encoder model checkpoint for weights initialization."
|
||||||
"Don't set if you want to train an encoder model from scratch."
|
"Don't set if you want to train an encoder model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
decoder_model_name_or_path: str = field(
|
decoder_model_name_or_path: str = field(
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The decoder model checkpoint for weights initialization."
|
"help": (
|
||||||
|
"The decoder model checkpoint for weights initialization."
|
||||||
"Don't set if you want to train a decoder model from scratch."
|
"Don't set if you want to train a decoder model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
encoder_config_name: Optional[str] = field(
|
encoder_config_name: Optional[str] = field(
|
||||||
|
@ -175,14 +175,19 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -222,38 +227,48 @@ class DataTrainingArguments:
|
|||||||
max_target_length: Optional[int] = field(
|
max_target_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_target_length: Optional[int] = field(
|
val_max_target_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||||
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
|
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
|
||||||
"during evaluation."
|
"during evaluation."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -266,8 +281,10 @@ class DataTrainingArguments:
|
|||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
|
"help": (
|
||||||
|
"Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
|
||||||
"which is used during evaluation."
|
"which is used during evaluation."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -623,7 +640,7 @@ def main():
|
|||||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||||
if training_args.block_size % train_batch_size > 0 or training_args.block_size % eval_batch_size > 0:
|
if training_args.block_size % train_batch_size > 0 or training_args.block_size % eval_batch_size > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
|
"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
|
||||||
f"Got {training_args.block_size}, {train_batch_size} and {eval_batch_size} respectively instead."
|
f"Got {training_args.block_size}, {train_batch_size} and {eval_batch_size} respectively instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1136,7 +1153,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
||||||
|
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
batch = next(train_batches)
|
batch = next(train_batches)
|
||||||
@ -1150,7 +1167,10 @@ def main():
|
|||||||
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
||||||
|
|
||||||
_train_metric = unreplicate(train_metric)
|
_train_metric = unreplicate(train_metric)
|
||||||
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
|
desc = (
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
|
||||||
|
f" Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
|
||||||
|
)
|
||||||
epochs.desc = desc
|
epochs.desc = desc
|
||||||
epochs.write(desc)
|
epochs.write(desc)
|
||||||
|
|
||||||
|
@ -138,8 +138,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -162,14 +163,19 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -194,15 +200,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -217,9 +227,11 @@ class DataTrainingArguments:
|
|||||||
block_size: Optional[int] = field(
|
block_size: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Optional input sequence length after tokenization. "
|
"help": (
|
||||||
|
"Optional input sequence length after tokenization. "
|
||||||
"The training dataset will be truncated in block of this size for training. "
|
"The training dataset will be truncated in block of this size for training. "
|
||||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -505,7 +517,8 @@ def main():
|
|||||||
# clm input could be much much longer than block_size
|
# clm input could be much much longer than block_size
|
||||||
if "Token indices sequence length is longer than the" in cl.out:
|
if "Token indices sequence length is longer than the" in cl.out:
|
||||||
tok_logger.warning(
|
tok_logger.warning(
|
||||||
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
||||||
|
" before being passed to the model."
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -735,7 +748,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate'].mean()})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
@ -762,7 +776,10 @@ def main():
|
|||||||
eval_metrics["perplexity"] = float("inf")
|
eval_metrics["perplexity"] = float("inf")
|
||||||
|
|
||||||
# Print metrics and update progress bar
|
# Print metrics and update progress bar
|
||||||
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
desc = (
|
||||||
|
f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:"
|
||||||
|
f" {eval_metrics['perplexity']})"
|
||||||
|
)
|
||||||
epochs.write(desc)
|
epochs.write(desc)
|
||||||
epochs.desc = desc
|
epochs.desc = desc
|
||||||
|
|
||||||
|
@ -136,8 +136,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -160,14 +161,19 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -209,8 +215,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated. Default to the max input length of the model."
|
"than this will be truncated. Default to the max input length of the model."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -223,8 +231,10 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
line_by_line: bool = field(
|
line_by_line: bool = field(
|
||||||
@ -764,7 +774,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
@ -135,8 +135,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -159,14 +160,19 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -208,7 +214,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization and masking. Sequences longer than this"
|
||||||
|
" will be truncated. Default to the max input length of the model."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -337,12 +346,14 @@ class FlaxDataCollatorForT5MLM:
|
|||||||
|
|
||||||
if batch["input_ids"].shape[-1] != self.input_length:
|
if batch["input_ids"].shape[-1] != self.input_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
|
f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
|
||||||
|
f" should be {self.target_length}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch["labels"].shape[-1] != self.target_length:
|
if batch["labels"].shape[-1] != self.target_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
|
f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
|
||||||
|
f" {self.target_length}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
|
# to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
|
||||||
@ -884,7 +895,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate'].mean()})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
@ -157,14 +157,19 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -200,37 +205,46 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=384,
|
default=384,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
"Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
|
||||||
"be faster on GPU but will be slower on TPU)."
|
" batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
version_2_with_negative: bool = field(
|
version_2_with_negative: bool = field(
|
||||||
@ -239,9 +253,11 @@ class DataTrainingArguments:
|
|||||||
null_score_diff_threshold: float = field(
|
null_score_diff_threshold: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
|
"help": (
|
||||||
|
"The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||||
"Only useful when `version_2_with_negative=True`."
|
"Only useful when `version_2_with_negative=True`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
doc_stride: int = field(
|
doc_stride: int = field(
|
||||||
@ -255,8 +271,10 @@ class DataTrainingArguments:
|
|||||||
max_answer_length: int = field(
|
max_answer_length: int = field(
|
||||||
default=30,
|
default=30,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
"help": (
|
||||||
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another."
|
"and end predictions are not conditioned on one another."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -498,9 +516,9 @@ def main():
|
|||||||
# region Tokenizer check: this script requires a fast tokenizer.
|
# region Tokenizer check: this script requires a fast tokenizer.
|
||||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
|
"This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
|
||||||
"at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
|
" https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
|
||||||
"requirement"
|
" this requirement"
|
||||||
)
|
)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@ -928,7 +946,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
@ -149,8 +149,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -173,14 +174,19 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -217,45 +223,57 @@ class DataTrainingArguments:
|
|||||||
max_source_length: Optional[int] = field(
|
max_source_length: Optional[int] = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_target_length: Optional[int] = field(
|
max_target_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_target_length: Optional[int] = field(
|
val_max_target_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||||
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
|
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
|
||||||
"during evaluation."
|
"during evaluation."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -271,8 +289,10 @@ class DataTrainingArguments:
|
|||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
|
"help": (
|
||||||
|
"Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
|
||||||
"which is used during evaluation."
|
"which is used during evaluation."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -831,7 +851,8 @@ def main():
|
|||||||
train_metric = unreplicate(train_metric)
|
train_metric = unreplicate(train_metric)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
|
@ -103,8 +103,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -148,29 +150,37 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. If set, sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. If set, sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -585,7 +595,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
@ -150,8 +150,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,36 +198,46 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. If set, sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. If set, sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
label_all_tokens: bool = field(
|
label_all_tokens: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
|
"help": (
|
||||||
|
"Whether to put the label for one word on all tokens of generated by that word or just on the "
|
||||||
"one (in which case the other tokens will have a padding index)."
|
"one (in which case the other tokens will have a padding index)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return_entity_level_metrics: bool = field(
|
return_entity_level_metrics: bool = field(
|
||||||
@ -693,7 +705,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
@ -744,7 +757,8 @@ def main():
|
|||||||
logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
|
logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
|
f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:"
|
||||||
|
f" {eval_metrics['accuracy']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
|
@ -134,8 +134,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -151,14 +152,19 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,15 +185,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -509,7 +519,8 @@ def main():
|
|||||||
|
|
||||||
train_step_progress_bar.close()
|
train_step_progress_bar.close()
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
|
@ -78,8 +78,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -102,7 +104,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
@ -182,7 +182,7 @@ if is_tf_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||||
if ex_index % 10000 == 0:
|
if ex_index % 10000 == 0:
|
||||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||||
|
|
||||||
@ -297,7 +297,7 @@ class RaceProcessor(DataProcessor):
|
|||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training and dev sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (_, data_raw) in enumerate(lines):
|
for _, data_raw in enumerate(lines):
|
||||||
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
||||||
article = data_raw["article"]
|
article = data_raw["article"]
|
||||||
for i in range(len(data_raw["answers"])):
|
for i in range(len(data_raw["answers"])):
|
||||||
@ -518,7 +518,7 @@ def convert_examples_to_features(
|
|||||||
label_map = {label: i for i, label in enumerate(label_list)}
|
label_map = {label: i for i, label in enumerate(label_list)}
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||||
if ex_index % 10000 == 0:
|
if ex_index % 10000 == 0:
|
||||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||||
choices_inputs = []
|
choices_inputs = []
|
||||||
|
@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O2",
|
default="O2",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||||
|
@ -148,8 +148,10 @@ class GLUETransformer(BaseTransformer):
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help=(
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -173,8 +173,10 @@ class NERTransformer(BaseTransformer):
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help=(
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -551,8 +551,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=384,
|
default=384,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help=(
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
|
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--doc_stride",
|
"--doc_stride",
|
||||||
@ -564,8 +566,10 @@ def main():
|
|||||||
"--max_query_length",
|
"--max_query_length",
|
||||||
default=64,
|
default=64,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
help=(
|
||||||
"be truncated to this length.",
|
"The maximum number of tokens for the question. Questions longer than this will "
|
||||||
|
"be truncated to this length."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -610,20 +614,27 @@ def main():
|
|||||||
"--max_answer_length",
|
"--max_answer_length",
|
||||||
default=30,
|
default=30,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help=(
|
||||||
"and end predictions are not conditioned on one another.",
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
|
"and end predictions are not conditioned on one another."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose_logging",
|
"--verbose_logging",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help=(
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
"If true, all of the warnings related to data processing will be printed. "
|
||||||
|
"A number of warnings are expected for a normal SQuAD evaluation."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang_id",
|
"--lang_id",
|
||||||
default=0,
|
default=0,
|
||||||
type=int,
|
type=int,
|
||||||
help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
|
help=(
|
||||||
|
"language id of input for language-specific xlm models (see"
|
||||||
|
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
||||||
@ -652,8 +663,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
@ -84,7 +84,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
@ -68,7 +68,10 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
|
"help": (
|
||||||
|
"The model checkpoint for weights initialization. Leave None if you want to train a model from"
|
||||||
|
" scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -99,8 +102,10 @@ class DataTrainingArguments:
|
|||||||
train_data_files: Optional[str] = field(
|
train_data_files: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The input training data files (multiple files in glob format). "
|
"help": (
|
||||||
|
"The input training data files (multiple files in glob format). "
|
||||||
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
|
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eval_data_file: Optional[str] = field(
|
eval_data_file: Optional[str] = field(
|
||||||
@ -130,7 +135,10 @@ class DataTrainingArguments:
|
|||||||
plm_probability: float = field(
|
plm_probability: float = field(
|
||||||
default=1 / 6,
|
default=1 / 6,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
|
"help": (
|
||||||
|
"Ratio of length of a span of masked tokens to surrounding context length for permutation language"
|
||||||
|
" modeling."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_span_length: int = field(
|
max_span_length: int = field(
|
||||||
@ -140,9 +148,11 @@ class DataTrainingArguments:
|
|||||||
block_size: int = field(
|
block_size: int = field(
|
||||||
default=-1,
|
default=-1,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Optional input sequence length after tokenization."
|
"help": (
|
||||||
|
"Optional input sequence length after tokenization."
|
||||||
"The training dataset will be truncated in block of this size for training."
|
"The training dataset will be truncated in block of this size for training."
|
||||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -206,7 +216,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@ -253,8 +264,8 @@ def main():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
|
"You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another"
|
||||||
"and load it from here, using --tokenizer_name"
|
" script, save it,and load it from here, using --tokenizer_name"
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_args.model_name_or_path:
|
if model_args.model_name_or_path:
|
||||||
|
@ -126,15 +126,15 @@ def main():
|
|||||||
"--max_steps",
|
"--max_steps",
|
||||||
default=-1,
|
default=-1,
|
||||||
type=int,
|
type=int,
|
||||||
help="If > 0: set total number of training \
|
help=(
|
||||||
steps to perform. Override num_train_epochs.",
|
"If > 0: set total number of training steps to perform. Override num_train_epochs."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gradient_accumulation_steps",
|
"--gradient_accumulation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before\
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
performing a backward/update pass.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
|
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
@ -516,8 +516,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=384,
|
default=384,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences "
|
help=(
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences "
|
||||||
|
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -576,8 +578,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
@ -90,31 +90,39 @@ class DataTrainingArguments:
|
|||||||
max_source_length: Optional[int] = field(
|
max_source_length: Optional[int] = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_target_length: Optional[int] = field(
|
max_target_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_target_length: Optional[int] = field(
|
val_max_target_length: Optional[int] = field(
|
||||||
default=142,
|
default=142,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded. "
|
"than this will be truncated, sequences shorter will be padded. "
|
||||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||||
"during ``evaluate`` and ``predict``."
|
"during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
test_max_target_length: Optional[int] = field(
|
test_max_target_length: Optional[int] = field(
|
||||||
default=142,
|
default=142,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for test target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
||||||
|
@ -22,15 +22,30 @@ from utils import calculate_rouge
|
|||||||
|
|
||||||
|
|
||||||
PRED = [
|
PRED = [
|
||||||
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.',
|
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the'
|
||||||
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.",
|
' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe'
|
||||||
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.",
|
" depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.",
|
||||||
|
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal"
|
||||||
|
" accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's"
|
||||||
|
" founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the"
|
||||||
|
" body.",
|
||||||
|
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of"
|
||||||
|
" state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the"
|
||||||
|
" world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital"
|
||||||
|
" punishment.",
|
||||||
]
|
]
|
||||||
|
|
||||||
TGT = [
|
TGT = [
|
||||||
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .',
|
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .'
|
||||||
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .",
|
' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz'
|
||||||
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .",
|
" had informed his Lufthansa training school of an episode of severe depression, airline says .",
|
||||||
|
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ."
|
||||||
|
" Israel and the United States opposed the move, which could open the door to war crimes investigations against"
|
||||||
|
" Israelis .",
|
||||||
|
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to"
|
||||||
|
" death . Organization claims that governments around the world are using the threat of terrorism to advance"
|
||||||
|
" executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death"
|
||||||
|
" sentences up by 28% .",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +80,8 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
|
|||||||
]
|
]
|
||||||
tgt = [
|
tgt = [
|
||||||
"Margot Frank, died in 1945, a month earlier than previously thought.",
|
"Margot Frank, died in 1945, a month earlier than previously thought.",
|
||||||
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.',
|
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of'
|
||||||
|
" the final seconds on board Flight 9525.",
|
||||||
]
|
]
|
||||||
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)
|
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)
|
||||||
|
|
||||||
|
@ -121,7 +121,10 @@ def run_generate(verbose=True):
|
|||||||
nargs="?",
|
nargs="?",
|
||||||
type=str,
|
type=str,
|
||||||
const=datetime_now(),
|
const=datetime_now(),
|
||||||
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
|
help=(
|
||||||
|
"use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
|
||||||
|
" lang=en-ru. If no value is passed, the current datetime string will be used."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
@ -35,7 +35,7 @@ def parse_search_arg(search):
|
|||||||
groups = search.split()
|
groups = search.split()
|
||||||
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
|
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
|
||||||
entry_names = list(entries.keys())
|
entry_names = list(entries.keys())
|
||||||
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
|
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
|
||||||
matrix = [list(x) for x in itertools.product(*sets)]
|
matrix = [list(x) for x in itertools.product(*sets)]
|
||||||
return matrix, entry_names
|
return matrix, entry_names
|
||||||
|
|
||||||
@ -66,7 +66,10 @@ def run_search():
|
|||||||
prog = sys.argv[0]
|
prog = sys.argv[0]
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
|
usage=(
|
||||||
|
"\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
|
||||||
|
" refer to `run_eval.py -h` for the complete list."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--search",
|
"--search",
|
||||||
@ -83,7 +86,10 @@ def run_search():
|
|||||||
nargs="?",
|
nargs="?",
|
||||||
type=str,
|
type=str,
|
||||||
const=datetime_now(),
|
const=datetime_now(),
|
||||||
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
|
help=(
|
||||||
|
"add custom notes to be printed before the results table. If no value is passed, the current datetime"
|
||||||
|
" string will be used."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
args, args_main = parser.parse_known_args()
|
args, args_main = parser.parse_known_args()
|
||||||
# we share some of the args
|
# we share some of the args
|
||||||
|
@ -57,9 +57,10 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
assert isinstance(
|
assert isinstance(self.model, PreTrainedModel), (
|
||||||
self.model, PreTrainedModel
|
"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is"
|
||||||
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
|
f" {self.model.__class__}"
|
||||||
|
)
|
||||||
self.config = self.model.config
|
self.config = self.model.config
|
||||||
else:
|
else:
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -68,13 +69,15 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
||||||
|
|
||||||
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
|
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
|
||||||
assert (
|
assert self.config.pad_token_id is not None, (
|
||||||
self.config.pad_token_id is not None
|
"Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss"
|
||||||
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
|
" calculation or doing label smoothing."
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
|
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
|
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for"
|
||||||
|
" padding.."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.label_smoothing == 0:
|
if self.args.label_smoothing == 0:
|
||||||
@ -248,7 +251,8 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
|
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
|
"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be"
|
||||||
|
f" padded to `max_length`={max_length}"
|
||||||
)
|
)
|
||||||
|
|
||||||
padded_tensor = pad_token_id * torch.ones(
|
padded_tensor = pad_token_id * torch.ones(
|
||||||
|
@ -39,9 +39,7 @@ def parse_args():
|
|||||||
"""
|
"""
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
description=(
|
description=(
|
||||||
"PyTorch TPU distributed training launch "
|
"PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
|
||||||
"helper utility that will spawn up "
|
|
||||||
"multiple distributed processes"
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -168,8 +168,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -215,7 +217,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
@ -87,8 +87,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -116,7 +118,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
module = import_module("tasks")
|
module = import_module("tasks")
|
||||||
|
@ -88,8 +88,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -111,7 +113,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
module = import_module("tasks")
|
module = import_module("tasks")
|
||||||
|
@ -103,7 +103,7 @@ class TokenClassificationTask:
|
|||||||
label_map = {label: i for i, label in enumerate(label_list)}
|
label_map = {label: i for i, label in enumerate(label_list)}
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (ex_index, example) in enumerate(examples):
|
for ex_index, example in enumerate(examples):
|
||||||
if ex_index % 10_000 == 0:
|
if ex_index % 10_000 == 0:
|
||||||
logger.info("Writing example %d of %d", ex_index, len(examples))
|
logger.info("Writing example %d of %d", ex_index, len(examples))
|
||||||
|
|
||||||
|
@ -86,8 +86,9 @@ class DataTrainingArguments:
|
|||||||
eval_split_name: str = field(
|
eval_split_name: str = field(
|
||||||
default="validation",
|
default="validation",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to "
|
"help": (
|
||||||
"'validation'"
|
"The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
audio_column_name: str = field(
|
audio_column_name: str = field(
|
||||||
@ -100,15 +101,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_length_seconds: float = field(
|
max_length_seconds: float = field(
|
||||||
@ -149,8 +154,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
freeze_feature_extractor: Optional[bool] = field(
|
freeze_feature_extractor: Optional[bool] = field(
|
||||||
|
@ -89,8 +89,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
freeze_vision_model: bool = field(
|
freeze_vision_model: bool = field(
|
||||||
@ -132,22 +134,28 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
|
@ -93,15 +93,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -140,8 +144,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,7 +62,10 @@ def parse_args():
|
|||||||
"--dataset_name",
|
"--dataset_name",
|
||||||
type=str,
|
type=str,
|
||||||
default="cifar10",
|
default="cifar10",
|
||||||
help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset).",
|
help=(
|
||||||
|
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||||
|
" dataset)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--train_dir", type=str, default=None, help="A folder containing the training data.")
|
parser.add_argument("--train_dir", type=str, default=None, help="A folder containing the training data.")
|
||||||
parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.")
|
parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.")
|
||||||
@ -70,15 +73,19 @@ def parse_args():
|
|||||||
"--max_train_samples",
|
"--max_train_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of training examples to this "
|
help=(
|
||||||
"value if set.",
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
|
"value if set."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_eval_samples",
|
"--max_eval_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
help=(
|
||||||
"value if set.",
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
|
"value if set."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_val_split",
|
"--train_val_split",
|
||||||
|
@ -74,15 +74,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,8 +108,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
@ -114,8 +119,10 @@ class ModelArguments:
|
|||||||
config_overrides: Optional[str] = field(
|
config_overrides: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
"help": (
|
||||||
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
@ -129,8 +136,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mask_ratio: float = field(
|
mask_ratio: float = field(
|
||||||
|
@ -87,15 +87,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,9 +121,11 @@ class ModelArguments:
|
|||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
|
"help": (
|
||||||
|
"The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
|
||||||
"checkpoint identifier on the hub. "
|
"checkpoint identifier on the hub. "
|
||||||
"Don't set if you want to train a model from scratch."
|
"Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -132,8 +138,10 @@ class ModelArguments:
|
|||||||
config_overrides: Optional[str] = field(
|
config_overrides: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
"help": (
|
||||||
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
@ -148,20 +156,26 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
image_size: Optional[int] = field(
|
image_size: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
|
"help": (
|
||||||
|
"The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
patch_size: Optional[int] = field(
|
patch_size: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
|
"help": (
|
||||||
|
"The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
encoder_stride: Optional[int] = field(
|
encoder_stride: Optional[int] = field(
|
||||||
|
@ -73,8 +73,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -84,8 +85,10 @@ class ModelArguments:
|
|||||||
config_overrides: Optional[str] = field(
|
config_overrides: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
"help": (
|
||||||
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
@ -109,8 +112,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,24 +146,30 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
block_size: Optional[int] = field(
|
block_size: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Optional input sequence length after tokenization. "
|
"help": (
|
||||||
|
"Optional input sequence length after tokenization. "
|
||||||
"The training dataset will be truncated in block of this size for training. "
|
"The training dataset will be truncated in block of this size for training. "
|
||||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -390,7 +401,8 @@ def main():
|
|||||||
# clm input could be much much longer than block_size
|
# clm input could be much much longer than block_size
|
||||||
if "Token indices sequence length is longer than the" in cl.out:
|
if "Token indices sequence length is longer than the" in cl.out:
|
||||||
tok_logger.warning(
|
tok_logger.warning(
|
||||||
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
||||||
|
" before being passed to the model."
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -168,7 +168,11 @@ def parse_args():
|
|||||||
"--block_size",
|
"--block_size",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Optional input sequence length after tokenization. The training dataset will be truncated in block of this size for training. Default to the model max input length for single sentence inputs (take into account special tokens).",
|
help=(
|
||||||
|
"Optional input sequence length after tokenization. The training dataset will be truncated in block of"
|
||||||
|
" this size for training. Default to the model max input length for single sentence inputs (take into"
|
||||||
|
" account special tokens)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers",
|
"--preprocessing_num_workers",
|
||||||
|
@ -70,8 +70,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -81,8 +82,10 @@ class ModelArguments:
|
|||||||
config_overrides: Optional[str] = field(
|
config_overrides: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
"help": (
|
||||||
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
@ -106,8 +109,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,8 +152,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated."
|
"than this will be truncated."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -165,22 +172,28 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -171,7 +171,9 @@ def parse_args():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
|
help=(
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--line_by_line",
|
"--line_by_line",
|
||||||
|
@ -63,8 +63,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
@ -73,8 +74,10 @@ class ModelArguments:
|
|||||||
config_overrides: Optional[str] = field(
|
config_overrides: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
"help": (
|
||||||
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tokenizer_name: Optional[str] = field(
|
tokenizer_name: Optional[str] = field(
|
||||||
@ -95,8 +98,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,8 +141,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated."
|
"than this will be truncated."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -147,8 +154,10 @@ class DataTrainingArguments:
|
|||||||
plm_probability: float = field(
|
plm_probability: float = field(
|
||||||
default=1 / 6,
|
default=1 / 6,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Ratio of length of a span of masked tokens to surrounding context length for "
|
"help": (
|
||||||
|
"Ratio of length of a span of masked tokens to surrounding context length for "
|
||||||
"permutation language modeling."
|
"permutation language modeling."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_span_length: int = field(
|
max_span_length: int = field(
|
||||||
@ -161,22 +170,28 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,8 +82,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -109,30 +111,38 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. If passed, sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to the maximum sentence length. "
|
"help": (
|
||||||
|
"Whether to pad all samples to the maximum sentence length. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||||
"efficient on GPU but very bad for TPU."
|
"efficient on GPU but very bad for TPU."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,8 +81,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -118,37 +120,46 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=384,
|
default=384,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
"Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
|
||||||
"be faster on GPU but will be slower on TPU)."
|
" batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
version_2_with_negative: bool = field(
|
version_2_with_negative: bool = field(
|
||||||
@ -157,9 +168,11 @@ class DataTrainingArguments:
|
|||||||
null_score_diff_threshold: float = field(
|
null_score_diff_threshold: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
|
"help": (
|
||||||
|
"The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||||
"Only useful when `version_2_with_negative=True`."
|
"Only useful when `version_2_with_negative=True`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
doc_stride: int = field(
|
doc_stride: int = field(
|
||||||
@ -173,8 +186,10 @@ class DataTrainingArguments:
|
|||||||
max_answer_length: int = field(
|
max_answer_length: int = field(
|
||||||
default=30,
|
default=30,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
"help": (
|
||||||
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another."
|
"and end predictions are not conditioned on one another."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -319,9 +334,9 @@ def main():
|
|||||||
# Tokenizer check: this script requires a fast tokenizer.
|
# Tokenizer check: this script requires a fast tokenizer.
|
||||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
|
"This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
|
||||||
"at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
|
" https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
|
||||||
"requirement"
|
" this requirement"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
|
@ -80,8 +80,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,37 +119,46 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=384,
|
default=384,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
"Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
|
||||||
"be faster on GPU but will be slower on TPU)."
|
" batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
version_2_with_negative: bool = field(
|
version_2_with_negative: bool = field(
|
||||||
@ -156,9 +167,11 @@ class DataTrainingArguments:
|
|||||||
null_score_diff_threshold: float = field(
|
null_score_diff_threshold: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
|
"help": (
|
||||||
|
"The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||||
"Only useful when `version_2_with_negative=True`."
|
"Only useful when `version_2_with_negative=True`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
doc_stride: int = field(
|
doc_stride: int = field(
|
||||||
@ -172,8 +185,10 @@ class DataTrainingArguments:
|
|||||||
max_answer_length: int = field(
|
max_answer_length: int = field(
|
||||||
default=30,
|
default=30,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
"help": (
|
||||||
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another."
|
"and end predictions are not conditioned on one another."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -116,8 +116,10 @@ def parse_args():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=384,
|
default=384,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
help=(
|
||||||
" sequences shorter will be padded if `--pad_to_max_lengh` is passed.",
|
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
||||||
|
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pad_to_max_length",
|
"--pad_to_max_length",
|
||||||
@ -190,9 +192,11 @@ def parse_args():
|
|||||||
"--null_score_diff_threshold",
|
"--null_score_diff_threshold",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The threshold used to select the null answer: if the best answer has a score that is less than "
|
help=(
|
||||||
|
"The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||||
"Only useful when `version_2_with_negative=True`.",
|
"Only useful when `version_2_with_negative=True`."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--version_2_with_negative",
|
"--version_2_with_negative",
|
||||||
@ -203,22 +207,28 @@ def parse_args():
|
|||||||
"--max_answer_length",
|
"--max_answer_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help=(
|
||||||
"and end predictions are not conditioned on one another.",
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
|
"and end predictions are not conditioned on one another."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_train_samples",
|
"--max_train_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of training examples to this "
|
help=(
|
||||||
"value if set.",
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
|
"value if set."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_eval_samples",
|
"--max_eval_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
help=(
|
||||||
"value if set.",
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
|
"value if set."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||||
|
@ -121,8 +121,10 @@ def parse_args():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=384,
|
default=384,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
help=(
|
||||||
" sequences shorter will be padded if `--pad_to_max_lengh` is passed.",
|
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
||||||
|
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pad_to_max_length",
|
"--pad_to_max_length",
|
||||||
@ -212,9 +214,11 @@ def parse_args():
|
|||||||
"--null_score_diff_threshold",
|
"--null_score_diff_threshold",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The threshold used to select the null answer: if the best answer has a score that is less than "
|
help=(
|
||||||
|
"The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||||
"Only useful when `version_2_with_negative=True`.",
|
"Only useful when `version_2_with_negative=True`."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--version_2_with_negative",
|
"--version_2_with_negative",
|
||||||
@ -225,22 +229,28 @@ def parse_args():
|
|||||||
"--max_answer_length",
|
"--max_answer_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help=(
|
||||||
"and end predictions are not conditioned on one another.",
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
|
"and end predictions are not conditioned on one another."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_train_samples",
|
"--max_train_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of training examples to this "
|
help=(
|
||||||
"value if set.",
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
|
"value if set."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_eval_samples",
|
"--max_eval_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
help=(
|
||||||
"value if set.",
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
|
"value if set."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||||
|
@ -81,8 +81,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -130,53 +132,66 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=384,
|
default=384,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_answer_length: int = field(
|
max_answer_length: int = field(
|
||||||
default=30,
|
default=30,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
"help": (
|
||||||
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another."
|
"and end predictions are not conditioned on one another."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_answer_length: Optional[int] = field(
|
val_max_answer_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
|
||||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||||
"during ``evaluate`` and ``predict``."
|
"during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
"Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
|
||||||
"be faster on GPU but will be slower on TPU)."
|
" batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
version_2_with_negative: bool = field(
|
version_2_with_negative: bool = field(
|
||||||
@ -185,9 +200,11 @@ class DataTrainingArguments:
|
|||||||
null_score_diff_threshold: float = field(
|
null_score_diff_threshold: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
|
"help": (
|
||||||
|
"The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||||
"Only useful when `version_2_with_negative=True`."
|
"Only useful when `version_2_with_negative=True`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
doc_stride: int = field(
|
doc_stride: int = field(
|
||||||
@ -201,8 +218,10 @@ class DataTrainingArguments:
|
|||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
"help": (
|
||||||
|
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||||
"which is used during ``evaluate`` and ``predict``."
|
"which is used during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: bool = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
|
@ -194,15 +194,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
reduce_labels: Optional[bool] = field(
|
reduce_labels: Optional[bool] = field(
|
||||||
@ -241,8 +245,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -219,7 +219,10 @@ def parse_args():
|
|||||||
"--pad_to_multiple_of",
|
"--pad_to_multiple_of",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).",
|
help=(
|
||||||
|
"If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the"
|
||||||
|
" use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adam_beta1",
|
"--adam_beta1",
|
||||||
@ -440,7 +443,7 @@ def main():
|
|||||||
# only normalized-inputs-training is supported
|
# only normalized-inputs-training is supported
|
||||||
if not feature_extractor.do_normalize:
|
if not feature_extractor.do_normalize:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``"
|
"Training is only supported for normalized inputs. Make sure ``feature_extractor.do_normalize == True``"
|
||||||
)
|
)
|
||||||
|
|
||||||
# set max & min audio length in number of samples
|
# set max & min audio length in number of samples
|
||||||
@ -496,7 +499,8 @@ def main():
|
|||||||
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
||||||
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
|
||||||
|
" ``config.feat_extract_norm='layer'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize random model
|
# initialize random model
|
||||||
@ -615,7 +619,7 @@ def main():
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
elif accelerator.is_local_main_process:
|
elif accelerator.is_local_main_process:
|
||||||
progress_bar.write(
|
progress_bar.write(
|
||||||
"Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..."
|
f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# update gumbel temperature
|
# update gumbel temperature
|
||||||
|
@ -101,9 +101,11 @@ class ModelArguments:
|
|||||||
mask_time_prob: float = field(
|
mask_time_prob: float = field(
|
||||||
default=0.05,
|
default=0.05,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
"help": (
|
||||||
|
"Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
||||||
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
||||||
"vectors will be masked along the time axis."
|
"vectors will be masked along the time axis."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mask_time_length: int = field(
|
mask_time_length: int = field(
|
||||||
@ -113,8 +115,11 @@ class ModelArguments:
|
|||||||
mask_feature_prob: float = field(
|
mask_feature_prob: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
"help": (
|
||||||
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
|
"Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
|
||||||
|
" to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
|
||||||
|
" bins will be masked along the time axis."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mask_feature_length: int = field(
|
mask_feature_length: int = field(
|
||||||
@ -146,8 +151,10 @@ class DataTrainingArguments:
|
|||||||
train_split_name: str = field(
|
train_split_name: str = field(
|
||||||
default="train+validation",
|
default="train+validation",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to "
|
"help": (
|
||||||
|
"The name of the training data set split to use (via the datasets library). Defaults to "
|
||||||
"'train+validation'"
|
"'train+validation'"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eval_split_name: str = field(
|
eval_split_name: str = field(
|
||||||
@ -174,15 +181,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
chars_to_ignore: Optional[List[str]] = list_field(
|
chars_to_ignore: Optional[List[str]] = list_field(
|
||||||
@ -196,7 +207,10 @@ class DataTrainingArguments:
|
|||||||
max_duration_in_seconds: float = field(
|
max_duration_in_seconds: float = field(
|
||||||
default=20.0,
|
default=20.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
|
"help": (
|
||||||
|
"Filter audio files that are longer than `max_duration_in_seconds` seconds to"
|
||||||
|
" 'max_duration_in_seconds`"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
min_duration_in_seconds: float = field(
|
min_duration_in_seconds: float = field(
|
||||||
@ -205,17 +219,21 @@ class DataTrainingArguments:
|
|||||||
preprocessing_only: bool = field(
|
preprocessing_only: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to only do data preprocessing and skip training. "
|
"help": (
|
||||||
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
"Whether to only do data preprocessing and skip training. This is especially useful when data"
|
||||||
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
|
||||||
"so that the cached datasets can consequently be loaded in distributed training"
|
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
|
||||||
|
" can consequently be loaded in distributed training"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "If :obj:`True`, will use the token generated when running"
|
"help": (
|
||||||
|
"If :obj:`True`, will use the token generated when running"
|
||||||
":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
|
":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
unk_token: str = field(
|
unk_token: str = field(
|
||||||
@ -233,10 +251,12 @@ class DataTrainingArguments:
|
|||||||
phoneme_language: Optional[str] = field(
|
phoneme_language: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The target language that should be used be"
|
"help": (
|
||||||
|
"The target language that should be used be"
|
||||||
" passed to the tokenizer for tokenization. Note that"
|
" passed to the tokenizer for tokenization. Note that"
|
||||||
" this is only relevant if the model classifies the"
|
" this is only relevant if the model classifies the"
|
||||||
" input audio to a sequence of phoneme sequences."
|
" input audio to a sequence of phoneme sequences."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -720,7 +740,10 @@ def main():
|
|||||||
"finetuned_from": model_args.model_name_or_path,
|
"finetuned_from": model_args.model_name_or_path,
|
||||||
"tasks": "speech-recognition",
|
"tasks": "speech-recognition",
|
||||||
"tags": ["automatic-speech-recognition", data_args.dataset_name],
|
"tags": ["automatic-speech-recognition", data_args.dataset_name],
|
||||||
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
|
"dataset_args": (
|
||||||
|
f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
|
||||||
|
f" {data_args.eval_split_name}"
|
||||||
|
),
|
||||||
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
||||||
}
|
}
|
||||||
if "common_voice" in data_args.dataset_name:
|
if "common_voice" in data_args.dataset_name:
|
||||||
|
@ -87,8 +87,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
freeze_feature_encoder: bool = field(
|
freeze_feature_encoder: bool = field(
|
||||||
@ -122,15 +124,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
audio_column_name: str = field(
|
audio_column_name: str = field(
|
||||||
@ -144,7 +150,10 @@ class DataTrainingArguments:
|
|||||||
max_duration_in_seconds: float = field(
|
max_duration_in_seconds: float = field(
|
||||||
default=20.0,
|
default=20.0,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
|
"help": (
|
||||||
|
"Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
|
||||||
|
" 'max_duration_in_seconds`"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
min_duration_in_seconds: float = field(
|
min_duration_in_seconds: float = field(
|
||||||
@ -153,10 +162,12 @@ class DataTrainingArguments:
|
|||||||
preprocessing_only: bool = field(
|
preprocessing_only: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to only do data preprocessing and skip training. "
|
"help": (
|
||||||
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
"Whether to only do data preprocessing and skip training. This is especially useful when data"
|
||||||
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
|
||||||
"so that the cached datasets can consequently be loaded in distributed training"
|
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
|
||||||
|
" can consequently be loaded in distributed training"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
train_split_name: str = field(
|
train_split_name: str = field(
|
||||||
|
@ -101,15 +101,19 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
resize_position_embeddings: Optional[bool] = field(
|
resize_position_embeddings: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
|
"help": (
|
||||||
|
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
|
||||||
"the model's position embeddings."
|
"the model's position embeddings."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -142,14 +146,15 @@ class DataTrainingArguments:
|
|||||||
validation_file: Optional[str] = field(
|
validation_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
|
"help": (
|
||||||
"(a jsonlines or csv file)."
|
"An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
test_file: Optional[str] = field(
|
test_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
|
"help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -162,60 +167,76 @@ class DataTrainingArguments:
|
|||||||
max_source_length: Optional[int] = field(
|
max_source_length: Optional[int] = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_target_length: Optional[int] = field(
|
max_target_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_target_length: Optional[int] = field(
|
val_max_target_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||||
"during ``evaluate`` and ``predict``."
|
"during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
"help": (
|
||||||
|
"Whether to pad all samples to model maximum sentence length. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||||
"efficient on GPU but very bad for TPU."
|
"efficient on GPU but very bad for TPU."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
"help": (
|
||||||
|
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||||
"which is used during ``evaluate`` and ``predict``."
|
"which is used during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: bool = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
@ -231,9 +252,11 @@ class DataTrainingArguments:
|
|||||||
forced_bos_token: Optional[str] = field(
|
forced_bos_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The token to force as the first generated token after the decoder_start_token_id."
|
"help": (
|
||||||
|
"The token to force as the first generated token after the decoder_start_token_id."
|
||||||
"Useful for multilingual models like mBART where the first generated token"
|
"Useful for multilingual models like mBART where the first generated token"
|
||||||
"needs to be the target language token (Usually it is the target language token)"
|
"needs to be the target language token (Usually it is the target language token)"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -410,17 +433,18 @@ def main():
|
|||||||
):
|
):
|
||||||
if model_args.resize_position_embeddings is None:
|
if model_args.resize_position_embeddings is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} "
|
"Increasing the model's number of position embedding vectors from"
|
||||||
f"to {data_args.max_source_length}."
|
f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
|
||||||
)
|
)
|
||||||
model.resize_position_embeddings(data_args.max_source_length)
|
model.resize_position_embeddings(data_args.max_source_length)
|
||||||
elif model_args.resize_position_embeddings:
|
elif model_args.resize_position_embeddings:
|
||||||
model.resize_position_embeddings(data_args.max_source_length)
|
model.resize_position_embeddings(data_args.max_source_length)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
|
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
|
||||||
f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
|
f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
|
||||||
"resize the model's position encodings by passing `--resize_position_embeddings`."
|
f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
|
||||||
|
" model's position encodings by passing `--resize_position_embeddings`."
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||||
|
@ -111,20 +111,22 @@ def parse_args():
|
|||||||
"--ignore_pad_token_for_loss",
|
"--ignore_pad_token_for_loss",
|
||||||
type=bool,
|
type=bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
|
help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_source_length",
|
"--max_source_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=1024,
|
||||||
help="The maximum total input sequence length after "
|
help=(
|
||||||
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after "
|
||||||
|
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--source_prefix",
|
"--source_prefix",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="A prefix to add before every source text " "(useful for T5 models).",
|
help="A prefix to add before every source text (useful for T5 models).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers",
|
"--preprocessing_num_workers",
|
||||||
@ -139,18 +141,22 @@ def parse_args():
|
|||||||
"--max_target_length",
|
"--max_target_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=128,
|
||||||
help="The maximum total sequence length for target text after "
|
help=(
|
||||||
|
"The maximum total sequence length for target text after "
|
||||||
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||||
"during ``evaluate`` and ``predict``.",
|
"during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--val_max_target_length",
|
"--val_max_target_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="The maximum total sequence length for validation "
|
help=(
|
||||||
|
"The maximum total sequence length for validation "
|
||||||
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
|
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
|
||||||
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
|
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
|
||||||
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_length",
|
"--max_length",
|
||||||
@ -165,8 +171,10 @@ def parse_args():
|
|||||||
"--num_beams",
|
"--num_beams",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Number of beams to use for evaluation. This argument will be "
|
help=(
|
||||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
"Number of beams to use for evaluation. This argument will be "
|
||||||
|
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pad_to_max_length",
|
"--pad_to_max_length",
|
||||||
|
@ -89,8 +89,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -99,29 +101,37 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
train_file: Optional[str] = field(
|
train_file: Optional[str] = field(
|
||||||
@ -180,8 +190,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,8 +67,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -77,29 +79,37 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
|
server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
|
||||||
@ -146,8 +156,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,8 +81,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,44 +129,56 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. If set, sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. If set, sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
"help": (
|
||||||
|
"Whether to pad all samples to model maximum sentence length. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||||
"efficient on GPU but very bad for TPU."
|
"efficient on GPU but very bad for TPU."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
label_all_tokens: bool = field(
|
label_all_tokens: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
|
"help": (
|
||||||
|
"Whether to put the label for one word on all tokens of generated by that word or just on the "
|
||||||
"one (in which case the other tokens will have a padding index)."
|
"one (in which case the other tokens will have a padding index)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return_entity_level_metrics: bool = field(
|
return_entity_level_metrics: bool = field(
|
||||||
@ -355,9 +369,9 @@ def main():
|
|||||||
# Tokenizer check: this script requires a fast tokenizer.
|
# Tokenizer check: this script requires a fast tokenizer.
|
||||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
|
"This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
|
||||||
"at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
|
" https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
|
||||||
"requirement"
|
" this requirement"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model has labels -> use them.
|
# Model has labels -> use them.
|
||||||
@ -373,8 +387,8 @@ def main():
|
|||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||||
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
|
||||||
"\nIgnoring the model labels as a result.",
|
f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the correspondences label/ID inside the model config
|
# Set the correspondences label/ID inside the model config
|
||||||
|
@ -403,8 +403,8 @@ def main():
|
|||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||||
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
|
||||||
"\nIgnoring the model labels as a result.",
|
f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the correspondences label/ID inside the model config
|
# Set the correspondences label/ID inside the model config
|
||||||
|
@ -91,8 +91,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -116,15 +118,12 @@ class DataTrainingArguments:
|
|||||||
validation_file: Optional[str] = field(
|
validation_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on a jsonlines file."
|
||||||
"a jsonlines file."
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
test_file: Optional[str] = field(
|
test_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "An optional input test data file to evaluate the metrics (sacreblue) on a jsonlines file."},
|
||||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
@ -136,60 +135,76 @@ class DataTrainingArguments:
|
|||||||
max_source_length: Optional[int] = field(
|
max_source_length: Optional[int] = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_target_length: Optional[int] = field(
|
max_target_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_target_length: Optional[int] = field(
|
val_max_target_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||||
"during ``evaluate`` and ``predict``."
|
"during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
"help": (
|
||||||
|
"Whether to pad all samples to model maximum sentence length. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||||
"efficient on GPU but very bad for TPU."
|
"efficient on GPU but very bad for TPU."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
"help": (
|
||||||
|
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||||
"which is used during ``evaluate`` and ``predict``."
|
"which is used during ``evaluate`` and ``predict``."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: bool = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
@ -204,9 +219,11 @@ class DataTrainingArguments:
|
|||||||
forced_bos_token: Optional[str] = field(
|
forced_bos_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
"help": (
|
||||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
"The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
|
||||||
"needs to be the target language token.(Usually it is the target language token)"
|
" multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
|
||||||
|
" be the target language token.(Usually it is the target language token)"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,41 +95,51 @@ def parse_args():
|
|||||||
"--num_beams",
|
"--num_beams",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Number of beams to use for evaluation. This argument will be "
|
help=(
|
||||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
"Number of beams to use for evaluation. This argument will be "
|
||||||
|
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_source_length",
|
"--max_source_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=1024,
|
||||||
help="The maximum total input sequence length after "
|
help=(
|
||||||
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after "
|
||||||
|
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_target_length",
|
"--max_target_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=128,
|
||||||
help="The maximum total sequence length for target text after "
|
help=(
|
||||||
|
"The maximum total sequence length for target text after "
|
||||||
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||||
"during ``evaluate`` and ``predict``.",
|
"during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--val_max_target_length",
|
"--val_max_target_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="The maximum total sequence length for validation "
|
help=(
|
||||||
|
"The maximum total sequence length for validation "
|
||||||
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
|
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
|
||||||
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
|
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
|
||||||
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pad_to_max_length",
|
"--pad_to_max_length",
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to pad all samples to model maximum sentence "
|
help=(
|
||||||
|
"Whether to pad all samples to model maximum sentence "
|
||||||
"length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
|
"length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
|
||||||
"efficient on GPU but very bad for TPU.",
|
"efficient on GPU but very bad for TPU."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
@ -138,7 +148,7 @@ def parse_args():
|
|||||||
"--ignore_pad_token_for_loss",
|
"--ignore_pad_token_for_loss",
|
||||||
type=bool,
|
type=bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
|
help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--source_lang", type=str, default=None, help="Source language id for translation.")
|
parser.add_argument("--source_lang", type=str, default=None, help="Source language id for translation.")
|
||||||
parser.add_argument("--target_lang", type=str, default=None, help="Target language id for translation.")
|
parser.add_argument("--target_lang", type=str, default=None, help="Target language id for translation.")
|
||||||
@ -146,7 +156,7 @@ def parse_args():
|
|||||||
"--source_prefix",
|
"--source_prefix",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="A prefix to add before every source text " "(useful for T5 models).",
|
help="A prefix to add before every source text (useful for T5 models).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers",
|
"--preprocessing_num_workers",
|
||||||
|
@ -39,9 +39,7 @@ def parse_args():
|
|||||||
"""
|
"""
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
description=(
|
description=(
|
||||||
"PyTorch TPU distributed training launch "
|
"PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
|
||||||
"helper utility that will spawn up "
|
|
||||||
"multiple distributed processes"
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -77,8 +77,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -110,7 +112,8 @@ def main():
|
|||||||
and not training_args.overwrite_output_dir
|
and not training_args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
|
||||||
|
" --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
@ -197,7 +197,7 @@ if is_tf_available():
|
|||||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||||
if ex_index % 10000 == 0:
|
if ex_index % 10000 == 0:
|
||||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||||
|
|
||||||
@ -268,7 +268,7 @@ class HansProcessor(DataProcessor):
|
|||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training and dev sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
@ -303,7 +303,7 @@ def hans_convert_examples_to_features(
|
|||||||
label_map = {label: i for i, label in enumerate(label_list)}
|
label_map = {label: i for i, label in enumerate(label_list)}
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||||
if ex_index % 10000 == 0:
|
if ex_index % 10000 == 0:
|
||||||
logger.info("Writing example %d" % (ex_index))
|
logger.info("Writing example %d" % (ex_index))
|
||||||
|
|
||||||
|
@ -84,7 +84,10 @@ class AlbertModelWithPabee(AlbertModel):
|
|||||||
|
|
||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
||||||
message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
message = (
|
||||||
|
f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
|
||||||
|
f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
||||||
|
)
|
||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
||||||
|
@ -89,7 +89,10 @@ class BertModelWithPabee(BertModel):
|
|||||||
|
|
||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
||||||
message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
message = (
|
||||||
|
f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
|
||||||
|
f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
|
||||||
|
)
|
||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
||||||
|
@ -483,8 +483,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help=(
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -574,8 +576,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--local_rank",
|
"--local_rank",
|
||||||
|
@ -325,7 +325,8 @@ def main():
|
|||||||
|
|
||||||
if not documents_dir_is_valid(args.documents_dir):
|
if not documents_dir_is_valid(args.documents_dir):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
"We could not find the directory you specified for the documents to summarize, or it was empty. Please"
|
||||||
|
" specify a valid path."
|
||||||
)
|
)
|
||||||
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
@ -338,8 +338,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
help=(
|
||||||
"Sequences longer than this will be truncated, sequences shorter padded.",
|
"The maximum total input sequence length after WordPiece tokenization. \n"
|
||||||
|
"Sequences longer than this will be truncated, sequences shorter padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||||
|
|
||||||
|
@ -314,8 +314,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
help=(
|
||||||
"Sequences longer than this will be truncated, sequences shorter padded.",
|
"The maximum total input sequence length after WordPiece tokenization. \n"
|
||||||
|
"Sequences longer than this will be truncated, sequences shorter padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||||
|
|
||||||
|
@ -112,7 +112,10 @@ class HumanEvalArguments:
|
|||||||
device_int: Optional[int] = field(
|
device_int: Optional[int] = field(
|
||||||
default=-1,
|
default=-1,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive number corresponds to which GPU device id to run on."
|
"help": (
|
||||||
|
"Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive"
|
||||||
|
" number corresponds to which GPU device id to run on."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -186,7 +186,8 @@ def main():
|
|||||||
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
|
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
|
||||||
except ValueError as exception:
|
except ValueError as exception:
|
||||||
print(
|
print(
|
||||||
'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"` flag to enable code evaluation.'
|
'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"`'
|
||||||
|
" flag to enable code evaluation."
|
||||||
)
|
)
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
|
@ -459,8 +459,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help=(
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -529,8 +531,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
|
@ -60,7 +60,7 @@ class GroupedBatchSampler(BatchSampler):
|
|||||||
def __init__(self, sampler, group_ids, batch_size):
|
def __init__(self, sampler, group_ids, batch_size):
|
||||||
if not isinstance(sampler, Sampler):
|
if not isinstance(sampler, Sampler):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
"sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||||
)
|
)
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.group_ids = group_ids
|
self.group_ids = group_ids
|
||||||
|
@ -518,7 +518,10 @@ def main():
|
|||||||
"--teacher_type",
|
"--teacher_type",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
help=(
|
||||||
|
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
|
||||||
|
" distillation."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--teacher_name_or_path",
|
"--teacher_name_or_path",
|
||||||
@ -590,8 +593,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=384,
|
default=384,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help=(
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
|
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--doc_stride",
|
"--doc_stride",
|
||||||
@ -603,8 +608,10 @@ def main():
|
|||||||
"--max_query_length",
|
"--max_query_length",
|
||||||
default=64,
|
default=64,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
help=(
|
||||||
"be truncated to this length.",
|
"The maximum number of tokens for the question. Questions longer than this will "
|
||||||
|
"be truncated to this length."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -649,14 +656,18 @@ def main():
|
|||||||
"--max_answer_length",
|
"--max_answer_length",
|
||||||
default=30,
|
default=30,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help=(
|
||||||
"and end predictions are not conditioned on one another.",
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
|
"and end predictions are not conditioned on one another."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose_logging",
|
"--verbose_logging",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help=(
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
"If true, all of the warnings related to data processing will be printed. "
|
||||||
|
"A number of warnings are expected for a normal SQuAD evaluation."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
@ -685,8 +696,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
@ -25,7 +25,10 @@ from transformers import GPT2LMHeadModel, RobertaForMaskedLM
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
description=(
|
||||||
|
"Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
|
||||||
|
" Distillation"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
||||||
parser.add_argument("--model_name", default="roberta-large", type=str)
|
parser.add_argument("--model_name", default="roberta-large", type=str)
|
||||||
|
@ -25,7 +25,10 @@ from transformers import BertForMaskedLM
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
description=(
|
||||||
|
"Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
|
||||||
|
" Distillation"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||||
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
||||||
|
@ -207,8 +207,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
||||||
@ -226,8 +228,8 @@ def main():
|
|||||||
if os.path.exists(args.dump_path):
|
if os.path.exists(args.dump_path):
|
||||||
if not args.force:
|
if not args.force:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
|
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
|
||||||
"Use `--force` if you want to overwrite it"
|
" itUse `--force` if you want to overwrite it"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shutil.rmtree(args.dump_path)
|
shutil.rmtree(args.dump_path)
|
||||||
|
@ -48,7 +48,8 @@ class FSNERTokenizerUtils(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of strings` for supports are supported."
|
"Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of"
|
||||||
|
" strings` for supports are supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
@ -75,8 +75,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -99,7 +100,10 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,8 +145,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated. Default to the max input length of the model."
|
"than this will be truncated. Default to the max input length of the model."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -155,8 +161,10 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
line_by_line: bool = field(
|
line_by_line: bool = field(
|
||||||
@ -575,7 +583,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if step % training_args.logging_steps == 0 and step > 0:
|
if step % training_args.logging_steps == 0 and step > 0:
|
||||||
steps.write(
|
steps.write(
|
||||||
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate'].mean()})"
|
||||||
)
|
)
|
||||||
train_time += time.time() - train_start
|
train_time += time.time() - train_start
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
@ -604,7 +613,10 @@ if __name__ == "__main__":
|
|||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
steps.desc = (
|
||||||
|
f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
|
||||||
|
f" {eval_metrics['accuracy']})"
|
||||||
|
)
|
||||||
|
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
write_eval_metric(summary_writer, eval_metrics, step)
|
write_eval_metric(summary_writer, eval_metrics, step)
|
||||||
|
@ -77,14 +77,18 @@ class ModelArguments:
|
|||||||
|
|
||||||
text_model_name_or_path: str = field(
|
text_model_name_or_path: str = field(
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The text model checkpoint for weights initialization."
|
"help": (
|
||||||
|
"The text model checkpoint for weights initialization."
|
||||||
"Don't set if you want to train a model from scratch."
|
"Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
vision_model_name_or_path: str = field(
|
vision_model_name_or_path: str = field(
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The vision model checkpoint for weights initialization."
|
"help": (
|
||||||
|
"The vision model checkpoint for weights initialization."
|
||||||
"Don't set if you want to train a model from scratch."
|
"Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
from_pt: bool = field(
|
from_pt: bool = field(
|
||||||
@ -107,7 +111,10 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,22 +136,28 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=72,
|
default=72,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -519,7 +532,8 @@ def main():
|
|||||||
|
|
||||||
train_step_progress_bar.close()
|
train_step_progress_bar.close()
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
|
@ -69,8 +69,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -93,7 +94,10 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -118,15 +122,19 @@ class DataTrainingArguments:
|
|||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -141,9 +149,11 @@ class DataTrainingArguments:
|
|||||||
block_size: Optional[int] = field(
|
block_size: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Optional input sequence length after tokenization. "
|
"help": (
|
||||||
|
"Optional input sequence length after tokenization. "
|
||||||
"The training dataset will be truncated in block of this size for training. "
|
"The training dataset will be truncated in block of this size for training. "
|
||||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
@ -334,7 +344,8 @@ def main():
|
|||||||
# clm input could be much much longer than block_size
|
# clm input could be much much longer than block_size
|
||||||
if "Token indices sequence length is longer than the" in cl.out:
|
if "Token indices sequence length is longer than the" in cl.out:
|
||||||
tok_logger.warning(
|
tok_logger.warning(
|
||||||
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
||||||
|
" before being passed to the model."
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -606,7 +617,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
@ -632,7 +644,8 @@ def main():
|
|||||||
eval_metrics["perplexity"] = float("inf")
|
eval_metrics["perplexity"] = float("inf")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
|
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity:"
|
||||||
|
f" {eval_metrics['perplexity']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
|
@ -64,7 +64,10 @@ class ModelArguments:
|
|||||||
dtype: Optional[str] = field(
|
dtype: Optional[str] = field(
|
||||||
default="float32",
|
default="float32",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
"help": (
|
||||||
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||||
|
" `[float32, float16, bfloat16]`."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,7 +97,9 @@ class DataTrainingArguments:
|
|||||||
validation_split_name: Optional[str] = field(
|
validation_split_name: Optional[str] = field(
|
||||||
default="validation",
|
default="validation",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
"help": (
|
||||||
|
"The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
speech_file_column: Optional[str] = field(
|
speech_file_column: Optional[str] = field(
|
||||||
@ -120,7 +125,10 @@ class DataTrainingArguments:
|
|||||||
pad_to_multiple_of: Optional[int] = field(
|
pad_to_multiple_of: Optional[int] = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
|
"help": (
|
||||||
|
"If set will pad the sequence to a multiple of the provided value. This is important to avoid"
|
||||||
|
" triggering recompilations on TPU"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -357,7 +365,8 @@ def main():
|
|||||||
|
|
||||||
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
|
||||||
|
" ``config.feat_extract_norm='layer'"
|
||||||
)
|
)
|
||||||
|
|
||||||
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||||
@ -557,7 +566,8 @@ def main():
|
|||||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||||
|
f" {train_metric['learning_rate'].mean()})"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
@ -583,7 +593,8 @@ def main():
|
|||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity:"
|
||||||
|
f" {eval_metrics['codevector_perplexity']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save metrics
|
# Save metrics
|
||||||
|
@ -649,7 +649,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
|
|||||||
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
||||||
]
|
]
|
||||||
all_res_lists = []
|
all_res_lists = []
|
||||||
for (res_passages, dl) in zip(res_passages_lst, D):
|
for res_passages, dl in zip(res_passages_lst, D):
|
||||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||||
for r, sc in zip(res_list, dl):
|
for r, sc in zip(res_list, dl):
|
||||||
r["score"] = float(sc)
|
r["score"] = float(sc)
|
||||||
@ -679,7 +679,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
|
|||||||
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
|
||||||
]
|
]
|
||||||
all_res_lists = []
|
all_res_lists = []
|
||||||
for (res_passages, dl, il) in zip(res_passages_lst, D, I):
|
for res_passages, dl, il in zip(res_passages_lst, D, I):
|
||||||
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
|
||||||
for r, sc, i in zip(res_list, dl, il):
|
for r, sc, i in zip(res_list, dl, il):
|
||||||
r["passage_id"] = int(i)
|
r["passage_id"] = int(i)
|
||||||
|
@ -101,8 +101,8 @@ def parse_args():
|
|||||||
type=int,
|
type=int,
|
||||||
default=32,
|
default=32,
|
||||||
help=(
|
help=(
|
||||||
"The maximum total input entity length after tokenization (Used only for (M)Luke models). Sequences longer than this will be truncated,"
|
"The maximum total input entity length after tokenization (Used only for (M)Luke models). Sequences longer"
|
||||||
" sequences shorter will be padded if `--pad_to_max_length` is passed."
|
" than this will be truncated, sequences shorter will be padded if `--pad_to_max_length` is passed."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -110,8 +110,8 @@ def parse_args():
|
|||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help=(
|
||||||
"The maximum total input mention length after tokenization (Used only for (M)Luke models). Sequences longer than this will be truncated,"
|
"The maximum total input mention length after tokenization (Used only for (M)Luke models). Sequences"
|
||||||
" sequences shorter will be padded if `--pad_to_max_length` is passed."
|
" longer than this will be truncated, sequences shorter will be padded if `--pad_to_max_length` is passed."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -592,7 +592,7 @@ class Matcher(object):
|
|||||||
|
|
||||||
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
||||||
|
|
||||||
for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
||||||
low_high = (matched_vals >= low) & (matched_vals < high)
|
low_high = (matched_vals >= low) & (matched_vals < high)
|
||||||
match_labels[low_high] = l
|
match_labels[low_high] = l
|
||||||
|
|
||||||
@ -1037,9 +1037,9 @@ class ResNet(Backbone):
|
|||||||
curr_kwargs = {}
|
curr_kwargs = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if k.endswith("_per_block"):
|
if k.endswith("_per_block"):
|
||||||
assert len(v) == num_blocks, (
|
assert (
|
||||||
f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
|
len(v) == num_blocks
|
||||||
)
|
), f"Argument '{k}' of make_stage should have the same length as num_blocks={num_blocks}."
|
||||||
newk = k[: -len("_per_block")]
|
newk = k[: -len("_per_block")]
|
||||||
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
||||||
curr_kwargs[newk] = v[i]
|
curr_kwargs[newk] = v[i]
|
||||||
@ -1401,7 +1401,7 @@ class AnchorGenerator(nn.Module):
|
|||||||
|
|
||||||
def grid_anchors(self, grid_sizes):
|
def grid_anchors(self, grid_sizes):
|
||||||
anchors = []
|
anchors = []
|
||||||
for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
|
for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
|
||||||
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
|
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
|
||||||
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
||||||
|
|
||||||
@ -1708,10 +1708,9 @@ class GeneralizedRCNN(nn.Module):
|
|||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
assert (
|
assert from_tf, (
|
||||||
from_tf
|
"We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
|
||||||
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
.format(pretrained_model_name_or_path + ".index")
|
||||||
pretrained_model_name_or_path + ".index"
|
|
||||||
)
|
)
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
@ -1798,25 +1797,27 @@ class GeneralizedRCNN(nn.Module):
|
|||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
print(
|
print(
|
||||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
||||||
|
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
print(
|
print(
|
||||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||||
f"and are newly initialized: {missing_keys}\n"
|
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
||||||
|
" training."
|
||||||
)
|
)
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -231,9 +231,10 @@ def compare(in_tensor):
|
|||||||
n2 = out_tensor.numpy()[0]
|
n2 = out_tensor.numpy()[0]
|
||||||
print(n1.shape, n1[0, 0, :5])
|
print(n1.shape, n1[0, 0, :5])
|
||||||
print(n2.shape, n2[0, 0, :5])
|
print(n2.shape, n2[0, 0, :5])
|
||||||
assert np.allclose(
|
assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
|
||||||
n1, n2, rtol=0.01, atol=0.1
|
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
|
||||||
), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
|
" element-wise mismatch"
|
||||||
|
)
|
||||||
raise Exception("tensors are all good")
|
raise Exception("tensors are all good")
|
||||||
|
|
||||||
# Hugging face functions below
|
# Hugging face functions below
|
||||||
|
@ -61,8 +61,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -72,8 +73,10 @@ class ModelArguments:
|
|||||||
config_overrides: Optional[str] = field(
|
config_overrides: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
|
"help": (
|
||||||
|
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
@ -97,8 +100,10 @@ class ModelArguments:
|
|||||||
use_auth_token: bool = field(
|
use_auth_token: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
"help": (
|
||||||
|
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
"with private models)."
|
"with private models)."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -146,8 +151,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated. Default to the max input length of the model."
|
"than this will be truncated. Default to the max input length of the model."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -160,8 +167,10 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -356,8 +356,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help=(
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
|
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
|
||||||
@ -423,8 +425,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
|
@ -103,15 +103,20 @@ if __name__ == "__main__":
|
|||||||
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
|
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
|
help=(
|
||||||
|
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
|
||||||
|
" sigmoied_threshold = Soft movement pruning)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--threshold",
|
"--threshold",
|
||||||
type=float,
|
type=float,
|
||||||
required=False,
|
required=False,
|
||||||
help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
help=(
|
||||||
|
"For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
||||||
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
||||||
"Not needed for `l0`",
|
"Not needed for `l0`"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
|
@ -70,15 +70,20 @@ if __name__ == "__main__":
|
|||||||
choices=["l0", "topK", "sigmoied_threshold"],
|
choices=["l0", "topK", "sigmoied_threshold"],
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
|
help=(
|
||||||
|
"Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement"
|
||||||
|
" pruning)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--threshold",
|
"--threshold",
|
||||||
type=float,
|
type=float,
|
||||||
required=False,
|
required=False,
|
||||||
help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
help=(
|
||||||
|
"For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
|
||||||
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
|
||||||
"Not needed for `l0`",
|
"Not needed for `l0`"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--serialization_dir",
|
"--serialization_dir",
|
||||||
|
@ -80,8 +80,8 @@ class BertSelfAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
|
||||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
% (config.hidden_size, config.num_attention_heads)
|
||||||
)
|
)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
|
@ -622,8 +622,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help=(
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -669,22 +671,29 @@ def main():
|
|||||||
"--initial_warmup",
|
"--initial_warmup",
|
||||||
default=1,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
|
help=(
|
||||||
"at its `initial_threshold` value (sparsity schedule).",
|
"Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
|
||||||
|
"at its `initial_threshold` value (sparsity schedule)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--final_warmup",
|
"--final_warmup",
|
||||||
default=2,
|
default=2,
|
||||||
type=int,
|
type=int,
|
||||||
help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
|
help=(
|
||||||
"at its final_threshold value (sparsity schedule).",
|
"Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
|
||||||
|
"at its final_threshold value (sparsity schedule)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pruning_method",
|
"--pruning_method",
|
||||||
default="topK",
|
default="topK",
|
||||||
type=str,
|
type=str,
|
||||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
|
help=(
|
||||||
|
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
|
||||||
|
" sigmoied_threshold = Soft movement pruning)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask_init",
|
"--mask_init",
|
||||||
@ -717,7 +726,10 @@ def main():
|
|||||||
"--teacher_type",
|
"--teacher_type",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
help=(
|
||||||
|
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
|
||||||
|
" distillation."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--teacher_name_or_path",
|
"--teacher_name_or_path",
|
||||||
@ -787,8 +799,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
|
|
||||||
@ -805,7 +819,8 @@ def main():
|
|||||||
and not args.overwrite_output_dir
|
and not args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to"
|
||||||
|
" overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup CUDA, GPU & distributed training
|
# Setup CUDA, GPU & distributed training
|
||||||
|
@ -737,8 +737,10 @@ def main():
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=384,
|
default=384,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help=(
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
|
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--doc_stride",
|
"--doc_stride",
|
||||||
@ -750,8 +752,10 @@ def main():
|
|||||||
"--max_query_length",
|
"--max_query_length",
|
||||||
default=64,
|
default=64,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
help=(
|
||||||
"be truncated to this length.",
|
"The maximum number of tokens for the question. Questions longer than this will "
|
||||||
|
"be truncated to this length."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
@ -785,22 +789,29 @@ def main():
|
|||||||
"--initial_warmup",
|
"--initial_warmup",
|
||||||
default=1,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
|
help=(
|
||||||
"at its `initial_threshold` value (sparsity schedule).",
|
"Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
|
||||||
|
"at its `initial_threshold` value (sparsity schedule)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--final_warmup",
|
"--final_warmup",
|
||||||
default=2,
|
default=2,
|
||||||
type=int,
|
type=int,
|
||||||
help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
|
help=(
|
||||||
"at its final_threshold value (sparsity schedule).",
|
"Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
|
||||||
|
"at its final_threshold value (sparsity schedule)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pruning_method",
|
"--pruning_method",
|
||||||
default="topK",
|
default="topK",
|
||||||
type=str,
|
type=str,
|
||||||
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
|
help=(
|
||||||
|
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
|
||||||
|
" sigmoied_threshold = Soft movement pruning)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask_init",
|
"--mask_init",
|
||||||
@ -833,7 +844,10 @@ def main():
|
|||||||
"--teacher_type",
|
"--teacher_type",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
help=(
|
||||||
|
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
|
||||||
|
" distillation."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--teacher_name_or_path",
|
"--teacher_name_or_path",
|
||||||
@ -883,20 +897,27 @@ def main():
|
|||||||
"--max_answer_length",
|
"--max_answer_length",
|
||||||
default=30,
|
default=30,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help=(
|
||||||
"and end predictions are not conditioned on one another.",
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
|
"and end predictions are not conditioned on one another."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose_logging",
|
"--verbose_logging",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help=(
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
"If true, all of the warnings related to data processing will be printed. "
|
||||||
|
"A number of warnings are expected for a normal SQuAD evaluation."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang_id",
|
"--lang_id",
|
||||||
default=0,
|
default=0,
|
||||||
type=int,
|
type=int,
|
||||||
help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
|
help=(
|
||||||
|
"language id of input for language-specific xlm models (see"
|
||||||
|
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
||||||
@ -925,8 +946,10 @@ def main():
|
|||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help=(
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
@ -392,13 +392,14 @@ class BeamSearchScorerTS(torch.nn.Module):
|
|||||||
|
|
||||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
|
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
||||||
|
" one should make use of `greedy_search` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
|
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
||||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def hypo_len(self, hypo_idx: int):
|
def hypo_len(self, hypo_idx: int):
|
||||||
@ -508,7 +509,8 @@ class BeamSearchScorerTS(torch.nn.Module):
|
|||||||
|
|
||||||
if beam_idx < self.group_size:
|
if beam_idx < self.group_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
||||||
|
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we are done so that we can save a pad step if all(done)
|
# Check if we are done so that we can save a pad step if all(done)
|
||||||
|
@ -53,14 +53,16 @@ def parse_args():
|
|||||||
"--max_length",
|
"--max_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=("The maximum total input sequence length after tokenization."),
|
help="The maximum total input sequence length after tokenization.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_beams",
|
"--num_beams",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Number of beams to use for evaluation. This argument will be "
|
help=(
|
||||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
"Number of beams to use for evaluation. This argument will be "
|
||||||
|
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
|
@ -535,7 +535,7 @@ class FastAttentionviaLowRankDecomposition(FastAttention):
|
|||||||
assert key.ndim == value.ndim
|
assert key.ndim == value.ndim
|
||||||
for ax in axis:
|
for ax in axis:
|
||||||
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
|
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
|
||||||
raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
|
raise ValueError("Attention axis must be between the batch axis and the last-two axes.")
|
||||||
n = key.ndim
|
n = key.ndim
|
||||||
|
|
||||||
# Constructing projection tensor.
|
# Constructing projection tensor.
|
||||||
|
@ -98,8 +98,9 @@ class ModelArguments:
|
|||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model checkpoint for weights initialization."
|
"help": (
|
||||||
"Don't set if you want to train a model from scratch."
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
performer: bool = field(
|
performer: bool = field(
|
||||||
@ -159,8 +160,10 @@ class DataTrainingArguments:
|
|||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": (
|
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated. Default to the max input length of the model."
|
"than this will be truncated. Default to the max input length of the model."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@ -173,8 +176,10 @@ class DataTrainingArguments:
|
|||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": (
|
||||||
|
"Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -175,8 +175,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
|
|||||||
test_loss /= len(data_loader.dataset)
|
test_loss /= len(data_loader.dataset)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
"Performance on test set: "
|
"Performance on test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
||||||
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
|
||||||
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
|
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -309,7 +308,7 @@ def train_discriminator(
|
|||||||
x.append(seq)
|
x.append(seq)
|
||||||
y.append(d["label"])
|
y.append(d["label"])
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
print("Error evaluating / tokenizing line {}, skipping it".format(i))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
full_dataset = Dataset(x, y)
|
full_dataset = Dataset(x, y)
|
||||||
@ -349,7 +348,7 @@ def train_discriminator(
|
|||||||
x.append(seq)
|
x.append(seq)
|
||||||
y.append(int(np.sum(d["label"]) > 0))
|
y.append(int(np.sum(d["label"]) > 0))
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
print("Error evaluating / tokenizing line {}, skipping it".format(i))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
full_dataset = Dataset(x, y)
|
full_dataset = Dataset(x, y)
|
||||||
@ -370,7 +369,7 @@ def train_discriminator(
|
|||||||
# class \t text
|
# class \t text
|
||||||
|
|
||||||
if dataset_fp is None:
|
if dataset_fp is None:
|
||||||
raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
|
raise ValueError("When generic dataset is selected, dataset_fp needs to be specified aswell.")
|
||||||
|
|
||||||
classes = set()
|
classes = set()
|
||||||
with open(dataset_fp) as f:
|
with open(dataset_fp) as f:
|
||||||
@ -490,15 +489,17 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default="SST",
|
default="SST",
|
||||||
choices=("SST", "clickbait", "toxic", "generic"),
|
choices=("SST", "clickbait", "toxic", "generic"),
|
||||||
help="dataset to train the discriminator on."
|
help=(
|
||||||
|
"dataset to train the discriminator on."
|
||||||
"In case of generic, the dataset is expected"
|
"In case of generic, the dataset is expected"
|
||||||
"to be a TSBV file with structure: class \\t text",
|
"to be a TSBV file with structure: class \\t text"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset_fp",
|
"--dataset_fp",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="File path of the dataset to use. " "Needed only in case of generic datadset",
|
help="File path of the dataset to use. Needed only in case of generic datadset",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
|
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
|
||||||
|
@ -87,8 +87,10 @@ parser.add_argument(
|
|||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=384,
|
default=384,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help=(
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
|
"longer than this will be truncated, and sequences shorter than this will be padded."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--doc_stride",
|
"--doc_stride",
|
||||||
@ -109,8 +111,10 @@ parser.add_argument(
|
|||||||
"--max_answer_length",
|
"--max_answer_length",
|
||||||
default=30,
|
default=30,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help=(
|
||||||
"and end predictions are not conditioned on one another.",
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
|
"and end predictions are not conditioned on one another."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user