mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Fix phi4_multimodal
tests (#38816)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
3526e25d3d
commit
309e8c96f2
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -31,7 +30,15 @@ from transformers import (
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import backend_empty_cache, require_soundfile, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
require_torch_large_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_soundfile_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -276,13 +283,14 @@ class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
@slow
|
||||
class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
||||
checkpoint_path = "microsoft/Phi-4-multimodal-instruct"
|
||||
revision = "refs/pr/70"
|
||||
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
|
||||
|
||||
def setUp(self):
|
||||
# Currently, the Phi-4 checkpoint on the hub is not working with the latest Phi-4 code, so the slow integration tests
|
||||
# won't pass without using the correct revision (refs/pr/70)
|
||||
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path)
|
||||
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path, revision=self.revision)
|
||||
self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False)
|
||||
self.user_token = "<|user|>"
|
||||
self.assistant_token = "<|assistant|>"
|
||||
@ -294,13 +302,14 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
||||
tmp.seek(0)
|
||||
self.audio, self.sampling_rate = soundfile.read(tmp.name)
|
||||
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_text_only_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
prompt = f"{self.user_token}What is the answer for 1+1? Explain it.{self.end_token}{self.assistant_token}"
|
||||
@ -319,7 +328,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_vision_text_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
prompt = f"{self.user_token}<|image|>What is shown in this image?{self.end_token}{self.assistant_token}"
|
||||
@ -332,13 +341,20 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
||||
output = output[:, inputs["input_ids"].shape[1] :]
|
||||
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
EXPECTED_RESPONSE = "The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural"
|
||||
EXPECTED_RESPONSES = Expectations(
|
||||
{
|
||||
("cuda", 7): 'The image shows a vibrant scene at a traditional Chinese-style street entrance, known as a "gate"',
|
||||
("cuda", 8): 'The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural',
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_RESPONSE = EXPECTED_RESPONSES.get_expectation()
|
||||
|
||||
self.assertEqual(response, EXPECTED_RESPONSE)
|
||||
|
||||
@require_torch_large_accelerator
|
||||
def test_multi_image_vision_text_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
images = []
|
||||
@ -365,7 +381,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
||||
@require_soundfile
|
||||
def test_audio_text_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
self.checkpoint_path, revision=self.revision, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
prompt = f"{self.user_token}<|audio|>What is happening in this audio?{self.end_token}{self.assistant_token}"
|
||||
|
Loading…
Reference in New Issue
Block a user