mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
BatchFeature
: Convert List[np.ndarray]
to np.ndarray
before converting to pytorch tensors (#14306)
* update * style fix * retrigger checks * check first element * fix syntax error * Update src/transformers/feature_extraction_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove import Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
46d0cdae40
commit
321eb56222
@ -138,7 +138,11 @@ class BatchFeature(UserDict):
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
import torch
|
||||
|
||||
as_tensor = torch.tensor
|
||||
def as_tensor(value):
|
||||
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
|
||||
value = np.array(value)
|
||||
return torch.tensor(value)
|
||||
|
||||
is_tensor = torch.is_tensor
|
||||
elif tensor_type == TensorType.JAX:
|
||||
if not is_flax_available():
|
||||
|
Loading…
Reference in New Issue
Block a user