mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
TF clearer model variable naming: xlnet (#16150)
This commit is contained in:
parent
a23a7c0cd6
commit
015de6f081
@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...utils import logging
|
||||
@ -578,6 +578,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
return pos_emb
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -596,63 +597,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
token_type_ids=token_type_ids,
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_mems=use_mems,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if training and inputs["use_mems"] is None:
|
||||
inputs["use_mems"] = self.use_mems_train
|
||||
if training and use_mems is None:
|
||||
use_mems = self.use_mems_train
|
||||
else:
|
||||
inputs["use_mems"] = self.use_mems_eval
|
||||
use_mems = self.use_mems_eval
|
||||
|
||||
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
||||
# but we want a unified interface in the library with the batch size on the first dimension
|
||||
# so we move here the first dimension (batch) to the end
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
|
||||
qlen, bsz = shape_list(inputs["input_ids"])[:2]
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
|
||||
elif input_ids is not None:
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
qlen, bsz = shape_list(input_ids)[:2]
|
||||
elif inputs_embeds is not None:
|
||||
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
inputs["token_type_ids"] = (
|
||||
tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
inputs["input_mask"] = (
|
||||
tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None
|
||||
)
|
||||
inputs["attention_mask"] = (
|
||||
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
inputs["perm_mask"] = (
|
||||
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
|
||||
)
|
||||
inputs["target_mapping"] = (
|
||||
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
|
||||
)
|
||||
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
|
||||
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
|
||||
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
|
||||
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
|
||||
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
|
||||
|
||||
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
|
||||
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
# Attention mask
|
||||
@ -666,19 +638,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
raise ValueError(f"Unsupported attention type: {self.attn_type}")
|
||||
|
||||
# data mask: input mask & perm mask
|
||||
assert inputs["input_mask"] is None or inputs["attention_mask"] is None, (
|
||||
assert input_mask is None or attention_mask is None, (
|
||||
"You can only use one of input_mask (uses 1 for padding) "
|
||||
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
|
||||
)
|
||||
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
|
||||
if input_mask is None and attention_mask is not None:
|
||||
one_cst = tf.constant(1.0)
|
||||
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
|
||||
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
|
||||
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
|
||||
data_mask = inputs["input_mask"][None]
|
||||
elif inputs["input_mask"] is None and inputs["perm_mask"] is not None:
|
||||
data_mask = inputs["perm_mask"]
|
||||
input_mask = 1.0 - tf.cast(attention_mask, dtype=one_cst.dtype)
|
||||
if input_mask is not None and perm_mask is not None:
|
||||
data_mask = input_mask[None] + perm_mask
|
||||
elif input_mask is not None and perm_mask is None:
|
||||
data_mask = input_mask[None]
|
||||
elif input_mask is None and perm_mask is not None:
|
||||
data_mask = perm_mask
|
||||
else:
|
||||
data_mask = None
|
||||
|
||||
@ -704,33 +676,33 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
non_tgt_mask = None
|
||||
|
||||
# Word embeddings and prepare h & g hidden states
|
||||
if inputs["inputs_embeds"] is not None:
|
||||
word_emb_k = inputs["inputs_embeds"]
|
||||
if inputs_embeds is not None:
|
||||
word_emb_k = inputs_embeds
|
||||
else:
|
||||
word_emb_k = self.word_embedding(inputs["input_ids"])
|
||||
output_h = self.dropout(word_emb_k, training=inputs["training"])
|
||||
if inputs["target_mapping"] is not None:
|
||||
word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1])
|
||||
word_emb_k = self.word_embedding(input_ids)
|
||||
output_h = self.dropout(word_emb_k, training=training)
|
||||
if target_mapping is not None:
|
||||
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
|
||||
# else: # We removed the inp_q input which was same as target mapping
|
||||
# inp_q_ext = inp_q[:, :, None]
|
||||
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
|
||||
output_g = self.dropout(word_emb_q, training=inputs["training"])
|
||||
output_g = self.dropout(word_emb_q, training=training)
|
||||
else:
|
||||
output_g = None
|
||||
|
||||
# Segment embedding
|
||||
if inputs["token_type_ids"] is not None:
|
||||
if token_type_ids is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
if mlen > 0:
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=inputs["token_type_ids"].dtype)
|
||||
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=token_type_ids.dtype)
|
||||
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
|
||||
else:
|
||||
cat_ids = inputs["token_type_ids"]
|
||||
cat_ids = token_type_ids
|
||||
|
||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||
seg_mat = tf.cast(
|
||||
tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])),
|
||||
dtype=inputs["token_type_ids"].dtype,
|
||||
tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])),
|
||||
dtype=token_type_ids.dtype,
|
||||
)
|
||||
seg_mat = tf.one_hot(seg_mat, 2)
|
||||
else:
|
||||
@ -738,29 +710,29 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# Positional encoding
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
||||
pos_emb = self.dropout(pos_emb, training=inputs["training"])
|
||||
pos_emb = self.dropout(pos_emb, training=training)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
||||
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
||||
if inputs["head_mask"] is not None:
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
inputs["head_mask"] = [None] * self.n_layer
|
||||
head_mask = [None] * self.n_layer
|
||||
|
||||
new_mems = ()
|
||||
if inputs["mems"] is None:
|
||||
inputs["mems"] = [None] * len(self.layer)
|
||||
if mems is None:
|
||||
mems = [None] * len(self.layer)
|
||||
|
||||
attentions = [] if inputs["output_attentions"] else None
|
||||
hidden_states = [] if inputs["output_hidden_states"] else None
|
||||
attentions = [] if output_attentions else None
|
||||
hidden_states = [] if output_hidden_states else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
if inputs["use_mems"]:
|
||||
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
|
||||
if inputs["output_hidden_states"]:
|
||||
if use_mems:
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
outputs = layer_module(
|
||||
@ -770,34 +742,34 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
attn_mask,
|
||||
pos_emb,
|
||||
seg_mat,
|
||||
inputs["mems"][i],
|
||||
inputs["target_mapping"],
|
||||
inputs["head_mask"][i],
|
||||
inputs["output_attentions"],
|
||||
training=inputs["training"],
|
||||
mems[i],
|
||||
target_mapping,
|
||||
head_mask[i],
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
output_h, output_g = outputs[:2]
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
attentions.append(outputs[2])
|
||||
|
||||
# Add last hidden state
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
|
||||
output = self.dropout(output_g if output_g is not None else output_h, training=training)
|
||||
|
||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
output = tf.transpose(output, perm=(1, 0, 2))
|
||||
|
||||
if not inputs["use_mems"]:
|
||||
if not use_mems:
|
||||
new_mems = None
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
if output_g is not None:
|
||||
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
|
||||
else:
|
||||
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
|
||||
if inputs["output_attentions"]:
|
||||
if inputs["target_mapping"] is not None:
|
||||
if output_attentions:
|
||||
if target_mapping is not None:
|
||||
# when target_mapping is provided, there are 2-tuple of attentions
|
||||
attentions = tuple(
|
||||
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
|
||||
@ -805,7 +777,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
|
||||
|
||||
return TFXLNetModelOutput(
|
||||
@ -1154,6 +1126,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFXLNetMainLayer(config, name="transformer")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1179,9 +1152,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
@ -1196,23 +1167,6 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_mems=inputs["use_mems"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
@ -1286,6 +1240,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
return inputs
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
@ -1349,9 +1304,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
... 0
|
||||
>>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
|
||||
```"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
@ -1365,34 +1318,16 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_mems=inputs["use_mems"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_state = transformer_outputs[0]
|
||||
logits = self.lm_loss(hidden_state, training=inputs["training"])
|
||||
logits = self.lm_loss(hidden_state, training=training)
|
||||
|
||||
loss = None
|
||||
if inputs["labels"] is not None:
|
||||
loss = self.hf_compute_loss(inputs["labels"], logits)
|
||||
if labels is not None:
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@ -1432,6 +1367,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1464,9 +1400,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
@ -1480,34 +1414,16 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_mems=inputs["use_mems"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
|
||||
output = self.sequence_summary(output)
|
||||
logits = self.logits_proj(output)
|
||||
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@ -1558,6 +1474,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
"""
|
||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1590,72 +1507,45 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
||||
"""
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
token_type_ids=token_type_ids,
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_mems=use_mems,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_input_mask = (
|
||||
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
|
||||
)
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
inputs["mems"],
|
||||
inputs["perm_mask"],
|
||||
inputs["target_mapping"],
|
||||
mems,
|
||||
perm_mask,
|
||||
target_mapping,
|
||||
flat_token_type_ids,
|
||||
flat_input_mask,
|
||||
inputs["head_mask"],
|
||||
head_mask,
|
||||
flat_inputs_embeds,
|
||||
inputs["use_mems"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
use_mems,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@ -1706,6 +1596,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1737,9 +1628,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
"""
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
@ -1753,31 +1642,13 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_mems=inputs["use_mems"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.classifier(output)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@ -1812,6 +1683,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1849,9 +1721,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
@ -1865,26 +1735,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_mems=inputs["use_mems"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = transformer_outputs[0]
|
||||
|
||||
@ -1894,12 +1745,12 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user