Fix return_dict in encodec (#31646)

* fix: use return_dict parameter

* fix: type checks

* fix: unused imports

* update: one-line if else

* remove: recursive check
This commit is contained in:
Jacky Lee 2024-06-28 04:18:01 -07:00 committed by GitHub
parent 5e89b335ab
commit 82a1fc7256
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 27 deletions

View File

@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
return_dict = return_dict or self.config.return_dict
return_dict = return_dict if return_dict is not None else self.config.return_dict
chunk_length = self.config.chunk_length
if chunk_length is None:
@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
>>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio_values
```"""
return_dict = return_dict or self.config.return_dict
return_dict = return_dict if return_dict is not None else self.config.return_dict
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()

View File

@ -19,7 +19,6 @@ import inspect
import os
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
from datasets import Audio, load_dataset
@ -385,31 +384,21 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
self.assertTrue(isinstance(tuple_output, tuple))
self.assertTrue(isinstance(dict_output, dict))
recursive_check(tuple_output, dict_output)
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
),
)
for model_class in self.all_model_classes:
model = model_class(config)