BatchFeature: Convert List[np.ndarray] to np.ndarray before converting to pytorch tensors (#14306)

* update

* style fix

* retrigger checks

* check first element

* fix syntax error

* Update src/transformers/feature_extraction_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove import

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Elad Segal 2021-11-10 05:23:08 +02:00 committed by GitHub
parent 46d0cdae40
commit 321eb56222
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -138,7 +138,11 @@ class BatchFeature(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
value = np.array(value)
return torch.tensor(value)
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():