mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Remove deprecated code (#37059)
* Remove deprecated code * fix get_loading_attributes * fix error * skip test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
d1efaf0318
commit
f99c279d20
@ -47,7 +47,7 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@ -525,7 +525,7 @@ def main():
|
||||
if torch.cuda.is_availble():
|
||||
pad_factor = 8
|
||||
|
||||
elif is_torch_tpu_available():
|
||||
elif is_torch_xla_available(check_is_tpu=True):
|
||||
pad_factor = 128
|
||||
|
||||
# Add the new tokens to the tokenizer
|
||||
@ -795,9 +795,13 @@ def main():
|
||||
processing_class=tokenizer,
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||
data_collator=default_data_collator,
|
||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
||||
compute_metrics=compute_metrics
|
||||
if training_args.do_eval and not is_torch_xla_available(check_is_tpu=True)
|
||||
else None,
|
||||
preprocess_logits_for_metrics=(
|
||||
preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None
|
||||
preprocess_logits_for_metrics
|
||||
if training_args.do_eval and not is_torch_xla_available(check_is_tpu=True)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers import (
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
)
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
@ -492,7 +492,7 @@ def main():
|
||||
if torch.cuda.is_availble():
|
||||
pad_factor = 8
|
||||
|
||||
elif is_torch_tpu_available():
|
||||
elif is_torch_xla_available(check_is_tpu=True):
|
||||
pad_factor = 128
|
||||
|
||||
# Add the new tokens to the tokenizer
|
||||
|
@ -1037,7 +1037,6 @@ _import_structure = {
|
||||
"is_torch_musa_available",
|
||||
"is_torch_neuroncore_available",
|
||||
"is_torch_npu_available",
|
||||
"is_torch_tpu_available",
|
||||
"is_torchvision_available",
|
||||
"is_torch_xla_available",
|
||||
"is_torch_xpu_available",
|
||||
@ -6341,7 +6340,6 @@ if TYPE_CHECKING:
|
||||
is_torch_musa_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchvision_available,
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from collections.abc import Collection, Iterable
|
||||
from math import ceil
|
||||
from typing import Optional, Union
|
||||
@ -453,7 +452,6 @@ def center_crop(
|
||||
size: tuple[int, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
return_numpy: Optional[bool] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
|
||||
@ -474,22 +472,11 @@ def center_crop(
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
return_numpy (`bool`, *optional*):
|
||||
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
|
||||
previous ImageFeatureExtractionMixin method.
|
||||
- Unset: will return the same type as the input image.
|
||||
- `True`: will return a numpy array.
|
||||
- `False`: will return a `PIL.Image.Image` object.
|
||||
Returns:
|
||||
`np.ndarray`: The cropped image.
|
||||
"""
|
||||
requires_backends(center_crop, ["vision"])
|
||||
|
||||
if return_numpy is not None:
|
||||
warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
|
||||
|
||||
return_numpy = True if return_numpy is None else return_numpy
|
||||
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
@ -541,9 +528,6 @@ def center_crop(
|
||||
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
|
||||
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
|
||||
|
||||
if not return_numpy:
|
||||
new_image = to_pil_image(new_image)
|
||||
|
||||
return new_image
|
||||
|
||||
|
||||
|
@ -228,7 +228,6 @@ from .import_utils import (
|
||||
is_torch_sdpa_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchao_available,
|
||||
|
@ -675,31 +675,6 @@ def is_g2p_en_available():
|
||||
return _g2p_en_available
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_torch_tpu_available(check_device=True):
|
||||
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
||||
warnings.warn(
|
||||
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
|
||||
"Please use the `is_torch_xla_available` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if not _torch_available:
|
||||
return False
|
||||
if importlib.util.find_spec("torch_xla") is not None:
|
||||
if check_device:
|
||||
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_ = xm.xla_device()
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
|
||||
"""
|
||||
|
@ -682,7 +682,6 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
self.use_exllama = use_exllama
|
||||
self.max_input_length = max_input_length
|
||||
self.exllama_config = exllama_config
|
||||
self.disable_exllama = kwargs.pop("disable_exllama", None)
|
||||
self.cache_block_outputs = cache_block_outputs
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize
|
||||
self.post_init()
|
||||
@ -690,7 +689,6 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
def get_loading_attributes(self):
|
||||
attibutes_dict = copy.deepcopy(self.__dict__)
|
||||
loading_attibutes = [
|
||||
"disable_exllama",
|
||||
"use_exllama",
|
||||
"exllama_config",
|
||||
"use_cuda_fp16",
|
||||
@ -739,20 +737,9 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
self.use_exllama = False
|
||||
|
||||
# auto-gptq specific kernel control logic
|
||||
if self.disable_exllama is None and self.use_exllama is None:
|
||||
if self.use_exllama is None:
|
||||
# New default behaviour
|
||||
self.use_exllama = True
|
||||
elif self.disable_exllama is not None and self.use_exllama is None:
|
||||
# Follow pattern of old config
|
||||
logger.warning(
|
||||
"Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
|
||||
"The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
|
||||
)
|
||||
self.use_exllama = not self.disable_exllama
|
||||
self.disable_exllama = None
|
||||
elif self.disable_exllama is not None and self.use_exllama is not None:
|
||||
# Only happens if user explicitly passes in both arguments
|
||||
raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
|
||||
|
||||
if self.exllama_config is None:
|
||||
self.exllama_config = {"version": ExllamaVersion.ONE}
|
||||
@ -809,7 +796,7 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
if "disable_exllama" in config_dict:
|
||||
config_dict["use_exllama"] = not config_dict["disable_exllama"]
|
||||
# switch to None to not trigger the warning
|
||||
config_dict["disable_exllama"] = None
|
||||
config_dict.pop("disable_exllama")
|
||||
|
||||
config = cls(**config_dict)
|
||||
return config
|
||||
|
@ -592,7 +592,7 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
||||
# TODO: @ydshieh: refer to #34968
|
||||
@unittest.skip(reason="Failing on multi-gpu runner")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
self.skipTest(reason="Failing on multi-gpu runner")
|
||||
|
||||
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user