mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[tokenizers] convert_to_tensors: don't reconvert when the type is already right (#8283)
* don't reconvert when the type is already right * better name * adjust logic as suggested * merge
This commit is contained in:
parent
20b658607e
commit
42111f1d56
@ -53,6 +53,15 @@ if is_torch_available():
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def _is_numpy(x):
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
def _is_jax(x):
|
||||
return isinstance(x, jnp.ndarray)
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from tokenizers import AddedToken
|
||||
from tokenizers import Encoding as EncodingFast
|
||||
@ -705,16 +714,20 @@ class BatchEncoding(UserDict):
|
||||
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
|
||||
)
|
||||
as_tensor = tf.constant
|
||||
is_tensor = tf.is_tensor
|
||||
elif tensor_type == TensorType.PYTORCH:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
as_tensor = torch.tensor
|
||||
is_tensor = torch.is_tensor
|
||||
elif tensor_type == TensorType.JAX:
|
||||
if not is_flax_available():
|
||||
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||
as_tensor = jnp.array
|
||||
is_tensor = _is_jax
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
is_tensor = _is_numpy
|
||||
# (mfuntowicz: This code is unreachable)
|
||||
# else:
|
||||
# raise ImportError(
|
||||
@ -727,16 +740,17 @@ class BatchEncoding(UserDict):
|
||||
if prepend_batch_axis:
|
||||
value = [value]
|
||||
|
||||
tensor = as_tensor(value)
|
||||
if not is_tensor(value):
|
||||
tensor = as_tensor(value)
|
||||
|
||||
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
|
||||
# # at-least2d
|
||||
# if tensor.ndim > 2:
|
||||
# tensor = tensor.squeeze(0)
|
||||
# elif tensor.ndim < 2:
|
||||
# tensor = tensor[None, :]
|
||||
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
|
||||
# # at-least2d
|
||||
# if tensor.ndim > 2:
|
||||
# tensor = tensor.squeeze(0)
|
||||
# elif tensor.ndim < 2:
|
||||
# tensor = tensor[None, :]
|
||||
|
||||
self[key] = tensor
|
||||
self[key] = tensor
|
||||
except: # noqa E722
|
||||
if key == "overflowing_tokens":
|
||||
raise ValueError(
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
|
||||
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
|
||||
from transformers.testing_utils import CaptureStderr, require_flax, require_tf, require_tokenizers, require_torch, slow
|
||||
|
||||
|
||||
class TokenizerUtilsTest(unittest.TestCase):
|
||||
@ -156,6 +156,10 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="np")
|
||||
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
|
||||
self.assertEqual(tensor_batch["labels"].shape, (2,))
|
||||
# test converting the converted
|
||||
with CaptureStderr() as cs:
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="np")
|
||||
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
|
||||
|
||||
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True)
|
||||
@ -168,6 +172,10 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
|
||||
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
|
||||
self.assertEqual(tensor_batch["labels"].shape, (2,))
|
||||
# test converting the converted
|
||||
with CaptureStderr() as cs:
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
|
||||
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
|
||||
|
||||
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
|
||||
@ -180,12 +188,32 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
|
||||
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
|
||||
self.assertEqual(tensor_batch["labels"].shape, (2,))
|
||||
# test converting the converted
|
||||
with CaptureStderr() as cs:
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
|
||||
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
|
||||
|
||||
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True)
|
||||
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
|
||||
self.assertEqual(tensor_batch["labels"].shape, (1,))
|
||||
|
||||
@require_flax
|
||||
def test_batch_encoding_with_labels_jax(self):
|
||||
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="jax")
|
||||
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
|
||||
self.assertEqual(tensor_batch["labels"].shape, (2,))
|
||||
# test converting the converted
|
||||
with CaptureStderr() as cs:
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="jax")
|
||||
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
|
||||
|
||||
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
|
||||
tensor_batch = batch.convert_to_tensors(tensor_type="jax", prepend_batch_axis=True)
|
||||
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
|
||||
self.assertEqual(tensor_batch["labels"].shape, (1,))
|
||||
|
||||
def test_padding_accepts_tensors(self):
|
||||
features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}]
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
|
Loading…
Reference in New Issue
Block a user