improve _get_is_as_tensor_fns (#32596)

* improve _get_is_as_tensor_fns

* format
This commit is contained in:
Zhan Rongrui 2024-08-16 22:59:44 +08:00 committed by GitHub
parent a27182b7fc
commit f20d0e81ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -137,8 +137,15 @@ class BatchFeature(UserDict):
import torch # noqa
def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
value = np.array(value)
if isinstance(value, (list, tuple)) and len(value) > 0:
if isinstance(value[0], np.ndarray):
value = np.array(value)
elif (
isinstance(value[0], (list, tuple))
and len(value[0]) > 0
and isinstance(value[0][0], np.ndarray)
):
value = np.array(value)
return torch.tensor(value)
is_tensor = torch.is_tensor