Remove fast tokenization warning in Data Collators (#28213)

This commit is contained in:
Daniel Bustamante Ospina 2024-01-02 13:32:23 -05:00 committed by GitHub
parent 5be46dfc09
commit aa4a0f8ef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,6 +49,28 @@ class DataCollatorMixin:
raise ValueError(f"Framework '{return_tensors}' not recognized!")
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
"""
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
"""
# To avoid errors when using Feature extractors
if not hasattr(tokenizer, "deprecation_warnings"):
return tokenizer.pad(*pad_args, **pad_kwargs)
# Save the state of the warning, then disable it
warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
try:
padded = tokenizer.pad(*pad_args, **pad_kwargs)
finally:
# Restore the state of the warning.
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
return padded
def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
"""
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
@ -246,7 +268,8 @@ class DataCollatorWithPadding:
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch = self.tokenizer.pad(
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
features,
padding=self.padding,
max_length=self.max_length,
@ -307,7 +330,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
batch = self.tokenizer.pad(
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
no_labels_features,
padding=self.padding,
max_length=self.max_length,
@ -343,7 +367,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
batch = self.tokenizer.pad(
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
features,
padding=self.padding,
max_length=self.max_length,
@ -372,7 +397,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
def numpy_call(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
batch = self.tokenizer.pad(
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
features,
padding=self.padding,
max_length=self.max_length,
@ -583,7 +609,8 @@ class DataCollatorForSeq2Seq:
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
features = self.tokenizer.pad(
features = pad_without_fast_tokenizer_warning(
self.tokenizer,
features,
padding=self.padding,
max_length=self.max_length,
@ -692,7 +719,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of
)
else:
batch = {
"input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
@ -729,7 +758,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
)
else:
batch = {
"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
@ -784,7 +815,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
)
else:
batch = {
"input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)