From a6752a7d3c23d03c5d4456c51e992250ebabfc1b Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 13 Apr 2023 23:45:22 +0200 Subject: [PATCH] Fix `serving_output` for TF composite models (encoder-decoder like models) (#22743) * fix * style * fix --------- Co-authored-by: ydshieh --- .../modeling_tf_encoder_decoder.py | 16 ++++++++++------ .../modeling_tf_vision_encoder_decoder.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index be6d03a1318..1c90245b696 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -633,14 +633,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ) def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None + dec_hs = ( + tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None + ) + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None + enc_hs = ( + tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None + ) + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None cross_attns = ( tf.convert_to_tensor(output.cross_attentions) - if self.config.output_attentions and output.cross_attentions is not None + if self.config.decoder.output_attentions and output.cross_attentions is not None else None ) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 5af7c195ff0..439c5d668a9 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -662,14 +662,18 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ) def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None + dec_hs = ( + tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None + ) + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None + enc_hs = ( + tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None + ) + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None cross_attns = ( tf.convert_to_tensor(output.cross_attentions) - if self.config.output_attentions and output.cross_attentions is not None + if self.config.decoder.output_attentions and output.cross_attentions is not None else None )