diff --git a/transformers/modeling_tf_pytorch_utils.py b/transformers/modeling_tf_pytorch_utils.py index 510e130c90e..9d2b663dcb3 100644 --- a/transformers/modeling_tf_pytorch_utils.py +++ b/transformers/modeling_tf_pytorch_utils.py @@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i logger.info("Loading PyTorch weights from {}".format(pt_path)) pt_state_dict = torch.load(pt_path, map_location='cpu') + logger.info("PyTorch checkpoint contains {:,} parameters".format(sum(t.numel() for t in pt_state_dict.values()))) return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) @@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a start_prefix_to_remove = tf_model.base_model_prefix + '.' symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights - + tf_loaded_numel = 0 weight_value_tuples = [] all_pytorch_weights = set(list(pt_state_dict.keys())) for symbolic_weight in symbolic_weights: @@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a e.args += (symbolic_weight.shape, array.shape) raise e + tf_loaded_numel += array.size # logger.warning("Initialize TF weight {}".format(symbolic_weight.name)) weight_value_tuples.append((symbolic_weight, array)) @@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a if tf_inputs is not None: tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run + logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel)) + logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights)) return tf_model diff --git a/transformers/modeling_tf_xlm.py b/transformers/modeling_tf_xlm.py index 6f11b0537df..903a8596c3e 100644 --- a/transformers/modeling_tf_xlm.py +++ b/transformers/modeling_tf_xlm.py @@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) else: langs_list = None - return [inputs_list, attns_list, langs_list] + return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list} XLM_START_DOCSTRING = r""" The XLM model was proposed in diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 257f0da394b..b604ae669d6 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs) + @property + def dummy_inputs(self): + inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) + attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + if self.config.use_lang_emb and self.config.n_langs > 1: + langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + else: + langs_list = None + return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list} + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, nn.Embedding): @@ -646,7 +656,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, - lengths=lengths, + lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds)