mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
c6814b4ee8
commit
69632aadb7
@ -631,8 +631,10 @@ class OffloadedCache(DynamicCache):
|
||||
def prefetch_layer(self, layer_idx: int):
|
||||
"Starts prefetching the next layer cache"
|
||||
if layer_idx < len(self):
|
||||
with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream(
|
||||
with (
|
||||
self.prefetch_stream
|
||||
if is_torch_greater_or_equal("2.7", accept_dev=True)
|
||||
else torch.cuda.stream(self.prefetch_stream)
|
||||
):
|
||||
# Prefetch next layer tensors to GPU
|
||||
device = self.original_device[layer_idx]
|
||||
|
@ -129,7 +129,7 @@ class Qwen2AudioProcessor(ProcessorMixin):
|
||||
if audio is not None:
|
||||
# ensure we have as much audios as audio tokens
|
||||
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
|
||||
num_audios = 1 if type(audio) == np.ndarray else len(audio)
|
||||
num_audios = 1 if type(audio) is np.ndarray else len(audio)
|
||||
if num_audio_tokens != num_audios:
|
||||
raise ValueError(
|
||||
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
|
||||
|
Loading…
Reference in New Issue
Block a user