mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Removed XLMModel inheritance from FlaubertModel(torch+tf) (#19432)
* FlaubertModel inheritance from XLMModel removed * Fix style and add FlaubertPreTrainedModel to __init__ * Fix formatting issue * Fix Typo and repo-consistency * Fix style * add FlaubertPreTrainedModel to TYPE_HINT * fix repo consistency * Update src/transformers/models/flaubert/modeling_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_tf_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_tf_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/flaubert/modeling_flaubert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * removed redundant Copied from comments * added missing copied from comments Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
5fda1fbd46
commit
ed858f5354
@ -1273,6 +1273,7 @@ else:
|
||||
"FlaubertForTokenClassification",
|
||||
"FlaubertModel",
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FlaubertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.flava"].extend(
|
||||
@ -4141,6 +4142,7 @@ if TYPE_CHECKING:
|
||||
FlaubertForSequenceClassification,
|
||||
FlaubertForTokenClassification,
|
||||
FlaubertModel,
|
||||
FlaubertPreTrainedModel,
|
||||
FlaubertWithLMHeadModel,
|
||||
)
|
||||
from .models.flava import (
|
||||
|
@ -41,6 +41,7 @@ else:
|
||||
"FlaubertForTokenClassification",
|
||||
"FlaubertModel",
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FlaubertPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -79,6 +80,7 @@ if TYPE_CHECKING:
|
||||
FlaubertForSequenceClassification,
|
||||
FlaubertForTokenClassification,
|
||||
FlaubertModel,
|
||||
FlaubertPreTrainedModel,
|
||||
FlaubertWithLMHeadModel,
|
||||
)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -26,28 +26,35 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFMultipleChoiceModelOutput,
|
||||
TFQuestionAnsweringModelOutput,
|
||||
TFSequenceClassifierOutput,
|
||||
TFTokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFMultipleChoiceLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
)
|
||||
from ..xlm.modeling_tf_xlm import (
|
||||
TFXLMForMultipleChoice,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForTokenClassification,
|
||||
)
|
||||
from .configuration_flaubert import FlaubertConfig
|
||||
|
||||
|
||||
@ -218,7 +225,7 @@ class TFFlaubertPreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
|
||||
# Sometimes Flaubert has language embeddings so don't forget to build them as well if needed
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
@ -862,12 +869,84 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
|
||||
""",
|
||||
FLAUBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||
class TFFlaubertForSequenceClassification(TFFlaubertPreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cache: Optional[Dict[str, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
|
||||
logits = self.sequence_summary(output)
|
||||
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
|
||||
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -877,12 +956,99 @@ class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
|
||||
""",
|
||||
FLAUBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||
class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cache: Optional[Dict[str, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
r"""
|
||||
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
sequence_output = transformer_outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
|
||||
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -892,12 +1058,86 @@ class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):
|
||||
""",
|
||||
FLAUBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||
class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassificationLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cache: Optional[Dict[str, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
"""
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
sequence_output = transformer_outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFTokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
|
||||
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -907,9 +1147,139 @@ class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
|
||||
""",
|
||||
FLAUBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||
class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
||||
self.logits_proj = tf.keras.layers.Dense(
|
||||
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
|
||||
)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
tf.Tensor with dummy inputs
|
||||
"""
|
||||
# Sometimes Flaubert has language embeddings so don't forget to build them as well if needed
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
return {
|
||||
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
|
||||
}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(
|
||||
FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||
)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFMultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
cache: Optional[Dict[str, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: bool = False,
|
||||
):
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if lengths is not None:
|
||||
logger.warning(
|
||||
"The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
|
||||
"attention mask instead.",
|
||||
)
|
||||
lengths = None
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_langs,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
lengths,
|
||||
cache,
|
||||
head_mask,
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFMultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
|
||||
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
|
||||
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
|
||||
}
|
||||
]
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
|
||||
def serving(self, inputs: Dict[str, tf.Tensor]):
|
||||
output = self.call(input_ids=inputs)
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
|
||||
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
@ -2065,6 +2065,13 @@ class FlaubertModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class FlaubertPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class FlaubertWithLMHeadModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user