[tokenizers] convert_to_tensors: don't reconvert when the type is already right (#8283)

* don't reconvert when the type is already right

* better name

* adjust logic as suggested

* merge
This commit is contained in:
Stas Bekman 2020-11-19 12:06:01 -08:00 committed by GitHub
parent 20b658607e
commit 42111f1d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 9 deletions

View File

@ -53,6 +53,15 @@ if is_torch_available():
if is_flax_available():
import jax.numpy as jnp
def _is_numpy(x):
return isinstance(x, np.ndarray)
def _is_jax(x):
return isinstance(x, jnp.ndarray)
if is_tokenizers_available():
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
@ -705,16 +714,20 @@ class BatchEncoding(UserDict):
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
as_tensor = jnp.array
is_tensor = _is_jax
else:
as_tensor = np.asarray
is_tensor = _is_numpy
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
@ -727,16 +740,17 @@ class BatchEncoding(UserDict):
if prepend_batch_axis:
value = [value]
tensor = as_tensor(value)
if not is_tensor(value):
tensor = as_tensor(value)
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]
self[key] = tensor
self[key] = tensor
except: # noqa E722
if key == "overflowing_tokens":
raise ValueError(

View File

@ -20,7 +20,7 @@ import numpy as np
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
from transformers.testing_utils import CaptureStderr, require_flax, require_tf, require_tokenizers, require_torch, slow
class TokenizerUtilsTest(unittest.TestCase):
@ -156,6 +156,10 @@ class TokenizerUtilsTest(unittest.TestCase):
tensor_batch = batch.convert_to_tensors(tensor_type="np")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="np")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True)
@ -168,6 +172,10 @@ class TokenizerUtilsTest(unittest.TestCase):
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
@ -180,12 +188,32 @@ class TokenizerUtilsTest(unittest.TestCase):
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))
@require_flax
def test_batch_encoding_with_labels_jax(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="jax")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="jax")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="jax", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))
def test_padding_accepts_tensors(self):
features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")