Fix phi4_multimodal tests (#38816)

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-18 09:39:17 +02:00 committed by GitHub
parent 3526e25d3d
commit 309e8c96f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
@ -31,7 +30,15 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import backend_empty_cache, require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import (
Expectations,
cleanup,
require_soundfile,
require_torch,
require_torch_large_accelerator,
slow,
torch_device,
)
from transformers.utils import is_soundfile_available
from ...generation.test_utils import GenerationTesterMixin
@ -276,13 +283,14 @@ class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
@slow
class Phi4MultimodalIntegrationTest(unittest.TestCase):
checkpoint_path = "microsoft/Phi-4-multimodal-instruct"
revision = "refs/pr/70"
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
def setUp(self):
# Currently, the Phi-4 checkpoint on the hub is not working with the latest Phi-4 code, so the slow integration tests
# won't pass without using the correct revision (refs/pr/70)
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path)
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path, revision=self.revision)
self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False)
self.user_token = "<|user|>"
self.assistant_token = "<|assistant|>"
@ -294,13 +302,14 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
tmp.seek(0)
self.audio, self.sampling_rate = soundfile.read(tmp.name)
cleanup(torch_device, gc_collect=True)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
cleanup(torch_device, gc_collect=True)
def test_text_only_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
)
prompt = f"{self.user_token}What is the answer for 1+1? Explain it.{self.end_token}{self.assistant_token}"
@ -319,7 +328,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
def test_vision_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
)
prompt = f"{self.user_token}<|image|>What is shown in this image?{self.end_token}{self.assistant_token}"
@ -332,13 +341,20 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
output = output[:, inputs["input_ids"].shape[1] :]
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
EXPECTED_RESPONSE = "The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural"
EXPECTED_RESPONSES = Expectations(
{
("cuda", 7): 'The image shows a vibrant scene at a traditional Chinese-style street entrance, known as a "gate"',
("cuda", 8): 'The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural',
}
) # fmt: skip
EXPECTED_RESPONSE = EXPECTED_RESPONSES.get_expectation()
self.assertEqual(response, EXPECTED_RESPONSE)
@require_torch_large_accelerator
def test_multi_image_vision_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
)
images = []
@ -365,7 +381,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
@require_soundfile
def test_audio_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
)
prompt = f"{self.user_token}<|audio|>What is happening in this audio?{self.end_token}{self.assistant_token}"