[WIP] Add type hints for Lxmert (TF) (#19441)

* Add type hints for Lxmert (TF)

* Update src/transformers/models/lxmert/modeling_tf_lxmert.py

Co-authored-by: Emmanuel Lusenji <elusenji@Emmanuels-MacBook-Pro.local>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Emmanuel Lusenji 2022-10-13 16:53:27 +02:00 committed by GitHub
parent 036e808517
commit f06a6f7e37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,14 +18,22 @@
import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from transformers.tf_utils import stable_softmax
from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list, unpack_inputs
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
get_initializer,
keras_serializable,
shape_list,
unpack_inputs,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@ -698,7 +706,6 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
return_dict=None,
training=False,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -951,18 +958,18 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
)
def call(
self,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
input_ids: Optional[TFModelInputType] = None,
visual_feats: Optional[tf.Tensor] = None,
visual_pos: Optional[tf.Tensor] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
visual_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple, TFLxmertModelOutput]:
outputs = self.lxmert(
input_ids,
visual_feats,