diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 28558786c56..8007b40f370 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -205,7 +205,7 @@ Flax), PyTorch, and/or TensorFlow. | Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | | GLPN | ❌ | ❌ | ✅ | ❌ | ❌ | | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | -| GPT-J | ❌ | ❌ | ✅ | ❌ | ✅ | +| GPT-J | ❌ | ❌ | ✅ | ✅ | ✅ | | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/model_doc/gptj.mdx b/docs/source/model_doc/gptj.mdx index 67edd444834..62bd224746a 100644 --- a/docs/source/model_doc/gptj.mdx +++ b/docs/source/model_doc/gptj.mdx @@ -130,6 +130,26 @@ model. [[autodoc]] GPTJForQuestionAnswering - forward +## TFGPTJModel + +[[autodoc]] TFGPTJModel + - call + +## TFGPTJForCausalLM + +[[autodoc]] TFGPTJForCausalLM + - call + +## TFGPTJForSequenceClassification + +[[autodoc]] TFGPTJForSequenceClassification + - call + +## TFGPTJForQuestionAnswering + +[[autodoc]] TFGPTJForQuestionAnswering + - call + ## FlaxGPTJModel [[autodoc]] FlaxGPTJModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3b11b50a9a4..ad61a63062c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1929,6 +1929,15 @@ if is_tf_available(): "TFGPT2PreTrainedModel", ] ) + _import_structure["models.gptj"].extend( + [ + "TFGPTJForCausalLM", + "TFGPTJForQuestionAnswering", + "TFGPTJForSequenceClassification", + "TFGPTJModel", + "TFGPTJPreTrainedModel", + ] + ) _import_structure["models.hubert"].extend( [ "TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4003,6 +4012,13 @@ if TYPE_CHECKING: TFGPT2Model, TFGPT2PreTrainedModel, ) + from .models.gptj import ( + TFGPTJForCausalLM, + TFGPTJForQuestionAnswering, + TFGPTJForSequenceClassification, + TFGPTJModel, + TFGPTJPreTrainedModel, + ) from .models.hubert import ( TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFHubertForCTC, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 34f59393aff..8afa05ba578 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -52,6 +52,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ("bert", "TFBertModel"), ("openai-gpt", "TFOpenAIGPTModel"), ("gpt2", "TFGPT2Model"), + ("gptj", "TFGPTJModel"), ("mobilebert", "TFMobileBertModel"), ("transfo-xl", "TFTransfoXLModel"), ("xlnet", "TFXLNetModel"), @@ -123,6 +124,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ("bert", "TFBertForMaskedLM"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"), + ("gptj", "TFGPTJForCausalLM"), ("mobilebert", "TFMobileBertForMaskedLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), ("xlnet", "TFXLNetLMHeadModel"), @@ -146,6 +148,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("bert", "TFBertLMHeadModel"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"), + ("gptj", "TFGPTJForCausalLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), ("xlnet", "TFXLNetLMHeadModel"), ("xlm", "TFXLMWithLMHeadModel"), @@ -239,6 +242,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("tapas", "TFTapasForSequenceClassification"), ("funnel", "TFFunnelForSequenceClassification"), ("gpt2", "TFGPT2ForSequenceClassification"), + ("gptj", "TFGPTJForSequenceClassification"), ("mpnet", "TFMPNetForSequenceClassification"), ("openai-gpt", "TFOpenAIGPTForSequenceClassification"), ("transfo-xl", "TFTransfoXLForSequenceClassification"), @@ -267,6 +271,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("xlm", "TFXLMForQuestionAnsweringSimple"), ("electra", "TFElectraForQuestionAnswering"), ("funnel", "TFFunnelForQuestionAnswering"), + ("gptj", "TFGPTJForQuestionAnswering"), ("mpnet", "TFMPNetForQuestionAnswering"), ] ) diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py index 69ca43f276b..a6b144ab825 100644 --- a/src/transformers/models/gptj/__init__.py +++ b/src/transformers/models/gptj/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import _LazyModule, is_flax_available, is_torch_available +from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available _import_structure = { @@ -34,6 +34,15 @@ if is_torch_available(): "GPTJPreTrainedModel", ] +if is_tf_available(): + _import_structure["modeling_tf_gptj"] = [ + "TFGPTJForCausalLM", + "TFGPTJForQuestionAnswering", + "TFGPTJForSequenceClassification", + "TFGPTJModel", + "TFGPTJPreTrainedModel", + ] + if is_flax_available(): _import_structure["modeling_flax_gptj"] = [ "FlaxGPTJForCausalLM", @@ -55,6 +64,15 @@ if TYPE_CHECKING: GPTJPreTrainedModel, ) + if is_tf_available(): + from .modeling_tf_gptj import ( + TFGPTJForCausalLM, + TFGPTJForQuestionAnswering, + TFGPTJForSequenceClassification, + TFGPTJModel, + TFGPTJPreTrainedModel, + ) + if is_flax_available(): from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py new file mode 100644 index 00000000000..6c24d747692 --- /dev/null +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -0,0 +1,1156 @@ +# coding=utf-8 +# Copyright 2022 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 GPT-J model.""" + +from typing import Optional, Tuple + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...file_utils import ( + DUMMY_INPUTS, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPast, + TFCausalLMOutputWithPast, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutputWithPast, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFSharedEmbeddings, + get_initializer, + input_processing, + keras_serializable, +) +from ...tf_utils import shape_list +from ...utils import logging +from .configuration_gptj import GPTJConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B" +_CONFIG_FOR_DOC = "GPTJConfig" +_TOKENIZER_FOR_DOC = "GPTJTokenizer" + +GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "EleutherAI/gpt-j-6B", + # See all GPT-J models at https://huggingface.co/models?filter=gptj +] + + +def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]: + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32) + seq_len_range = tf.cast(tf.range(seq_len), tf.float32) + sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", seq_len_range, inv_freq), tf.float32) + return tf.sin(sinusoid_inp), tf.cos(sinusoid_inp) + + +def rotate_every_two(x: tf.Tensor) -> tf.Tensor: + rotate_half_tensor = tf.stack((-x[:, :, :, 1::2], x[:, :, :, ::2]), axis=-1) + new_shape = shape_list(rotate_half_tensor)[:-2] + [tf.math.reduce_prod(shape_list(rotate_half_tensor)[-2:])] + rotate_half_tensor = tf.reshape(rotate_half_tensor, new_shape) + return rotate_half_tensor + + +def apply_rotary_pos_emb(x: tf.Tensor, sincos: tf.Tensor, offset: int = 0) -> tf.Tensor: + sin_pos, cos_pos = sincos + sin_pos = tf.repeat(sin_pos[None, offset : x.shape[1] + offset, None, :], 2, 3) + cos_pos = tf.repeat(cos_pos[None, offset : x.shape[1] + offset, None, :], 2, 3) + return (x * cos_pos) + (rotate_every_two(x) * sin_pos) + + +class TFGPTJAttention(tf.keras.layers.Layer): + def __init__(self, config: GPTJConfig, **kwargs): + super().__init__(**kwargs) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = self.head_dim**0.5 + self.rotary_dim = config.rotary_dim + + self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop) + self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop) + + self.q_proj = tf.keras.layers.Dense( + self.embed_dim, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="q_proj", + ) + self.k_proj = tf.keras.layers.Dense( + self.embed_dim, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="k_proj", + ) + self.v_proj = tf.keras.layers.Dense( + self.embed_dim, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="v_proj", + ) + self.out_proj = tf.keras.layers.Dense( + self.embed_dim, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="out_proj", + ) + + self.max_positions = config.max_position_embeddings + self.lower_triangle_mask = tf.reshape( + tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8), + (1, 1, self.max_positions, self.max_positions), + ) + + def get_causal_mask(self, key_length, query_length) -> tf.Tensor: + return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool) + + @staticmethod + def get_masked_bias(dtype: tf.DType) -> tf.Tensor: + return tf.cast(tf.constant(-1e9), dtype) + + def _split_heads(self, hidden_states: tf.Tensor, rotary: bool) -> tf.Tensor: + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = shape_list(hidden_states)[:-1] + [self.num_attention_heads, self.head_dim] + hidden_states = tf.reshape(hidden_states, new_shape) + if rotary: + return hidden_states + if len(shape_list(hidden_states)) == 4: + return tf.transpose(hidden_states, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) + if len(shape_list(hidden_states)) == 5: + return tf.transpose(hidden_states, (0, 1, 3, 2, 4)) # (batch, blocks, head, block_length, head_features) + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") + + def _merge_heads(self, hidden_states: tf.Tensor) -> tf.Tensor: + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + if len(shape_list(hidden_states)) == 4: + hidden_states = tf.transpose(hidden_states, (0, 2, 1, 3)) + elif len(shape_list(hidden_states)) == 5: + hidden_states = tf.transpose(hidden_states, (0, 1, 3, 2, 4)) + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") + new_shape = shape_list(hidden_states)[:-2] + [self.num_attention_heads * self.head_dim] + return tf.reshape(hidden_states, new_shape) + + def _attn( + self, + query: tf.Tensor, + key: tf.Tensor, + value: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + head_mask: Optional[tf.Tensor] = None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + # compute causal mask from causal mask buffer + query_length, key_length = query.shape[-2], key.shape[-2] + causal_mask = self.get_causal_mask(key_length, query_length) + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = tf.cast(query, tf.float32) + key = tf.cast(key, tf.float32) + + attn_weights = tf.matmul(query, key, transpose_b=True) + attn_weights = tf.where(causal_mask, attn_weights, self.get_masked_bias(attn_weights.dtype)) + + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + attn_weights = tf.cast(attn_weights, value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = tf.matmul(attn_weights, value) + + return attn_output, attn_weights + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None, + head_mask: Optional[tf.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, True) + key = self._split_heads(key, True) + value = self._split_heads(value, False) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = tf.concat((k_rot, k_pass), axis=-1) + query = tf.concat((q_rot, q_pass), axis=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = tf.transpose(key, (0, 2, 1, 3)) + query = tf.transpose(query, (0, 2, 1, 3)) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = tf.concat((past_key, key), axis=-2) + value = tf.concat((past_value, value), axis=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class TFGPTJMLP(tf.keras.layers.Layer): + def __init__(self, intermediate_size: int, config: GPTJConfig, **kwargs): + super().__init__(**kwargs) + embed_dim = config.n_embd + + self.fc_in = tf.keras.layers.Dense( + intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="fc_in" + ) + self.fc_out = tf.keras.layers.Dense( + embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="fc_out" + ) + + self.act = get_tf_activation(config.activation_function) + self.dropout = tf.keras.layers.Dropout(config.embd_pdrop) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class TFGPTJBlock(tf.keras.layers.Layer): + def __init__(self, config: GPTJConfig, **kwargs): + super().__init__(**kwargs) + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") + self.attn = TFGPTJAttention(config, name="attn") + self.mlp = TFGPTJMLP(inner_dim, config, name="mlp") + + def call( + self, + hidden_states: tf.Tensor, + layer_past: Optional[tf.Tensor] = None, + attention_mask: Optional[tf.Tensor] = None, + head_mask: Optional[tf.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) # attn_outputs: attn_output, present, (attentions) + attn_output = attn_outputs[0] + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + return outputs # hidden_states, present, (attentions) + + +@keras_serializable +class TFGPTJMainLayer(tf.keras.layers.Layer): + config_class = GPTJConfig + + def __init__(self, config: GPTJConfig, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + self.config = config + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.use_cache = config.use_cache + self.return_dict = config.use_return_dict + + self.num_hidden_layers = config.n_layer + self.vocab_size = config.vocab_size + self.n_embd = config.n_embd + self.n_positions = config.n_positions + self.initializer_range = config.initializer_range + + self.wte = TFSharedEmbeddings( + config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" + ) + self.drop = tf.keras.layers.Dropout(config.embd_pdrop) + self.h = [TFGPTJBlock(config, name=f"h_._{i}") for i in range(config.n_layer)] + self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, value: tf.Tensor): + self.wte.weight = value + self.wte.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + raise NotImplementedError + + def call( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif inputs["input_ids"] is not None: + input_shape = shape_list(inputs["input_ids"]) + inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) + elif inputs["inputs_embeds"] is not None: + input_shape = shape_list(inputs["inputs_embeds"])[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs["past_key_values"] is None: + past_length = 0 + inputs["past_key_values"] = [None] * len(self.h) + else: + past_length = shape_list(inputs["past_key_values"][0][0])[-2] + + if inputs["position_ids"] is None: + inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + + if inputs["attention_mask"] is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(inputs["attention_mask"]) + inputs["attention_mask"] = tf.reshape( + inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + one_cst = tf.constant(1.0) + inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) + inputs["attention_mask"] = tf.multiply( + tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if inputs["head_mask"] is not None: + raise NotImplementedError + else: + inputs["head_mask"] = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) + + if inputs["inputs_embeds"] is None: + inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") + + if inputs["token_type_ids"] is not None: + inputs["token_type_ids"] = tf.reshape( + inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] + ) + token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding") + else: + token_type_embeds = tf.constant(0.0) + + token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) + hidden_states = inputs["inputs_embeds"] + token_type_embeds + hidden_states = self.drop(hidden_states, training=inputs["training"]) + + output_shape = input_shape + [shape_list(hidden_states)[-1]] + + presents = () if inputs["use_cache"] else None + all_attentions = () if inputs["output_attentions"] else None + all_hidden_states = () if inputs["output_hidden_states"] else None + for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])): + if inputs["output_hidden_states"]: + all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) + + outputs = block( + hidden_states, + layer_past, + inputs["attention_mask"], + inputs["head_mask"][i], + inputs["use_cache"], + inputs["output_attentions"], + training=inputs["training"], + ) + + hidden_states = outputs[0] + if inputs["use_cache"]: + presents = presents + (outputs[1],) + + if inputs["output_attentions"]: + all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = tf.reshape(hidden_states, output_shape) + # Add last hidden state + if inputs["output_hidden_states"]: + all_hidden_states = all_hidden_states + (hidden_states,) + + if inputs["output_attentions"]: + # let the number of heads free (-1) so we can extract attention even after head pruning + attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] + all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) + + if not inputs["return_dict"]: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class TFGPTJPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTJConfig + base_model_prefix = "transformer" + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias"] + + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + return dummy + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + +GPTJ_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + + If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the + first positional argument : + + - a single Tensor with `input_ids` only and nothing else: `model(inputs_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + + + Parameters: + config ([`GPTJConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPTJ_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of + input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`. + + Indices can be obtained using [`GPTJTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`List[tf.Tensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see + `past` output below). Can be used to speed up sequential decoding. The token ids which have their past + given to this model should not be passed as input ids as they have already been computed. + attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.", + GPTJ_START_DOCSTRING, +) +class TFGPTJModel(TFGPTJPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFGPTJMainLayer(config, name="transformer") + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + r""" + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past`). Set to `False` during training, `True` during generation + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + outputs = self.transformer( + input_ids=inputs["input_ids"], + past_key_values=inputs["past_key_values"], + attention_mask=inputs["attention_mask"], + token_type_ids=inputs["token_type_ids"], + position_ids=inputs["position_ids"], + head_mask=inputs["head_mask"], + inputs_embeds=inputs["inputs_embeds"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + + return outputs + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutputWithPast( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + ) + + +@add_start_docstrings( + """ + The GPT-J Model transformer with a language modeling head on top. + """, + GPTJ_START_DOCSTRING, +) +class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFGPTJMainLayer(config, name="transformer") + self.lm_head = tf.keras.layers.Dense( + config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head" + ) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs): + # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 + # tests will need to be fixed after the change + + # only last token for inputs_ids if past is defined in kwargs + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left + # for a future PR to not change too many things for now. + # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch) + position_ids = None + attention_mask = None + if use_xla: + attention_mask = kwargs.get("attention_mask", None) + if past is not None and attention_mask is not None: + position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1 + elif attention_mask is not None: + position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past": past, + "use_cache": use_cache, + } + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + r""" + labels (`np.ndarray` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + transformer_outputs = self.transformer( + input_ids=inputs["input_ids"], + past_key_values=inputs["past_key_values"], + attention_mask=inputs["attention_mask"], + token_type_ids=inputs["token_type_ids"], + position_ids=inputs["position_ids"], + head_mask=inputs["head_mask"], + inputs_embeds=inputs["inputs_embeds"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if inputs["labels"] is not None: + # shift labels to the left and cut last logit token + shifted_logits = lm_logits[:, :-1] + labels = inputs["labels"][:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not inputs["return_dict"]: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns) + + +@add_start_docstrings( + """ + The GPT-J Model transformer with a sequence classification head on top (linear layer). + + [`GPTJForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT, GPT-2, GPT-Neo) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPTJ_START_DOCSTRING, +) +class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassificationLoss): + _keys_to_ignore_on_load_missing = [r"h.\d+.attn.masked_bias", r"h.\d+.attn.bias", r"lm_head.weight"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.transformer = TFGPTJMainLayer(config, name="transformer") + self.score = tf.keras.layers.Dense( + self.num_labels, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="score", + ) + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + r""" + labels (`np.ndarray` or `tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + transformer_outputs = self.transformer( + input_ids=inputs["input_ids"], + past_key_values=inputs["past_key_values"], + attention_mask=inputs["attention_mask"], + token_type_ids=inputs["token_type_ids"], + position_ids=inputs["position_ids"], + head_mask=inputs["head_mask"], + inputs_embeds=inputs["inputs_embeds"], + use_cache=inputs["use_cache"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + logits_shape = shape_list(logits) + in_logits = None + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if inputs["input_ids"] is not None: + sequence_lengths = ( + tf.reduce_sum( + tf.cast( + tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), + dtype=inputs["input_ids"].dtype, + ), + -1, + keepdims=False, + ) + - 1 + ) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + loss = None + + if inputs["labels"] is not None: + assert ( + self.config.pad_token_id is not None or logits_shape[0] == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + if not tf.is_tensor(sequence_lengths): + in_logits = logits[0 : logits_shape[0], sequence_lengths] + + loss = self.hf_compute_loss( + tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]) + ) + pooled_logits = in_logits if in_logits is not None else logits + + if not inputs["return_dict"]: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFSequenceClassifierOutputWithPast( + logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns + ) + + +@add_start_docstrings( + """ + The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPTJ_START_DOCSTRING, +) +class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss): + _keys_to_ignore_on_load_missing = [r"h.\d+.attn.masked_bias", r"h.\d+.attn.bias", r"lm_head.weight"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.transformer = TFGPTJMainLayer(config, name="transformer") + self.qa_outputs = tf.keras.layers.Dense( + self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + r""" + start_positions (`np.ndarray` or `tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`np.ndarray` or `tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + start_positions=start_positions, + end_positions=end_positions, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + transformer_outputs = self.transformer( + input_ids=inputs["input_ids"], + past_key_values=inputs["past_key_values"], + attention_mask=inputs["attention_mask"], + token_type_ids=inputs["token_type_ids"], + position_ids=inputs["position_ids"], + head_mask=inputs["head_mask"], + inputs_embeds=inputs["inputs_embeds"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + sequence_output = transformer_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if inputs["start_positions"] is not None and inputs["end_positions"] is not None: + labels = {"start_position": inputs["start_positions"]} + labels["end_position"] = inputs["end_positions"] + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not inputs["return_dict"]: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFQuestionAnsweringModelOutput( + start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 39132d6c6cb..202cd481ebe 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1157,6 +1157,41 @@ class TFGPT2PreTrainedModel(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFGPTJForCausalLM(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFGPTJForQuestionAnswering(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFGPTJForSequenceClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFGPTJModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFGPTJPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/gptj/test_modeling_tf_gptj.py b/tests/gptj/test_modeling_tf_gptj.py new file mode 100644 index 00000000000..50bcf1cc8a0 --- /dev/null +++ b/tests/gptj/test_modeling_tf_gptj.py @@ -0,0 +1,490 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest + +from transformers import AutoTokenizer, GPTJConfig, is_tf_available +from transformers.testing_utils import require_tf, slow, tooslow + +from ..test_configuration_common import ConfigTester +from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers.models.gptj.modeling_tf_gptj import ( + TFGPTJForCausalLM, + TFGPTJForQuestionAnswering, + TFGPTJForSequenceClassification, + TFGPTJModel, + shape_list, + ) + + +class TFGPTJModelTester: + def __init__(self, parent): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_token_type_ids = True + self.use_input_mask = True + self.use_labels = True + self.use_mc_token_ids = True + self.vocab_size = 99 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 16 + self.type_sequence_label_size = 2 + self.initializer_range = 0.02 + self.num_labels = 3 + self.num_choices = 4 + self.scope = None + self.bos_token_id = self.vocab_size - 1 + self.eos_token_id = self.vocab_size - 1 + self.pad_token_id = self.vocab_size - 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = GPTJConfig( + vocab_size=self.vocab_size, + n_embd=self.hidden_size, + n_layer=self.num_hidden_layers, + n_head=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + return_dict=True, + ) + + head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def create_and_check_gptj_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = TFGPTJModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + } + result = model(inputs) + + inputs = [input_ids, None, input_mask] # None is the input for 'past' + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_gptj_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = TFGPTJModel(config=config) + + # first forward pass + outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids) + outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size) + + # append to next input_ids and token_type_ids + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) + + output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] + output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"] + + # select random slice + random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_gptj_model_attention_mask_past( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + model = TFGPTJModel(config=config) + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + output, past = model(input_ids, attention_mask=attn_mask).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat([attn_mask, tf.ones((shape_list(attn_mask)[0], 1), dtype=tf.int32)], axis=1) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"] + + # select random slice + random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12) + + def create_and_check_gptj_model_past_large_inputs( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + model = TFGPTJModel(config=config) + + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + token_type_ids = token_type_ids[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + next_token_types = ids_tensor((self.batch_size, 3), self.type_vocab_size) + + # append to next input_ids and token_type_ids + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) + + output_from_no_past = model( + next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past=past + )["last_hidden_state"] + self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_gptj_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = TFGPTJForCausalLM(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + } + result = model(inputs) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + } + return config, inputs_dict + + +@require_tf +class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + (TFGPTJForCausalLM, TFGPTJForSequenceClassification, TFGPTJForQuestionAnswering, TFGPTJModel) + if is_tf_available() + else () + ) + + all_generative_model_classes = (TFGPTJForCausalLM,) if is_tf_available() else () + test_onnx = False + test_pruning = False + test_missing_keys = False + test_head_masking = False + + def setUp(self): + self.model_tester = TFGPTJModelTester(self) + self.config_tester = ConfigTester(self, config_class=GPTJConfig, n_embd=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_gptj_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model(*config_and_inputs) + + def test_gptj_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model_past(*config_and_inputs) + + def test_gptj_model_att_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model_attention_mask_past(*config_and_inputs) + + def test_gptj_model_past_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model_past_large_inputs(*config_and_inputs) + + def test_gptj_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_lm_head_model(*config_and_inputs) + + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + + @slow + def test_model_from_pretrained(self): + model = TFGPTJModel.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True) + self.assertIsNotNone(model) + + @unittest.skip(reason="Currently, model embeddings are going to undergo a major refactor.") + def test_resize_token_embeddings(self): + super().test_resize_token_embeddings() + + +@require_tf +class TFGPTJModelLanguageGenerationTest(unittest.TestCase): + @tooslow + def test_lm_generate_gptj(self): + # Marked as @tooslow due to GPU OOM + model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True) + input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog + # fmt: off + # The dog is a man's best friend. It is a loyal companion, and it is a friend + expected_output_ids = [464, 3290, 318, 257, 582, 338, 1266, 1545, 13, 632, 318, 257, 9112, 15185, 11, 290, 340, 318, 257, 1545] + # fmt: on + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) + + @tooslow + def test_gptj_sample(self): + # Marked as @tooslow due to GPU OOM (issue #13676) + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16") + model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True) + + tf.random.set_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True) + input_ids, token_type_ids = tokenized.input_ids, tokenized.token_type_ids + output_ids = model.generate(input_ids, do_sample=True) + output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) + output_seq_tt = model.generate( + input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 + ) + output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) + output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) + + EXPECTED_OUTPUT_STR = "Today is a nice day and I am taking an hour to sit in the hammock and just enjoy" + + self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + self.assertTrue( + all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) + ) # token_type_ids should change output + + @slow + def test_gptj_sample_max_time(self): + tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random") + model = TFGPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random", from_pt=True) + + input_ids = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True).input_ids + + MAX_TIME = 0.5 + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, max_time=None, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + @tooslow + def test_batch_generation(self): + # Marked as @tooslow due to GPU OOM + model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True) + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16") + + tokenizer.padding_side = "left" + + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="tf", padding=True) + input_ids = inputs["input_ids"] + token_type_ids = tf.concat( + [ + tf.zeros((input_ids.shape[0], input_ids.shape[1] - 1), dtype=tf.int64), + 500 * tf.ones((input_ids.shape[0], 1), dtype=tf.int64), + ], + axis=-1, + ) + + outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"]) + outputs_tt = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"], + token_type_ids=token_type_ids, + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = ( + shape_list(inputs_non_padded)[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy() + ) + inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little over a year old and has been diagnosed with a heart murmur", + "Today, I’m going to share with you a few of my favorite", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])