mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
set test_torchscript = False
for Blip2 testing (#35972)
* just skip * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
0a9923a609
commit
dd16acb8a3
@ -15,7 +15,6 @@
|
|||||||
"""Testing suite for the PyTorch BLIP-2 model."""
|
"""Testing suite for the PyTorch BLIP-2 model."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -36,7 +35,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_torch_available, is_torch_sdpa_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -477,7 +476,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = True
|
test_torchscript = False
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -494,116 +493,6 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
|
||||||
# overwrite because BLIP requires ipnut ids and pixel values as input
|
|
||||||
if not self.test_torchscript:
|
|
||||||
self.skipTest(reason="test_torchscript is set to `False`")
|
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
||||||
configs_no_init.torchscript = True
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
for attn_implementation in ["eager", "sdpa"]:
|
|
||||||
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
|
||||||
continue
|
|
||||||
|
|
||||||
configs_no_init._attn_implementation = attn_implementation
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
|
|
||||||
main_input_name = model_class.main_input_name
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
|
||||||
main_input = inputs[main_input_name]
|
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
attention_mask = inputs["attention_mask"]
|
|
||||||
decoder_input_ids = inputs["decoder_input_ids"]
|
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
|
||||||
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
|
||||||
traced_model = torch.jit.trace(
|
|
||||||
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
main_input = inputs[main_input_name]
|
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
|
|
||||||
if model.config._attn_implementation == "sdpa":
|
|
||||||
trace_input = {main_input_name: main_input, "input_ids": input_ids}
|
|
||||||
|
|
||||||
if "attention_mask" in inputs:
|
|
||||||
trace_input["attention_mask"] = inputs["attention_mask"]
|
|
||||||
else:
|
|
||||||
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
|
||||||
|
|
||||||
model(main_input, attention_mask=inputs["attention_mask"])
|
|
||||||
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
|
||||||
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
|
||||||
else:
|
|
||||||
model(main_input, input_ids)
|
|
||||||
traced_model = torch.jit.trace(model, (main_input, input_ids))
|
|
||||||
except RuntimeError:
|
|
||||||
self.fail("Couldn't trace module.")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
|
||||||
|
|
||||||
try:
|
|
||||||
torch.jit.save(traced_model, pt_file_name)
|
|
||||||
except Exception:
|
|
||||||
self.fail("Couldn't save module.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
loaded_model = torch.jit.load(pt_file_name)
|
|
||||||
except Exception:
|
|
||||||
self.fail("Couldn't load module.")
|
|
||||||
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
loaded_model.to(torch_device)
|
|
||||||
loaded_model.eval()
|
|
||||||
|
|
||||||
model_state_dict = model.state_dict()
|
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
|
||||||
|
|
||||||
non_persistent_buffers = {}
|
|
||||||
for key in loaded_model_state_dict.keys():
|
|
||||||
if key not in model_state_dict.keys():
|
|
||||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
|
||||||
|
|
||||||
loaded_model_state_dict = {
|
|
||||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
|
||||||
}
|
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
|
||||||
|
|
||||||
model_buffers = list(model.buffers())
|
|
||||||
for non_persistent_buffer in non_persistent_buffers.values():
|
|
||||||
found_buffer = False
|
|
||||||
for i, model_buffer in enumerate(model_buffers):
|
|
||||||
if torch.equal(non_persistent_buffer, model_buffer):
|
|
||||||
found_buffer = True
|
|
||||||
break
|
|
||||||
|
|
||||||
self.assertTrue(found_buffer)
|
|
||||||
model_buffers.pop(i)
|
|
||||||
|
|
||||||
models_equal = True
|
|
||||||
for layer_name, p1 in model_state_dict.items():
|
|
||||||
if layer_name in loaded_model_state_dict:
|
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
|
||||||
if p1.data.ne(p2.data).sum() > 0:
|
|
||||||
models_equal = False
|
|
||||||
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|
||||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
|
||||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
|
||||||
self.clear_torch_jit_class_registry()
|
|
||||||
|
|
||||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
@ -1015,7 +904,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = True
|
test_torchscript = False
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|
||||||
# TODO: Fix the failed tests
|
# TODO: Fix the failed tests
|
||||||
@ -1049,116 +938,6 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
|
||||||
# overwrite because BLIP requires ipnut ids and pixel values as input
|
|
||||||
if not self.test_torchscript:
|
|
||||||
self.skipTest(reason="test_torchscript is set to `False`")
|
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
||||||
configs_no_init.torchscript = True
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
for attn_implementation in ["eager", "sdpa"]:
|
|
||||||
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
|
||||||
continue
|
|
||||||
|
|
||||||
configs_no_init._attn_implementation = attn_implementation
|
|
||||||
model = model_class(config=configs_no_init)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
|
|
||||||
main_input_name = model_class.main_input_name
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
|
||||||
main_input = inputs[main_input_name]
|
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
attention_mask = inputs["attention_mask"]
|
|
||||||
decoder_input_ids = inputs["decoder_input_ids"]
|
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
|
||||||
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
|
||||||
traced_model = torch.jit.trace(
|
|
||||||
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
main_input = inputs[main_input_name]
|
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
|
|
||||||
if model.config._attn_implementation == "sdpa":
|
|
||||||
trace_input = {main_input_name: main_input, "input_ids": input_ids}
|
|
||||||
|
|
||||||
if "attention_mask" in inputs:
|
|
||||||
trace_input["attention_mask"] = inputs["attention_mask"]
|
|
||||||
else:
|
|
||||||
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
|
||||||
|
|
||||||
model(main_input, attention_mask=inputs["attention_mask"])
|
|
||||||
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
|
||||||
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
|
||||||
else:
|
|
||||||
model(main_input, input_ids)
|
|
||||||
traced_model = torch.jit.trace(model, (main_input, input_ids))
|
|
||||||
except RuntimeError:
|
|
||||||
self.fail("Couldn't trace module.")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
||||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
|
||||||
|
|
||||||
try:
|
|
||||||
torch.jit.save(traced_model, pt_file_name)
|
|
||||||
except Exception:
|
|
||||||
self.fail("Couldn't save module.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
loaded_model = torch.jit.load(pt_file_name)
|
|
||||||
except Exception:
|
|
||||||
self.fail("Couldn't load module.")
|
|
||||||
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
loaded_model.to(torch_device)
|
|
||||||
loaded_model.eval()
|
|
||||||
|
|
||||||
model_state_dict = model.state_dict()
|
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
|
||||||
|
|
||||||
non_persistent_buffers = {}
|
|
||||||
for key in loaded_model_state_dict.keys():
|
|
||||||
if key not in model_state_dict.keys():
|
|
||||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
|
||||||
|
|
||||||
loaded_model_state_dict = {
|
|
||||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
|
||||||
}
|
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
|
||||||
|
|
||||||
model_buffers = list(model.buffers())
|
|
||||||
for non_persistent_buffer in non_persistent_buffers.values():
|
|
||||||
found_buffer = False
|
|
||||||
for i, model_buffer in enumerate(model_buffers):
|
|
||||||
if torch.equal(non_persistent_buffer, model_buffer):
|
|
||||||
found_buffer = True
|
|
||||||
break
|
|
||||||
|
|
||||||
self.assertTrue(found_buffer)
|
|
||||||
model_buffers.pop(i)
|
|
||||||
|
|
||||||
models_equal = True
|
|
||||||
for layer_name, p1 in model_state_dict.items():
|
|
||||||
if layer_name in loaded_model_state_dict:
|
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
|
||||||
if p1.data.ne(p2.data).sum() > 0:
|
|
||||||
models_equal = False
|
|
||||||
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|
||||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
|
||||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
|
||||||
self.clear_torch_jit_class_registry()
|
|
||||||
|
|
||||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user