From 43891be19b4fd36cba972dfdc6fcfe0ef8cec25e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 May 2021 10:31:17 +0100 Subject: [PATCH] [T5 failing CI] Fix generate test (#11770) * fix_torch_device_generate_test * remove @ --- tests/test_generation_utils.py | 10 +++++++--- tests/test_modeling_t5.py | 13 ++++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index b377bd3fa6b..1134674a80a 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1084,9 +1084,13 @@ class GenerationTesterMixin: continue head_masking = { - "head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads), - "decoder_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads), - "cross_attn_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads), + "head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device), + "decoder_head_mask": torch.zeros( + config.decoder_layers, config.decoder_attention_heads, device=torch_device + ), + "cross_attn_head_mask": torch.zeros( + config.decoder_layers, config.decoder_attention_heads, device=torch_device + ), } signature = inspect.signature(model.forward) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 7f538a36c77..f020447d007 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -605,19 +605,22 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() config = config_and_inputs[0] max_length = config_and_inputs[1].shape[-1] + 3 - model = T5ForConditionalGeneration(config) + model = T5ForConditionalGeneration(config).eval() + model.to(torch_device) head_masking = { - "head_mask": torch.zeros(config.num_layers, config.num_heads), - "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads), - "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads), + "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), + "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), } for attn_name, (name, mask) in zip(attention_names, head_masking.items()): head_masks = {name: mask} # Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified if name == "head_mask": - head_masks["decoder_head_mask"] = torch.ones(config.num_decoder_layers, config.num_heads) + head_masks["decoder_head_mask"] = torch.ones( + config.num_decoder_layers, config.num_heads, device=torch_device + ) out = model.generate( config_and_inputs[1],