From e508965df7edac73caa9fe9935f22a5cad143b1d Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Thu, 29 May 2025 15:13:43 +0100 Subject: [PATCH] Cleanup `BatchFeature` and `BatchEncoding` (#38459) * Use dict comprehension to create dict * Fix type annotation Union[Any] doesn't really make any sense * Remove methods that are already implemented in the `UserDict` parent class --- src/transformers/feature_extraction_utils.py | 27 ++++++-------------- src/transformers/tokenization_utils_base.py | 9 ------- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 51e882aefa8..732f044e077 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -75,7 +75,7 @@ class BatchFeature(UserDict): super().__init__(data) self.convert_to_tensors(tensor_type=tensor_type) - def __getitem__(self, item: str) -> Union[Any]: + def __getitem__(self, item: str) -> Any: """ If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask', etc.). @@ -98,18 +98,6 @@ class BatchFeature(UserDict): if "data" in state: self.data = state["data"] - # Copied from transformers.tokenization_utils_base.BatchEncoding.keys - def keys(self): - return self.data.keys() - - # Copied from transformers.tokenization_utils_base.BatchEncoding.values - def values(self): - return self.data.values() - - # Copied from transformers.tokenization_utils_base.BatchEncoding.items - def items(self): - return self.data.items() - def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None): if tensor_type is None: return None, None @@ -218,7 +206,6 @@ class BatchFeature(UserDict): requires_backends(self, ["torch"]) import torch # noqa - new_data = {} device = kwargs.get("device") non_blocking = kwargs.get("non_blocking", False) # Check if the args are a device or a dtype @@ -233,17 +220,19 @@ class BatchFeature(UserDict): else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - for k, v in self.items(): + def maybe_to(v): # check if v is a floating point if isinstance(v, torch.Tensor) and torch.is_floating_point(v): # cast and send to device - new_data[k] = v.to(*args, **kwargs) + return v.to(*args, **kwargs) elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device, non_blocking=non_blocking) + return v.to(device=device, non_blocking=non_blocking) else: - new_data[k] = v - self.data = new_data + return v + + self.data = {k: maybe_to(v) for k, v in self.items()} return self diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index ec0bd53bd57..b6ed3c677b6 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -294,15 +294,6 @@ class BatchEncoding(UserDict): if "encodings" in state: self._encodings = state["encodings"] - def keys(self): - return self.data.keys() - - def values(self): - return self.data.values() - - def items(self): - return self.data.items() - # After this point: # Extended properties and methods only available for fast (Rust-based) tokenizers # provided by HuggingFace tokenizers library.