mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Aurevoir PyTorch 1 (#35358)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
4567ee8057
commit
05de764e9c
@ -21,39 +21,6 @@ jobs:
|
||||
echo "$(python3 -c 'print(int(${{ github.run_number }}) % 10)')"
|
||||
echo "run_number=$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" >> $GITHUB_OUTPUT
|
||||
|
||||
run_past_ci_pytorch_1-13:
|
||||
name: PyTorch 1.13
|
||||
needs: get_number
|
||||
if: needs.get_number.outputs.run_number == 0 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
|
||||
uses: ./.github/workflows/self-past-caller.yml
|
||||
with:
|
||||
framework: pytorch
|
||||
version: "1.13"
|
||||
sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
run_past_ci_pytorch_1-12:
|
||||
name: PyTorch 1.12
|
||||
needs: get_number
|
||||
if: needs.get_number.outputs.run_number == 1 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
|
||||
uses: ./.github/workflows/self-past-caller.yml
|
||||
with:
|
||||
framework: pytorch
|
||||
version: "1.12"
|
||||
sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
run_past_ci_pytorch_1-11:
|
||||
name: PyTorch 1.11
|
||||
needs: get_number
|
||||
if: needs.get_number.outputs.run_number == 2 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
|
||||
uses: ./.github/workflows/self-past-caller.yml
|
||||
with:
|
||||
framework: pytorch
|
||||
version: "1.11"
|
||||
sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
run_past_ci_tensorflow_2-11:
|
||||
name: TensorFlow 2.11
|
||||
needs: get_number
|
||||
|
@ -249,7 +249,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta
|
||||
|
||||
### With pip
|
||||
|
||||
This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, and TensorFlow 2.6+.
|
||||
This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, and TensorFlow 2.6+.
|
||||
|
||||
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -245,7 +245,7 @@ limitations under the License.
|
||||
|
||||
### باستخدام pip
|
||||
|
||||
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، و TensorFlow 2.6+.
|
||||
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، و TensorFlow 2.6+.
|
||||
|
||||
يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -246,7 +246,7 @@ Das Modell selbst ist ein reguläres [PyTorch `nn.Module`](https://pytorch.org/d
|
||||
|
||||
### Mit pip
|
||||
|
||||
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ und TensorFlow 2.6+ getestet.
|
||||
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ und TensorFlow 2.6+ getestet.
|
||||
|
||||
Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an.
|
||||
|
||||
|
@ -222,7 +222,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h
|
||||
|
||||
### Con pip
|
||||
|
||||
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ y TensorFlow 2.6+.
|
||||
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ y TensorFlow 2.6+.
|
||||
|
||||
Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -243,7 +243,7 @@ Le modèle lui-même est un module [`nn.Module` PyTorch](https://pytorch.org/doc
|
||||
|
||||
### Avec pip
|
||||
|
||||
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ et TensorFlow 2.6+.
|
||||
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ et TensorFlow 2.6+.
|
||||
|
||||
Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -198,7 +198,7 @@ checkpoint: जाँच बिंदु
|
||||
|
||||
### पिप का उपयोग करना
|
||||
|
||||
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ और TensorFlow 2.6+ के तहत किया गया है।
|
||||
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ और TensorFlow 2.6+ के तहत किया गया है।
|
||||
|
||||
आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।
|
||||
|
||||
|
@ -256,7 +256,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを
|
||||
|
||||
### pipにて
|
||||
|
||||
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+ でテストされています。
|
||||
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+ でテストされています。
|
||||
|
||||
🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。
|
||||
|
||||
|
@ -242,7 +242,7 @@ Transformers에 달린 100,000개의 별을 축하하기 위해, 우리는 커
|
||||
|
||||
### pip로 설치하기
|
||||
|
||||
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+에서 테스트 되었습니다.
|
||||
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+에서 테스트 되었습니다.
|
||||
|
||||
[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.
|
||||
|
||||
|
@ -253,7 +253,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht
|
||||
|
||||
### Com pip
|
||||
|
||||
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ e TensorFlow 2.6+.
|
||||
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ e TensorFlow 2.6+.
|
||||
|
||||
Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -244,7 +244,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра
|
||||
|
||||
### С помощью pip
|
||||
|
||||
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ и TensorFlow 2.6+.
|
||||
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ и TensorFlow 2.6+.
|
||||
|
||||
Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -246,7 +246,7 @@ limitations under the License.
|
||||
|
||||
### పిప్ తో
|
||||
|
||||
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 1.11+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
|
||||
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.0+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
|
||||
|
||||
మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్ఫార్మర్లను ఇన్స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.
|
||||
|
||||
|
@ -259,7 +259,7 @@ limitations under the License.
|
||||
|
||||
#### ‏ pip کے ساتھ
|
||||
|
||||
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
|
||||
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
|
||||
|
||||
آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔
|
||||
|
||||
|
@ -245,7 +245,7 @@ Chính mô hình là một [Pytorch `nn.Module`](https://pytorch.org/docs/stable
|
||||
|
||||
### Sử dụng pip
|
||||
|
||||
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ và TensorFlow 2.6+.
|
||||
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ và TensorFlow 2.6+.
|
||||
|
||||
Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -198,7 +198,7 @@ checkpoint: 检查点
|
||||
|
||||
### 使用 pip
|
||||
|
||||
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下经过测试。
|
||||
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下经过测试。
|
||||
|
||||
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||
|
||||
|
@ -210,7 +210,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換
|
||||
|
||||
### 使用 pip
|
||||
|
||||
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下經過測試。
|
||||
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下經過測試。
|
||||
|
||||
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||
|
||||
|
@ -106,7 +106,6 @@ if is_torch_available():
|
||||
XLMWithLMHeadModel,
|
||||
XLNetLMHeadModel,
|
||||
)
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
@ -279,7 +278,7 @@ def convert_pt_checkpoint_to_tf(
|
||||
if compare_with_pt_model:
|
||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
state_dict = torch.load(
|
||||
pytorch_checkpoint_path,
|
||||
map_location="cpu",
|
||||
|
@ -63,8 +63,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
else:
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||
@ -73,7 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
)
|
||||
raise
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
@ -246,13 +244,11 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
import torch
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
|
||||
# Load the index
|
||||
flax_state_dict = {}
|
||||
for shard_file in shard_filenames:
|
||||
# load using msgpack utils
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
|
||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||
pt_state_dict = {
|
||||
|
@ -180,8 +180,6 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
import tensorflow as tf # noqa: F401
|
||||
import torch # noqa: F401
|
||||
from safetensors.torch import load_file as safe_load_file # noqa: F401
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||
@ -201,7 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
if pt_path.endswith(".safetensors"):
|
||||
state_dict = safe_load_file(pt_path)
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||
|
||||
pt_state_dict.update(state_dict)
|
||||
|
@ -54,7 +54,6 @@ from .pytorch_utils import ( # noqa: F401
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
id_tensor_storage,
|
||||
is_torch_greater_or_equal_than_1_13,
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
@ -476,7 +475,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
||||
|
||||
for shard_file in shard_files:
|
||||
@ -532,7 +531,7 @@ def load_state_dict(
|
||||
and is_zipfile(checkpoint_file)
|
||||
):
|
||||
extra_args = {"mmap": True}
|
||||
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": weights_only}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location=map_location,
|
||||
|
@ -38,7 +38,6 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -815,14 +814,6 @@ class FalconPreTrainedModel(PreTrainedModel):
|
||||
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
|
||||
# NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
|
||||
if hard_check_only:
|
||||
if not is_torch_greater_or_equal_than_2_0:
|
||||
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_0:
|
||||
return config
|
||||
|
||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
return config
|
||||
|
@ -36,7 +36,6 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -56,9 +55,6 @@ if is_flash_attn_2_available():
|
||||
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
||||
# It means that the function will not be traced through and simply appear as a node in the graph.
|
||||
if is_torch_fx_available():
|
||||
if not is_torch_greater_or_equal_than_1_13:
|
||||
import torch.fx
|
||||
|
||||
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
||||
|
||||
|
||||
|
@ -33,7 +33,6 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -51,9 +50,6 @@ if is_flash_attn_2_available():
|
||||
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
||||
# It means that the function will not be traced through and simply appear as a node in the graph.
|
||||
if is_torch_fx_available():
|
||||
if not is_torch_greater_or_equal_than_1_13:
|
||||
import torch.fx
|
||||
|
||||
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
||||
|
||||
|
||||
|
@ -25,7 +25,6 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
|
||||
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@ -314,7 +313,7 @@ class SuperPointDescriptorDecoder(nn.Module):
|
||||
divisor = divisor.to(keypoints)
|
||||
keypoints /= divisor
|
||||
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
||||
kwargs = {"align_corners": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
kwargs = {"align_corners": True}
|
||||
# [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
|
||||
keypoints = keypoints.view(batch_size, 1, -1, 2)
|
||||
descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
|
||||
|
@ -31,7 +31,6 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
is_torch_greater_or_equal_than_1_12,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import (
|
||||
@ -46,12 +45,6 @@ from .configuration_tapas import TapasConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not is_torch_greater_or_equal_than_1_12:
|
||||
logger.warning(
|
||||
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
|
||||
"TapasModel. Please upgrade torch."
|
||||
)
|
||||
|
||||
_CONFIG_FOR_DOC = "TapasConfig"
|
||||
_CHECKPOINT_FOR_DOC = "google/tapas-base"
|
||||
|
||||
|
@ -38,7 +38,6 @@ from ...modeling_outputs import (
|
||||
XVectorOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@ -1590,7 +1589,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
state_dict = torch.load(
|
||||
weight_path,
|
||||
map_location="cpu",
|
||||
|
@ -34,9 +34,6 @@ is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse(
|
||||
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||
|
||||
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
@ -75,7 +75,6 @@ from .optimization import Adafactor, get_scheduler
|
||||
from .processing_utils import ProcessorMixin
|
||||
from .pytorch_utils import (
|
||||
ALL_LAYERNORM_LAYERS,
|
||||
is_torch_greater_or_equal_than_1_13,
|
||||
is_torch_greater_or_equal_than_2_3,
|
||||
)
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
@ -2778,7 +2777,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
# If the model is on the GPU, it still works!
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
||||
@ -2899,7 +2898,7 @@ class Trainer:
|
||||
or os.path.exists(best_safe_adapter_model_path)
|
||||
):
|
||||
has_been_loaded = True
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
||||
# If the 'user_content.pt' file exists, load with the new smp api.
|
||||
|
@ -56,12 +56,7 @@ if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if is_torch_available():
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
|
||||
if is_torch_greater_or_equal_than_2_0:
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
else:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -71,8 +71,6 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import AcceleratorState, PartialState
|
||||
from accelerate.utils import DistributedType
|
||||
@ -1157,7 +1155,7 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Number of batches loaded in advance by each worker. "
|
||||
@ -1702,14 +1700,6 @@ class TrainingArguments:
|
||||
raise ValueError(
|
||||
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
||||
)
|
||||
elif not is_torch_xpu_available():
|
||||
# xpu
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
|
||||
if not is_torch_greater_or_equal_than_1_12:
|
||||
raise ValueError(
|
||||
"Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed"
|
||||
)
|
||||
|
||||
if self.fp16 and self.bf16:
|
||||
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
||||
@ -2056,11 +2046,7 @@ class TrainingArguments:
|
||||
if self.use_cpu:
|
||||
self.dataloader_pin_memory = False
|
||||
|
||||
if (
|
||||
(not is_torch_available() or is_torch_greater_or_equal_than_2_0)
|
||||
and self.dataloader_num_workers == 0
|
||||
and self.dataloader_prefetch_factor is not None
|
||||
):
|
||||
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
|
||||
raise ValueError(
|
||||
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
|
||||
" when --dataloader_num_workers > 1."
|
||||
|
@ -60,7 +60,6 @@ from ..models.auto.modeling_auto import (
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
)
|
||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
TORCH_FX_REQUIRED_VERSION,
|
||||
@ -635,10 +634,9 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
operator.getitem: operator_getitem,
|
||||
}
|
||||
|
||||
if is_torch_greater_or_equal_than_2_0:
|
||||
_MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
|
||||
torch_nn_functional_scaled_dot_product_attention
|
||||
)
|
||||
_MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
|
||||
torch_nn_functional_scaled_dot_product_attention
|
||||
)
|
||||
|
||||
|
||||
class HFProxy(Proxy):
|
||||
|
@ -45,8 +45,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -43,9 +43,6 @@ if is_torch_available():
|
||||
FalconMambaModel,
|
||||
)
|
||||
from transformers.cache_utils import MambaCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
|
||||
@ -246,9 +243,6 @@ class FalconMambaModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
)
|
||||
@require_torch
|
||||
# Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache
|
||||
class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
@ -37,9 +37,6 @@ if is_torch_available():
|
||||
GPTBigCodeModel,
|
||||
)
|
||||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_12 = False
|
||||
|
||||
|
||||
class GPTBigCodeModelTester:
|
||||
@ -504,10 +501,6 @@ class GPTBigCodeMHAModelTest(GPTBigCodeModelTest):
|
||||
multi_query = False
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_12,
|
||||
reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.",
|
||||
)
|
||||
@slow
|
||||
@require_torch
|
||||
class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
@ -41,9 +41,6 @@ if is_torch_available():
|
||||
GPTJForSequenceClassification,
|
||||
GPTJModel,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_12 = False
|
||||
|
||||
|
||||
class GPTJModelTester:
|
||||
@ -363,15 +360,9 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
test_model_parallel = False
|
||||
test_head_masking = False
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+."
|
||||
)
|
||||
def test_torch_fx(self):
|
||||
super().test_torch_fx()
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+."
|
||||
)
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
|
@ -44,9 +44,6 @@ if is_torch_available():
|
||||
|
||||
from transformers import IdeficsForVisionText2Text, IdeficsModel, IdeficsProcessor
|
||||
from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -327,7 +324,6 @@ class IdeficsModelTester:
|
||||
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
|
||||
@ -594,7 +590,6 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
|
||||
@ -818,7 +813,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
@require_vision
|
||||
class IdeficsModelIntegrationTest(TestCasePlus):
|
||||
|
@ -48,8 +48,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -40,8 +40,6 @@ if is_torch_available():
|
||||
Idefics3ForConditionalGeneration,
|
||||
Idefics3Model,
|
||||
)
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -43,8 +43,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -48,8 +48,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -48,8 +48,6 @@ from ...test_modeling_common import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -48,8 +48,6 @@ from ...test_modeling_common import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -38,9 +38,6 @@ if is_torch_available():
|
||||
MambaModel,
|
||||
)
|
||||
from transformers.models.mamba.modeling_mamba import MambaCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
class MambaModelTester:
|
||||
@ -239,9 +236,6 @@ class MambaModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
)
|
||||
@require_torch
|
||||
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
|
||||
|
@ -37,9 +37,6 @@ if is_torch_available():
|
||||
Mamba2Model,
|
||||
)
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
class Mamba2ModelTester:
|
||||
@ -214,9 +211,6 @@ class Mamba2ModelTester:
|
||||
self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
)
|
||||
@require_torch
|
||||
class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else ()
|
||||
|
@ -40,8 +40,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -33,8 +33,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
pass
|
||||
|
@ -41,8 +41,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
class Qwen2AudioModelTester:
|
||||
|
@ -47,8 +47,6 @@ from ...test_modeling_common import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -33,9 +33,6 @@ if is_torch_available():
|
||||
RwkvForCausalLM,
|
||||
RwkvModel,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
class RwkvModelTester:
|
||||
@ -231,9 +228,6 @@ class RwkvModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
)
|
||||
@require_torch
|
||||
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
|
||||
@ -440,9 +434,6 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
)
|
||||
@slow
|
||||
class RWKVIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -60,9 +60,6 @@ if is_torch_available():
|
||||
reduce_mean,
|
||||
reduce_sum,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_12 = False
|
||||
|
||||
|
||||
class TapasModelTester:
|
||||
@ -411,7 +408,6 @@ class TapasModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
@ -578,7 +574,6 @@ def prepare_tapas_batch_inputs_for_training():
|
||||
return table, queries, answer_coordinates, answer_text, float_answer
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
class TapasModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
@ -930,10 +925,6 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs.logits, expected_tensor, atol=0.05))
|
||||
|
||||
|
||||
# Below: tests for Tapas utilities which are defined in modeling_tapas.py.
|
||||
# These are based on segmented_tensor_test.py of the original implementation.
|
||||
# URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
class TapasUtilitiesTest(unittest.TestCase):
|
||||
def _prepare_tables(self):
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AddedToken, is_torch_available
|
||||
from transformers import AddedToken
|
||||
from transformers.models.tapas.tokenization_tapas import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BasicTokenizer,
|
||||
@ -45,12 +45,6 @@ from transformers.testing_utils import (
|
||||
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_12 = False
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@require_pandas
|
||||
class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@ -1048,7 +1042,6 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# Do the same test as modeling common.
|
||||
self.assertIn(0, output["token_type_ids"][0])
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
|
@ -41,8 +41,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -20,7 +20,6 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
TableQuestionAnsweringPipeline,
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
is_torch_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
@ -33,12 +32,6 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_12 = False
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class TQAPipelineTests(unittest.TestCase):
|
||||
# Putting it there for consistency, but TQA do not have fast tokenizer
|
||||
@ -150,7 +143,6 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
def test_small_model_pt(self, torch_dtype="float32"):
|
||||
model_id = "lysandre/tiny-tapas-random-wtq"
|
||||
@ -253,12 +245,10 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
def test_small_model_pt_fp16(self):
|
||||
self.test_small_model_pt(torch_dtype="float16")
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"):
|
||||
model_id = "lysandre/tiny-tapas-random-sqa"
|
||||
@ -378,7 +368,6 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@require_torch
|
||||
def test_slow_tokenizer_sqa_pt_fp16(self):
|
||||
self.test_slow_tokenizer_sqa_pt(torch_dtype="float16")
|
||||
@ -505,7 +494,6 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_wtq_pt(self, torch_dtype="float32"):
|
||||
@ -551,7 +539,6 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_wtq_pt_fp16(self):
|
||||
@ -606,7 +593,6 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_sqa_pt(self, torch_dtype="float32"):
|
||||
@ -632,7 +618,6 @@ class TQAPipelineTests(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_sqa_pt_fp16(self):
|
||||
|
Loading…
Reference in New Issue
Block a user