mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix return_dict in encodec (#31646)
* fix: use return_dict parameter * fix: type checks * fix: unused imports * update: one-line if else * remove: recursive check
This commit is contained in:
parent
5e89b335ab
commit
82a1fc7256
@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
"""
|
||||
return_dict = return_dict or self.config.return_dict
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
chunk_length = self.config.chunk_length
|
||||
if chunk_length is None:
|
||||
@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
||||
>>> audio_codes = outputs.audio_codes
|
||||
>>> audio_values = outputs.audio_values
|
||||
```"""
|
||||
return_dict = return_dict or self.config.return_dict
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.ones_like(input_values).bool()
|
||||
|
@ -19,7 +19,6 @@ import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import Audio, load_dataset
|
||||
@ -385,31 +384,21 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
self.assertTrue(isinstance(tuple_output, tuple))
|
||||
self.assertTrue(isinstance(dict_output, dict))
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
|
||||
),
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
Loading…
Reference in New Issue
Block a user