mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
improve _get_is_as_tensor_fns (#32596)
* improve _get_is_as_tensor_fns * format
This commit is contained in:
parent
a27182b7fc
commit
f20d0e81ea
@ -137,8 +137,15 @@ class BatchFeature(UserDict):
|
||||
import torch # noqa
|
||||
|
||||
def as_tensor(value):
|
||||
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
|
||||
value = np.array(value)
|
||||
if isinstance(value, (list, tuple)) and len(value) > 0:
|
||||
if isinstance(value[0], np.ndarray):
|
||||
value = np.array(value)
|
||||
elif (
|
||||
isinstance(value[0], (list, tuple))
|
||||
and len(value[0]) > 0
|
||||
and isinstance(value[0][0], np.ndarray)
|
||||
):
|
||||
value = np.array(value)
|
||||
return torch.tensor(value)
|
||||
|
||||
is_tensor = torch.is_tensor
|
||||
|
Loading…
Reference in New Issue
Block a user