mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
TF: GPT-J compatible with XLA generation (#17986)
This commit is contained in:
parent
bf37e5c7f6
commit
360719a6a4
@ -60,14 +60,12 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:
|
def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor:
|
||||||
dim = shape_list(x)[-1]
|
|
||||||
if seq_len is None:
|
|
||||||
seq_len = shape_list(x)[seq_dim]
|
|
||||||
inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32)
|
inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32)
|
||||||
seq_len_range = tf.cast(tf.range(seq_len), tf.float32)
|
sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32)
|
||||||
sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", seq_len_range, inv_freq), tf.float32)
|
sin, cos = tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)
|
||||||
return tf.cast(tf.sin(sinusoid_inp), dtype=x.dtype), tf.cast(tf.cos(sinusoid_inp), dtype=x.dtype)
|
out = tf.concat((sin, cos), axis=1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
|
def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
|
||||||
@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
|
|||||||
return rotate_half_tensor
|
return rotate_half_tensor
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(x: tf.Tensor, sincos: tf.Tensor, offset: int = 0) -> tf.Tensor:
|
def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor:
|
||||||
sin_pos, cos_pos = sincos
|
sin_pos, cos_pos = sincos
|
||||||
sin_pos = tf.repeat(sin_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
|
sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3)
|
||||||
cos_pos = tf.repeat(cos_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
|
cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3)
|
||||||
return (x * cos_pos) + (rotate_every_two(x) * sin_pos)
|
return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)
|
||||||
|
|
||||||
|
|
||||||
class TFGPTJAttention(tf.keras.layers.Layer):
|
class TFGPTJAttention(tf.keras.layers.Layer):
|
||||||
@ -132,6 +130,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
|
|||||||
tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8),
|
tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8),
|
||||||
(1, 1, self.max_positions, self.max_positions),
|
(1, 1, self.max_positions, self.max_positions),
|
||||||
)
|
)
|
||||||
|
pos_embd_dim = self.rotary_dim or self.embed_dim
|
||||||
|
self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim)
|
||||||
|
|
||||||
def get_causal_mask(self, key_length, query_length) -> tf.Tensor:
|
def get_causal_mask(self, key_length, query_length) -> tf.Tensor:
|
||||||
return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool)
|
return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool)
|
||||||
@ -207,8 +207,9 @@ class TFGPTJAttention(tf.keras.layers.Layer):
|
|||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
|
||||||
layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
|
layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
|
||||||
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
position_ids: Optional[tf.Tensor] = None,
|
||||||
head_mask: Optional[tf.Tensor] = None,
|
head_mask: Optional[tf.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
@ -221,13 +222,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
|
|||||||
key = self._split_heads(key, True)
|
key = self._split_heads(key, True)
|
||||||
value = self._split_heads(value, False)
|
value = self._split_heads(value, False)
|
||||||
|
|
||||||
seq_len = shape_list(key)[1]
|
sincos = tf.gather(self.embed_positions, position_ids, axis=0)
|
||||||
offset = 0
|
sincos = tf.split(sincos, 2, axis=-1)
|
||||||
|
|
||||||
if layer_past is not None:
|
|
||||||
offset = shape_list(layer_past[0])[-2]
|
|
||||||
seq_len += offset
|
|
||||||
|
|
||||||
if self.rotary_dim is not None:
|
if self.rotary_dim is not None:
|
||||||
k_rot = key[:, :, :, : self.rotary_dim]
|
k_rot = key[:, :, :, : self.rotary_dim]
|
||||||
k_pass = key[:, :, :, self.rotary_dim :]
|
k_pass = key[:, :, :, self.rotary_dim :]
|
||||||
@ -235,16 +231,14 @@ class TFGPTJAttention(tf.keras.layers.Layer):
|
|||||||
q_rot = query[:, :, :, : self.rotary_dim]
|
q_rot = query[:, :, :, : self.rotary_dim]
|
||||||
q_pass = query[:, :, :, self.rotary_dim :]
|
q_pass = query[:, :, :, self.rotary_dim :]
|
||||||
|
|
||||||
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
|
k_rot = apply_rotary_pos_emb(k_rot, sincos)
|
||||||
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
|
q_rot = apply_rotary_pos_emb(q_rot, sincos)
|
||||||
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
|
|
||||||
|
|
||||||
key = tf.concat((k_rot, k_pass), axis=-1)
|
key = tf.concat((k_rot, k_pass), axis=-1)
|
||||||
query = tf.concat((q_rot, q_pass), axis=-1)
|
query = tf.concat((q_rot, q_pass), axis=-1)
|
||||||
else:
|
else:
|
||||||
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
|
key = apply_rotary_pos_emb(key, sincos)
|
||||||
key = apply_rotary_pos_emb(key, sincos, offset=offset)
|
query = apply_rotary_pos_emb(query, sincos)
|
||||||
query = apply_rotary_pos_emb(query, sincos, offset=offset)
|
|
||||||
|
|
||||||
key = tf.transpose(key, (0, 2, 1, 3))
|
key = tf.transpose(key, (0, 2, 1, 3))
|
||||||
query = tf.transpose(query, (0, 2, 1, 3))
|
query = tf.transpose(query, (0, 2, 1, 3))
|
||||||
@ -310,6 +304,7 @@ class TFGPTJBlock(tf.keras.layers.Layer):
|
|||||||
hidden_states: tf.Tensor,
|
hidden_states: tf.Tensor,
|
||||||
layer_past: Optional[tf.Tensor] = None,
|
layer_past: Optional[tf.Tensor] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
position_ids: Optional[tf.Tensor] = None,
|
||||||
head_mask: Optional[tf.Tensor] = None,
|
head_mask: Optional[tf.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
@ -317,9 +312,10 @@ class TFGPTJBlock(tf.keras.layers.Layer):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_outputs = self.attn(
|
attn_outputs = self.attn(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@ -466,12 +462,13 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
|
|||||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||||
|
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask[i],
|
position_ids=position_ids,
|
||||||
use_cache,
|
head_mask=head_mask[i],
|
||||||
output_attentions,
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -722,8 +719,6 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
self.lm_head = tf.keras.layers.Dense(
|
self.lm_head = tf.keras.layers.Dense(
|
||||||
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
|
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
|
||||||
)
|
)
|
||||||
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
|
|
||||||
self.supports_xla_generation = False
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
@ -731,25 +726,21 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
|
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||||
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# tests will need to be fixed after the change
|
|
||||||
|
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||||
|
|
||||||
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
|
position_ids = kwargs.get("position_ids", None)
|
||||||
# for a future PR to not change too many things for now.
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
|
|
||||||
position_ids = None
|
if attention_mask is not None and position_ids is None:
|
||||||
attention_mask = None
|
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||||
if use_xla:
|
if past:
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||||
if past is not None and attention_mask is not None:
|
|
||||||
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
|
|
||||||
elif attention_mask is not None:
|
|
||||||
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": inputs,
|
"input_ids": inputs,
|
||||||
@ -757,6 +748,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"past": past,
|
"past": past,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, GPTJConfig, is_tf_available
|
from transformers import AutoTokenizer, GPTJConfig, is_tf_available
|
||||||
@ -48,6 +47,7 @@ class TFGPTJModelTester:
|
|||||||
self.use_mc_token_ids = True
|
self.use_mc_token_ids = True
|
||||||
self.vocab_size = 99
|
self.vocab_size = 99
|
||||||
self.hidden_size = 32
|
self.hidden_size = 32
|
||||||
|
self.rotary_dim = 4
|
||||||
self.num_hidden_layers = 5
|
self.num_hidden_layers = 5
|
||||||
self.num_attention_heads = 4
|
self.num_attention_heads = 4
|
||||||
self.intermediate_size = 37
|
self.intermediate_size = 37
|
||||||
@ -103,6 +103,7 @@ class TFGPTJModelTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
rotary_dim=self.rotary_dim,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -359,10 +360,10 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
|
@tooslow
|
||||||
|
# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM.
|
||||||
class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
|
class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
|
||||||
@tooslow
|
|
||||||
def test_lm_generate_gptj(self):
|
def test_lm_generate_gptj(self):
|
||||||
# Marked as @tooslow due to GPU OOM
|
|
||||||
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True)
|
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True)
|
||||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -372,74 +373,20 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
output_ids = model.generate(input_ids, do_sample=False)
|
output_ids = model.generate(input_ids, do_sample=False)
|
||||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|
||||||
@tooslow
|
|
||||||
def test_gptj_sample(self):
|
def test_gptj_sample(self):
|
||||||
# Marked as @tooslow due to GPU OOM (issue #13676)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
|
||||||
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
|
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
|
||||||
|
|
||||||
tf.random.set_seed(0)
|
tokenized = tokenizer("Today is a nice day and", return_tensors="tf")
|
||||||
tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True)
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
input_ids, token_type_ids = tokenized.input_ids, tokenized.token_type_ids
|
with tf.device(":/CPU:0"):
|
||||||
output_ids = model.generate(input_ids, do_sample=True)
|
output_ids = model.generate(**tokenized, do_sample=True, seed=[42, 0])
|
||||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
|
EXPECTED_OUTPUT_STR = "Today is a nice day and I’m going to go for a walk. I’"
|
||||||
output_seq_tt = model.generate(
|
|
||||||
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
|
|
||||||
)
|
|
||||||
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
|
|
||||||
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
|
|
||||||
|
|
||||||
EXPECTED_OUTPUT_STR = "Today is a nice day and I am taking an hour to sit in the hammock and just enjoy"
|
|
||||||
|
|
||||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||||
self.assertTrue(
|
|
||||||
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
|
|
||||||
) # token_type_ids should change output
|
|
||||||
|
|
||||||
@slow
|
def _get_beam_search_test_objects(self):
|
||||||
@unittest.skip(reason="TF generate currently has no time-based stopping criteria")
|
|
||||||
def test_gptj_sample_max_time(self):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
|
|
||||||
model = TFGPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random", from_pt=True)
|
|
||||||
|
|
||||||
input_ids = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True).input_ids
|
|
||||||
|
|
||||||
MAX_TIME = 0.5
|
|
||||||
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
|
|
||||||
duration = datetime.datetime.now() - start
|
|
||||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
|
||||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
|
||||||
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
|
|
||||||
duration = datetime.datetime.now() - start
|
|
||||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
|
||||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
|
||||||
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
|
|
||||||
duration = datetime.datetime.now() - start
|
|
||||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
|
||||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
|
||||||
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
|
|
||||||
duration = datetime.datetime.now() - start
|
|
||||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
|
||||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
|
||||||
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
|
||||||
duration = datetime.datetime.now() - start
|
|
||||||
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
|
||||||
|
|
||||||
@tooslow
|
|
||||||
def test_batch_generation(self):
|
|
||||||
# Marked as @tooslow due to GPU OOM
|
|
||||||
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
|
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
|
||||||
|
|
||||||
@ -454,42 +401,46 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
"Hello, my dog is a little",
|
"Hello, my dog is a little",
|
||||||
"Today, I",
|
"Today, I",
|
||||||
]
|
]
|
||||||
|
expected_output_sentences = [
|
||||||
|
"Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia",
|
||||||
|
"Today, I’m going to be talking about a topic that’",
|
||||||
|
]
|
||||||
|
return model, tokenizer, sentences, expected_output_sentences
|
||||||
|
|
||||||
|
def test_batch_beam_search(self):
|
||||||
|
# Confirms that we get the expected results with left-padded beam search
|
||||||
|
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()
|
||||||
|
|
||||||
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
|
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
input_ids = inputs["input_ids"]
|
outputs = model.generate(**inputs, do_sample=False, num_beams=2)
|
||||||
token_type_ids = tf.concat(
|
|
||||||
[
|
|
||||||
tf.zeros((input_ids.shape[0], input_ids.shape[1] - 1), dtype=tf.int64),
|
|
||||||
500 * tf.ones((input_ids.shape[0], 1), dtype=tf.int64),
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
|
|
||||||
outputs_tt = model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
|
|
||||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
|
||||||
|
|
||||||
num_paddings = (
|
|
||||||
shape_list(inputs_non_padded)[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
|
|
||||||
)
|
|
||||||
inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
|
|
||||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
|
||||||
|
|
||||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
|
self.assertListEqual(expected_output_sentences, batch_out_sentence)
|
||||||
|
|
||||||
|
def test_batch_left_padding(self):
|
||||||
|
# Confirms that left-padding is working properly
|
||||||
|
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()
|
||||||
|
|
||||||
|
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf")
|
||||||
|
output_non_padded = model.generate(**inputs_non_padded, do_sample=False, num_beams=2)
|
||||||
|
num_paddings = (
|
||||||
|
shape_list(inputs_non_padded["input_ids"])[-1]
|
||||||
|
- tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
|
||||||
|
)
|
||||||
|
inputs_padded = tokenizer(sentences[1], return_tensors="tf")
|
||||||
|
output_padded = model.generate(
|
||||||
|
**inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings
|
||||||
|
)
|
||||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||||
|
self.assertListEqual(expected_output_sentences, [non_padded_sentence, padded_sentence])
|
||||||
|
|
||||||
expected_output_sentence = [
|
def test_xla_beam_search(self):
|
||||||
"Hello, my dog is a little over a year old and has been diagnosed with a heart murmur",
|
# Confirms that XLA is working properly
|
||||||
"Today, I’m going to share with you a few of my favorite",
|
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()
|
||||||
]
|
|
||||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
outputs_xla = xla_generate(**inputs, do_sample=False, num_beams=2)
|
||||||
|
xla_sentence = tokenizer.batch_decode(outputs_xla, skip_special_tokens=True)
|
||||||
|
self.assertListEqual(expected_output_sentences, xla_sentence)
|
||||||
|
Loading…
Reference in New Issue
Block a user