mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixing XLM conversion tests with dummy input
This commit is contained in:
parent
fafd4c86ec
commit
f19dad61c7
@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
|
||||
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
||||
|
||||
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
||||
logger.info("PyTorch checkpoint contains {:,} parameters".format(sum(t.numel() for t in pt_state_dict.values())))
|
||||
|
||||
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
|
||||
|
||||
@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
start_prefix_to_remove = tf_model.base_model_prefix + '.'
|
||||
|
||||
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
||||
|
||||
tf_loaded_numel = 0
|
||||
weight_value_tuples = []
|
||||
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
||||
for symbolic_weight in symbolic_weights:
|
||||
@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
e.args += (symbolic_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
tf_loaded_numel += array.size
|
||||
# logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
|
||||
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
if tf_inputs is not None:
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
|
||||
|
||||
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
|
||||
|
||||
return tf_model
|
||||
|
@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
|
||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
else:
|
||||
langs_list = None
|
||||
return [inputs_list, attns_list, langs_list]
|
||||
return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list}
|
||||
|
||||
|
||||
XLM_START_DOCSTRING = r""" The XLM model was proposed in
|
||||
|
@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
else:
|
||||
langs_list = None
|
||||
return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list}
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights. """
|
||||
if isinstance(module, nn.Embedding):
|
||||
@ -646,7 +656,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
Loading…
Reference in New Issue
Block a user