mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix GitModelIntegrationTest.test_batched_generation
device issue (#21362)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
73a2ff6974
commit
a582cfce3c
@ -508,9 +508,8 @@ class GitModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# we have to prepare `input_ids` with the same batch size as `pixel_values`
|
||||
start_token_id = model.config.bos_token_id
|
||||
generated_ids = model.generate(
|
||||
pixel_values=pixel_values, input_ids=torch.tensor([[start_token_id], [start_token_id]]), max_length=50
|
||||
)
|
||||
input_ids = torch.tensor([[start_token_id], [start_token_id]], device=torch_device)
|
||||
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
|
||||
generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
self.assertEquals(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2)
|
||||
|
Loading…
Reference in New Issue
Block a user