Remove deprecated code (#37059)

* Remove deprecated code

* fix get_loading_attributes

* fix error

* skip test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
cyyever 2025-03-31 17:15:35 +08:00 committed by GitHub
parent d1efaf0318
commit f99c279d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 13 additions and 66 deletions

View File

@ -47,7 +47,7 @@ from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
default_data_collator, default_data_collator,
is_torch_tpu_available, is_torch_xla_available,
set_seed, set_seed,
) )
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
@ -525,7 +525,7 @@ def main():
if torch.cuda.is_availble(): if torch.cuda.is_availble():
pad_factor = 8 pad_factor = 8
elif is_torch_tpu_available(): elif is_torch_xla_available(check_is_tpu=True):
pad_factor = 128 pad_factor = 128
# Add the new tokens to the tokenizer # Add the new tokens to the tokenizer
@ -795,9 +795,13 @@ def main():
processing_class=tokenizer, processing_class=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it. # Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator, data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, compute_metrics=compute_metrics
if training_args.do_eval and not is_torch_xla_available(check_is_tpu=True)
else None,
preprocess_logits_for_metrics=( preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_xla_available(check_is_tpu=True)
else None
), ),
) )

View File

@ -52,7 +52,7 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
is_torch_tpu_available, is_torch_xla_available,
) )
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
@ -492,7 +492,7 @@ def main():
if torch.cuda.is_availble(): if torch.cuda.is_availble():
pad_factor = 8 pad_factor = 8
elif is_torch_tpu_available(): elif is_torch_xla_available(check_is_tpu=True):
pad_factor = 128 pad_factor = 128
# Add the new tokens to the tokenizer # Add the new tokens to the tokenizer

View File

@ -1037,7 +1037,6 @@ _import_structure = {
"is_torch_musa_available", "is_torch_musa_available",
"is_torch_neuroncore_available", "is_torch_neuroncore_available",
"is_torch_npu_available", "is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available", "is_torchvision_available",
"is_torch_xla_available", "is_torch_xla_available",
"is_torch_xpu_available", "is_torch_xpu_available",
@ -6341,7 +6340,6 @@ if TYPE_CHECKING:
is_torch_musa_available, is_torch_musa_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchvision_available, is_torchvision_available,

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from collections.abc import Collection, Iterable from collections.abc import Collection, Iterable
from math import ceil from math import ceil
from typing import Optional, Union from typing import Optional, Union
@ -453,7 +452,6 @@ def center_crop(
size: tuple[int, int], size: tuple[int, int],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_numpy: Optional[bool] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
@ -474,22 +472,11 @@ def center_crop(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image. If unset, will use the inferred format of the input image.
return_numpy (`bool`, *optional*):
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
previous ImageFeatureExtractionMixin method.
- Unset: will return the same type as the input image.
- `True`: will return a numpy array.
- `False`: will return a `PIL.Image.Image` object.
Returns: Returns:
`np.ndarray`: The cropped image. `np.ndarray`: The cropped image.
""" """
requires_backends(center_crop, ["vision"]) requires_backends(center_crop, ["vision"])
if return_numpy is not None:
warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
return_numpy = True if return_numpy is None else return_numpy
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
@ -541,9 +528,6 @@ def center_crop(
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST) new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
if not return_numpy:
new_image = to_pil_image(new_image)
return new_image return new_image

View File

@ -228,7 +228,6 @@ from .import_utils import (
is_torch_sdpa_available, is_torch_sdpa_available,
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchao_available, is_torchao_available,

View File

@ -675,31 +675,6 @@ def is_g2p_en_available():
return _g2p_en_available return _g2p_en_available
@lru_cache()
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
"Please use the `is_torch_xla_available` instead.",
FutureWarning,
)
if not _torch_available:
return False
if importlib.util.find_spec("torch_xla") is not None:
if check_device:
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
import torch_xla.core.xla_model as xm
_ = xm.xla_device()
return True
except RuntimeError:
return False
return True
return False
@lru_cache @lru_cache
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
""" """

View File

@ -682,7 +682,6 @@ class GPTQConfig(QuantizationConfigMixin):
self.use_exllama = use_exllama self.use_exllama = use_exllama
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.exllama_config = exllama_config self.exllama_config = exllama_config
self.disable_exllama = kwargs.pop("disable_exllama", None)
self.cache_block_outputs = cache_block_outputs self.cache_block_outputs = cache_block_outputs
self.modules_in_block_to_quantize = modules_in_block_to_quantize self.modules_in_block_to_quantize = modules_in_block_to_quantize
self.post_init() self.post_init()
@ -690,7 +689,6 @@ class GPTQConfig(QuantizationConfigMixin):
def get_loading_attributes(self): def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__) attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = [ loading_attibutes = [
"disable_exllama",
"use_exllama", "use_exllama",
"exllama_config", "exllama_config",
"use_cuda_fp16", "use_cuda_fp16",
@ -739,20 +737,9 @@ class GPTQConfig(QuantizationConfigMixin):
self.use_exllama = False self.use_exllama = False
# auto-gptq specific kernel control logic # auto-gptq specific kernel control logic
if self.disable_exllama is None and self.use_exllama is None: if self.use_exllama is None:
# New default behaviour # New default behaviour
self.use_exllama = True self.use_exllama = True
elif self.disable_exllama is not None and self.use_exllama is None:
# Follow pattern of old config
logger.warning(
"Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
"The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
)
self.use_exllama = not self.disable_exllama
self.disable_exllama = None
elif self.disable_exllama is not None and self.use_exllama is not None:
# Only happens if user explicitly passes in both arguments
raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
if self.exllama_config is None: if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.ONE} self.exllama_config = {"version": ExllamaVersion.ONE}
@ -809,7 +796,7 @@ class GPTQConfig(QuantizationConfigMixin):
if "disable_exllama" in config_dict: if "disable_exllama" in config_dict:
config_dict["use_exllama"] = not config_dict["disable_exllama"] config_dict["use_exllama"] = not config_dict["disable_exllama"]
# switch to None to not trigger the warning # switch to None to not trigger the warning
config_dict["disable_exllama"] = None config_dict.pop("disable_exllama")
config = cls(**config_dict) config = cls(**config_dict)
return config return config

View File

@ -592,7 +592,7 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
# TODO: @ydshieh: refer to #34968 # TODO: @ydshieh: refer to #34968
@unittest.skip(reason="Failing on multi-gpu runner") @unittest.skip(reason="Failing on multi-gpu runner")
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
pass self.skipTest(reason="Failing on multi-gpu runner")
@require_torch @require_torch