mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Only resize embeddings when necessary (#20043)
* Only resize embeddings when necessary * Add comment
This commit is contained in:
parent
9080607b2c
commit
06886d5a68
@ -387,7 +387,11 @@ def main():
|
|||||||
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
||||||
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
|
@ -378,7 +378,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
|
@ -389,7 +389,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForMaskedLM.from_config(config)
|
model = AutoModelForMaskedLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
|
@ -383,7 +383,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForMaskedLM.from_config(config)
|
model = AutoModelForMaskedLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
|
@ -376,7 +376,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = XLNetLMHeadModel(config)
|
model = XLNetLMHeadModel(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
|
@ -398,7 +398,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForMultipleChoice.from_config(config)
|
model = AutoModelForMultipleChoice.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
|
@ -380,7 +380,11 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
@ -422,7 +422,11 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||||
if isinstance(tokenizer, MBartTokenizer):
|
if isinstance(tokenizer, MBartTokenizer):
|
||||||
|
@ -439,7 +439,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForSeq2SeqLM.from_config(config)
|
model = AutoModelForSeq2SeqLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
|
||||||
|
@ -414,7 +414,13 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForTokenClassification.from_config(config)
|
model = AutoModelForTokenClassification.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Model has labels -> use them.
|
# Model has labels -> use them.
|
||||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||||
|
@ -380,7 +380,11 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Set decoder_start_token_id
|
# Set decoder_start_token_id
|
||||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||||
|
@ -411,7 +411,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = AutoModelForSeq2SeqLM.from_config(config)
|
model = AutoModelForSeq2SeqLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# Set decoder_start_token_id
|
# Set decoder_start_token_id
|
||||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||||
|
@ -473,7 +473,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = TFAutoModelForCausalLM.from_config(config)
|
model = TFAutoModelForCausalLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region TF Dataset preparation
|
# region TF Dataset preparation
|
||||||
|
@ -489,7 +489,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = TFAutoModelForMaskedLM.from_config(config)
|
model = TFAutoModelForMaskedLM.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region TF Dataset preparation
|
# region TF Dataset preparation
|
||||||
|
@ -516,7 +516,11 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Prepare TF Dataset objects
|
# region Prepare TF Dataset objects
|
||||||
|
@ -385,7 +385,11 @@ def main():
|
|||||||
logger.info("Training new model from scratch")
|
logger.info("Training new model from scratch")
|
||||||
model = TFAutoModelForTokenClassification.from_config(config)
|
model = TFAutoModelForTokenClassification.from_config(config)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Create TF datasets
|
# region Create TF datasets
|
||||||
|
@ -469,7 +469,11 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
|
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
|
||||||
model.config.forced_bos_token_id = forced_bos_token_id
|
model.config.forced_bos_token_id = forced_bos_token_id
|
||||||
# endregion
|
# endregion
|
||||||
|
Loading…
Reference in New Issue
Block a user