mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Aurevoir PyTorch 1 (#35358)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
4567ee8057
commit
05de764e9c
@ -21,39 +21,6 @@ jobs:
|
|||||||
echo "$(python3 -c 'print(int(${{ github.run_number }}) % 10)')"
|
echo "$(python3 -c 'print(int(${{ github.run_number }}) % 10)')"
|
||||||
echo "run_number=$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" >> $GITHUB_OUTPUT
|
echo "run_number=$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
run_past_ci_pytorch_1-13:
|
|
||||||
name: PyTorch 1.13
|
|
||||||
needs: get_number
|
|
||||||
if: needs.get_number.outputs.run_number == 0 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
|
|
||||||
uses: ./.github/workflows/self-past-caller.yml
|
|
||||||
with:
|
|
||||||
framework: pytorch
|
|
||||||
version: "1.13"
|
|
||||||
sha: ${{ github.sha }}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
run_past_ci_pytorch_1-12:
|
|
||||||
name: PyTorch 1.12
|
|
||||||
needs: get_number
|
|
||||||
if: needs.get_number.outputs.run_number == 1 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
|
|
||||||
uses: ./.github/workflows/self-past-caller.yml
|
|
||||||
with:
|
|
||||||
framework: pytorch
|
|
||||||
version: "1.12"
|
|
||||||
sha: ${{ github.sha }}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
run_past_ci_pytorch_1-11:
|
|
||||||
name: PyTorch 1.11
|
|
||||||
needs: get_number
|
|
||||||
if: needs.get_number.outputs.run_number == 2 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
|
|
||||||
uses: ./.github/workflows/self-past-caller.yml
|
|
||||||
with:
|
|
||||||
framework: pytorch
|
|
||||||
version: "1.11"
|
|
||||||
sha: ${{ github.sha }}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
run_past_ci_tensorflow_2-11:
|
run_past_ci_tensorflow_2-11:
|
||||||
name: TensorFlow 2.11
|
name: TensorFlow 2.11
|
||||||
needs: get_number
|
needs: get_number
|
||||||
|
@ -249,7 +249,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta
|
|||||||
|
|
||||||
### With pip
|
### With pip
|
||||||
|
|
||||||
This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, and TensorFlow 2.6+.
|
This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, and TensorFlow 2.6+.
|
||||||
|
|
||||||
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -245,7 +245,7 @@ limitations under the License.
|
|||||||
|
|
||||||
### باستخدام pip
|
### باستخدام pip
|
||||||
|
|
||||||
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، و TensorFlow 2.6+.
|
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، و TensorFlow 2.6+.
|
||||||
|
|
||||||
يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -246,7 +246,7 @@ Das Modell selbst ist ein reguläres [PyTorch `nn.Module`](https://pytorch.org/d
|
|||||||
|
|
||||||
### Mit pip
|
### Mit pip
|
||||||
|
|
||||||
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ und TensorFlow 2.6+ getestet.
|
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ und TensorFlow 2.6+ getestet.
|
||||||
|
|
||||||
Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an.
|
Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an.
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h
|
|||||||
|
|
||||||
### Con pip
|
### Con pip
|
||||||
|
|
||||||
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ y TensorFlow 2.6+.
|
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ y TensorFlow 2.6+.
|
||||||
|
|
||||||
Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -243,7 +243,7 @@ Le modèle lui-même est un module [`nn.Module` PyTorch](https://pytorch.org/doc
|
|||||||
|
|
||||||
### Avec pip
|
### Avec pip
|
||||||
|
|
||||||
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ et TensorFlow 2.6+.
|
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ et TensorFlow 2.6+.
|
||||||
|
|
||||||
Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ checkpoint: जाँच बिंदु
|
|||||||
|
|
||||||
### पिप का उपयोग करना
|
### पिप का उपयोग करना
|
||||||
|
|
||||||
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ और TensorFlow 2.6+ के तहत किया गया है।
|
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ और TensorFlow 2.6+ के तहत किया गया है।
|
||||||
|
|
||||||
आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।
|
आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを
|
|||||||
|
|
||||||
### pipにて
|
### pipにて
|
||||||
|
|
||||||
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+ でテストされています。
|
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+ でテストされています。
|
||||||
|
|
||||||
🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。
|
🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。
|
||||||
|
|
||||||
|
@ -242,7 +242,7 @@ Transformers에 달린 100,000개의 별을 축하하기 위해, 우리는 커
|
|||||||
|
|
||||||
### pip로 설치하기
|
### pip로 설치하기
|
||||||
|
|
||||||
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+에서 테스트 되었습니다.
|
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+에서 테스트 되었습니다.
|
||||||
|
|
||||||
[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.
|
[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht
|
|||||||
|
|
||||||
### Com pip
|
### Com pip
|
||||||
|
|
||||||
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ e TensorFlow 2.6+.
|
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ e TensorFlow 2.6+.
|
||||||
|
|
||||||
Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -244,7 +244,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра
|
|||||||
|
|
||||||
### С помощью pip
|
### С помощью pip
|
||||||
|
|
||||||
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ и TensorFlow 2.6+.
|
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ и TensorFlow 2.6+.
|
||||||
|
|
||||||
Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -246,7 +246,7 @@ limitations under the License.
|
|||||||
|
|
||||||
### పిప్ తో
|
### పిప్ తో
|
||||||
|
|
||||||
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 1.11+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
|
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.0+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
|
||||||
|
|
||||||
మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్ఫార్మర్లను ఇన్స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.
|
మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్ఫార్మర్లను ఇన్స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#### ‏ pip کے ساتھ
|
#### ‏ pip کے ساتھ
|
||||||
|
|
||||||
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
|
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
|
||||||
|
|
||||||
آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔
|
آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔
|
||||||
|
|
||||||
|
@ -245,7 +245,7 @@ Chính mô hình là một [Pytorch `nn.Module`](https://pytorch.org/docs/stable
|
|||||||
|
|
||||||
### Sử dụng pip
|
### Sử dụng pip
|
||||||
|
|
||||||
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ và TensorFlow 2.6+.
|
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ và TensorFlow 2.6+.
|
||||||
|
|
||||||
Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ checkpoint: 检查点
|
|||||||
|
|
||||||
### 使用 pip
|
### 使用 pip
|
||||||
|
|
||||||
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下经过测试。
|
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下经过测试。
|
||||||
|
|
||||||
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換
|
|||||||
|
|
||||||
### 使用 pip
|
### 使用 pip
|
||||||
|
|
||||||
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下經過測試。
|
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下經過測試。
|
||||||
|
|
||||||
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||||
|
|
||||||
|
@ -106,7 +106,6 @@ if is_torch_available():
|
|||||||
XLMWithLMHeadModel,
|
XLMWithLMHeadModel,
|
||||||
XLNetLMHeadModel,
|
XLNetLMHeadModel,
|
||||||
)
|
)
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
@ -279,7 +278,7 @@ def convert_pt_checkpoint_to_tf(
|
|||||||
if compare_with_pt_model:
|
if compare_with_pt_model:
|
||||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
pytorch_checkpoint_path,
|
pytorch_checkpoint_path,
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
|
@ -63,8 +63,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
|
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||||
@ -73,7 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||||
|
|
||||||
@ -246,13 +244,11 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
||||||
|
|
||||||
# Load the index
|
# Load the index
|
||||||
flax_state_dict = {}
|
flax_state_dict = {}
|
||||||
for shard_file in shard_filenames:
|
for shard_file in shard_filenames:
|
||||||
# load using msgpack utils
|
# load using msgpack utils
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
|
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
|
||||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||||
pt_state_dict = {
|
pt_state_dict = {
|
||||||
|
@ -180,8 +180,6 @@ def load_pytorch_checkpoint_in_tf2_model(
|
|||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
from safetensors.torch import load_file as safe_load_file # noqa: F401
|
from safetensors.torch import load_file as safe_load_file # noqa: F401
|
||||||
|
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
@ -201,7 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
|||||||
if pt_path.endswith(".safetensors"):
|
if pt_path.endswith(".safetensors"):
|
||||||
state_dict = safe_load_file(pt_path)
|
state_dict = safe_load_file(pt_path)
|
||||||
else:
|
else:
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||||
|
|
||||||
pt_state_dict.update(state_dict)
|
pt_state_dict.update(state_dict)
|
||||||
|
@ -54,7 +54,6 @@ from .pytorch_utils import ( # noqa: F401
|
|||||||
apply_chunking_to_forward,
|
apply_chunking_to_forward,
|
||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
id_tensor_storage,
|
id_tensor_storage,
|
||||||
is_torch_greater_or_equal_than_1_13,
|
|
||||||
prune_conv1d_layer,
|
prune_conv1d_layer,
|
||||||
prune_layer,
|
prune_layer,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
@ -476,7 +475,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
|||||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||||
raise RuntimeError(error_message)
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
||||||
|
|
||||||
for shard_file in shard_files:
|
for shard_file in shard_files:
|
||||||
@ -532,7 +531,7 @@ def load_state_dict(
|
|||||||
and is_zipfile(checkpoint_file)
|
and is_zipfile(checkpoint_file)
|
||||||
):
|
):
|
||||||
extra_args = {"mmap": True}
|
extra_args = {"mmap": True}
|
||||||
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": weights_only}
|
||||||
return torch.load(
|
return torch.load(
|
||||||
checkpoint_file,
|
checkpoint_file,
|
||||||
map_location=map_location,
|
map_location=map_location,
|
||||||
|
@ -38,7 +38,6 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@ -815,14 +814,6 @@ class FalconPreTrainedModel(PreTrainedModel):
|
|||||||
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
|
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
|
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
|
||||||
# NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
|
|
||||||
if hard_check_only:
|
|
||||||
if not is_torch_greater_or_equal_than_2_0:
|
|
||||||
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
|
|
||||||
|
|
||||||
if not is_torch_greater_or_equal_than_2_0:
|
|
||||||
return config
|
|
||||||
|
|
||||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||||
if _is_bettertransformer:
|
if _is_bettertransformer:
|
||||||
return config
|
return config
|
||||||
|
@ -36,7 +36,6 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@ -56,9 +55,6 @@ if is_flash_attn_2_available():
|
|||||||
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
||||||
# It means that the function will not be traced through and simply appear as a node in the graph.
|
# It means that the function will not be traced through and simply appear as a node in the graph.
|
||||||
if is_torch_fx_available():
|
if is_torch_fx_available():
|
||||||
if not is_torch_greater_or_equal_than_1_13:
|
|
||||||
import torch.fx
|
|
||||||
|
|
||||||
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,6 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
@ -51,9 +50,6 @@ if is_flash_attn_2_available():
|
|||||||
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
||||||
# It means that the function will not be traced through and simply appear as a node in the graph.
|
# It means that the function will not be traced through and simply appear as a node in the graph.
|
||||||
if is_torch_fx_available():
|
if is_torch_fx_available():
|
||||||
if not is_torch_greater_or_equal_than_1_13:
|
|
||||||
import torch.fx
|
|
||||||
|
|
||||||
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ from transformers.modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
|
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
|
||||||
|
|
||||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@ -314,7 +313,7 @@ class SuperPointDescriptorDecoder(nn.Module):
|
|||||||
divisor = divisor.to(keypoints)
|
divisor = divisor.to(keypoints)
|
||||||
keypoints /= divisor
|
keypoints /= divisor
|
||||||
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
||||||
kwargs = {"align_corners": True} if is_torch_greater_or_equal_than_1_13 else {}
|
kwargs = {"align_corners": True}
|
||||||
# [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
|
# [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
|
||||||
keypoints = keypoints.view(batch_size, 1, -1, 2)
|
keypoints = keypoints.view(batch_size, 1, -1, 2)
|
||||||
descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
|
descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
|
||||||
|
@ -31,7 +31,6 @@ from ...modeling_utils import PreTrainedModel
|
|||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import (
|
||||||
apply_chunking_to_forward,
|
apply_chunking_to_forward,
|
||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
is_torch_greater_or_equal_than_1_12,
|
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@ -46,12 +45,6 @@ from .configuration_tapas import TapasConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if not is_torch_greater_or_equal_than_1_12:
|
|
||||||
logger.warning(
|
|
||||||
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
|
|
||||||
"TapasModel. Please upgrade torch."
|
|
||||||
)
|
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "TapasConfig"
|
_CONFIG_FOR_DOC = "TapasConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "google/tapas-base"
|
_CHECKPOINT_FOR_DOC = "google/tapas-base"
|
||||||
|
|
||||||
|
@ -38,7 +38,6 @@ from ...modeling_outputs import (
|
|||||||
XVectorOutput,
|
XVectorOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@ -1590,7 +1589,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
weight_path,
|
weight_path,
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
|
@ -34,9 +34,6 @@ is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse(
|
|||||||
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
||||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
|
||||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
|
||||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
|
||||||
|
|
||||||
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
||||||
_torch_distributed_available = torch.distributed.is_available()
|
_torch_distributed_available = torch.distributed.is_available()
|
||||||
|
@ -75,7 +75,6 @@ from .optimization import Adafactor, get_scheduler
|
|||||||
from .processing_utils import ProcessorMixin
|
from .processing_utils import ProcessorMixin
|
||||||
from .pytorch_utils import (
|
from .pytorch_utils import (
|
||||||
ALL_LAYERNORM_LAYERS,
|
ALL_LAYERNORM_LAYERS,
|
||||||
is_torch_greater_or_equal_than_1_13,
|
|
||||||
is_torch_greater_or_equal_than_2_3,
|
is_torch_greater_or_equal_than_2_3,
|
||||||
)
|
)
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
@ -2778,7 +2777,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
# If the model is on the GPU, it still works!
|
# If the model is on the GPU, it still works!
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
||||||
@ -2899,7 +2898,7 @@ class Trainer:
|
|||||||
or os.path.exists(best_safe_adapter_model_path)
|
or os.path.exists(best_safe_adapter_model_path)
|
||||||
):
|
):
|
||||||
has_been_loaded = True
|
has_been_loaded = True
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": True}
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
||||||
# If the 'user_content.pt' file exists, load with the new smp api.
|
# If the 'user_content.pt' file exists, load with the new smp api.
|
||||||
|
@ -56,12 +56,7 @@ if is_torch_xla_available():
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_2_0:
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
else:
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -71,8 +71,6 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate.state import AcceleratorState, PartialState
|
from accelerate.state import AcceleratorState, PartialState
|
||||||
from accelerate.utils import DistributedType
|
from accelerate.utils import DistributedType
|
||||||
@ -1157,7 +1155,7 @@ class TrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataloader_prefetch_factor: Optional[int] = field(
|
dataloader_prefetch_factor: Optional[int] = field(
|
||||||
default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
"Number of batches loaded in advance by each worker. "
|
"Number of batches loaded in advance by each worker. "
|
||||||
@ -1702,14 +1700,6 @@ class TrainingArguments:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
||||||
)
|
)
|
||||||
elif not is_torch_xpu_available():
|
|
||||||
# xpu
|
|
||||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_12
|
|
||||||
|
|
||||||
if not is_torch_greater_or_equal_than_1_12:
|
|
||||||
raise ValueError(
|
|
||||||
"Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.fp16 and self.bf16:
|
if self.fp16 and self.bf16:
|
||||||
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
||||||
@ -2056,11 +2046,7 @@ class TrainingArguments:
|
|||||||
if self.use_cpu:
|
if self.use_cpu:
|
||||||
self.dataloader_pin_memory = False
|
self.dataloader_pin_memory = False
|
||||||
|
|
||||||
if (
|
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
|
||||||
(not is_torch_available() or is_torch_greater_or_equal_than_2_0)
|
|
||||||
and self.dataloader_num_workers == 0
|
|
||||||
and self.dataloader_prefetch_factor is not None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
|
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
|
||||||
" when --dataloader_num_workers > 1."
|
" when --dataloader_num_workers > 1."
|
||||||
|
@ -60,7 +60,6 @@ from ..models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_MAPPING_NAMES,
|
MODEL_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
from .import_utils import (
|
from .import_utils import (
|
||||||
ENV_VARS_TRUE_VALUES,
|
ENV_VARS_TRUE_VALUES,
|
||||||
TORCH_FX_REQUIRED_VERSION,
|
TORCH_FX_REQUIRED_VERSION,
|
||||||
@ -635,10 +634,9 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
|||||||
operator.getitem: operator_getitem,
|
operator.getitem: operator_getitem,
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_2_0:
|
_MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
|
||||||
_MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
|
|
||||||
torch_nn_functional_scaled_dot_product_attention
|
torch_nn_functional_scaled_dot_product_attention
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HFProxy(Proxy):
|
class HFProxy(Proxy):
|
||||||
|
@ -45,8 +45,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -43,9 +43,6 @@ if is_torch_available():
|
|||||||
FalconMambaModel,
|
FalconMambaModel,
|
||||||
)
|
)
|
||||||
from transformers.cache_utils import MambaCache
|
from transformers.cache_utils import MambaCache
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
|
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
|
||||||
@ -246,9 +243,6 @@ class FalconMambaModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
|
||||||
)
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache
|
# Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache
|
||||||
class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
@ -37,9 +37,6 @@ if is_torch_available():
|
|||||||
GPTBigCodeModel,
|
GPTBigCodeModel,
|
||||||
)
|
)
|
||||||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
|
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_1_12 = False
|
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeModelTester:
|
class GPTBigCodeModelTester:
|
||||||
@ -504,10 +501,6 @@ class GPTBigCodeMHAModelTest(GPTBigCodeModelTest):
|
|||||||
multi_query = False
|
multi_query = False
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_1_12,
|
|
||||||
reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.",
|
|
||||||
)
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase):
|
class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
@ -41,9 +41,6 @@ if is_torch_available():
|
|||||||
GPTJForSequenceClassification,
|
GPTJForSequenceClassification,
|
||||||
GPTJModel,
|
GPTJModel,
|
||||||
)
|
)
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_1_12 = False
|
|
||||||
|
|
||||||
|
|
||||||
class GPTJModelTester:
|
class GPTJModelTester:
|
||||||
@ -363,15 +360,9 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+."
|
|
||||||
)
|
|
||||||
def test_torch_fx(self):
|
def test_torch_fx(self):
|
||||||
super().test_torch_fx()
|
super().test_torch_fx()
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+."
|
|
||||||
)
|
|
||||||
def test_torch_fx_output_loss(self):
|
def test_torch_fx_output_loss(self):
|
||||||
super().test_torch_fx_output_loss()
|
super().test_torch_fx_output_loss()
|
||||||
|
|
||||||
|
@ -44,9 +44,6 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import IdeficsForVisionText2Text, IdeficsModel, IdeficsProcessor
|
from transformers import IdeficsForVisionText2Text, IdeficsModel, IdeficsProcessor
|
||||||
from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig
|
from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -327,7 +324,6 @@ class IdeficsModelTester:
|
|||||||
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
|
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
|
||||||
@ -594,7 +590,6 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
|
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
|
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
|
||||||
@ -818,7 +813,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
class IdeficsModelIntegrationTest(TestCasePlus):
|
class IdeficsModelIntegrationTest(TestCasePlus):
|
||||||
|
@ -48,8 +48,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -40,8 +40,6 @@ if is_torch_available():
|
|||||||
Idefics3ForConditionalGeneration,
|
Idefics3ForConditionalGeneration,
|
||||||
Idefics3Model,
|
Idefics3Model,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -43,8 +43,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -48,8 +48,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
|
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -48,8 +48,6 @@ from ...test_modeling_common import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -48,8 +48,6 @@ from ...test_modeling_common import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -38,9 +38,6 @@ if is_torch_available():
|
|||||||
MambaModel,
|
MambaModel,
|
||||||
)
|
)
|
||||||
from transformers.models.mamba.modeling_mamba import MambaCache
|
from transformers.models.mamba.modeling_mamba import MambaCache
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
class MambaModelTester:
|
class MambaModelTester:
|
||||||
@ -239,9 +236,6 @@ class MambaModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
|
||||||
)
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
|
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
|
||||||
|
@ -37,9 +37,6 @@ if is_torch_available():
|
|||||||
Mamba2Model,
|
Mamba2Model,
|
||||||
)
|
)
|
||||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2ModelTester:
|
class Mamba2ModelTester:
|
||||||
@ -214,9 +211,6 @@ class Mamba2ModelTester:
|
|||||||
self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))
|
self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
|
||||||
)
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else ()
|
all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else ()
|
||||||
|
@ -40,8 +40,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -33,8 +33,7 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
pass
|
pass
|
||||||
|
@ -41,8 +41,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2AudioModelTester:
|
class Qwen2AudioModelTester:
|
||||||
|
@ -47,8 +47,6 @@ from ...test_modeling_common import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -33,9 +33,6 @@ if is_torch_available():
|
|||||||
RwkvForCausalLM,
|
RwkvForCausalLM,
|
||||||
RwkvModel,
|
RwkvModel,
|
||||||
)
|
)
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
class RwkvModelTester:
|
class RwkvModelTester:
|
||||||
@ -231,9 +228,6 @@ class RwkvModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
|
||||||
)
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
|
all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
|
||||||
@ -440,9 +434,6 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
|
||||||
)
|
|
||||||
@slow
|
@slow
|
||||||
class RWKVIntegrationTests(unittest.TestCase):
|
class RWKVIntegrationTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -60,9 +60,6 @@ if is_torch_available():
|
|||||||
reduce_mean,
|
reduce_mean,
|
||||||
reduce_sum,
|
reduce_sum,
|
||||||
)
|
)
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_1_12 = False
|
|
||||||
|
|
||||||
|
|
||||||
class TapasModelTester:
|
class TapasModelTester:
|
||||||
@ -411,7 +408,6 @@ class TapasModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
@ -578,7 +574,6 @@ def prepare_tapas_batch_inputs_for_training():
|
|||||||
return table, queries, answer_coordinates, answer_text, float_answer
|
return table, queries, answer_coordinates, answer_text, float_answer
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TapasModelIntegrationTest(unittest.TestCase):
|
class TapasModelIntegrationTest(unittest.TestCase):
|
||||||
@cached_property
|
@cached_property
|
||||||
@ -930,10 +925,6 @@ class TapasModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(outputs.logits, expected_tensor, atol=0.05))
|
self.assertTrue(torch.allclose(outputs.logits, expected_tensor, atol=0.05))
|
||||||
|
|
||||||
|
|
||||||
# Below: tests for Tapas utilities which are defined in modeling_tapas.py.
|
|
||||||
# These are based on segmented_tensor_test.py of the original implementation.
|
|
||||||
# URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TapasUtilitiesTest(unittest.TestCase):
|
class TapasUtilitiesTest(unittest.TestCase):
|
||||||
def _prepare_tables(self):
|
def _prepare_tables(self):
|
||||||
|
@ -23,7 +23,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AddedToken, is_torch_available
|
from transformers import AddedToken
|
||||||
from transformers.models.tapas.tokenization_tapas import (
|
from transformers.models.tapas.tokenization_tapas import (
|
||||||
VOCAB_FILES_NAMES,
|
VOCAB_FILES_NAMES,
|
||||||
BasicTokenizer,
|
BasicTokenizer,
|
||||||
@ -45,12 +45,6 @@ from transformers.testing_utils import (
|
|||||||
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
|
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_1_12 = False
|
|
||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
@require_pandas
|
@require_pandas
|
||||||
class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
@ -1048,7 +1042,6 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Do the same test as modeling common.
|
# Do the same test as modeling common.
|
||||||
self.assertIn(0, output["token_type_ids"][0])
|
self.assertIn(0, output["token_type_ids"][0])
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_encode_plus_sent_to_model(self):
|
def test_torch_encode_plus_sent_to_model(self):
|
||||||
|
@ -41,8 +41,6 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -20,7 +20,6 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TableQuestionAnsweringPipeline,
|
TableQuestionAnsweringPipeline,
|
||||||
TFAutoModelForTableQuestionAnswering,
|
TFAutoModelForTableQuestionAnswering,
|
||||||
is_torch_available,
|
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@ -33,12 +32,6 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
|
|
||||||
else:
|
|
||||||
is_torch_greater_or_equal_than_1_12 = False
|
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class TQAPipelineTests(unittest.TestCase):
|
class TQAPipelineTests(unittest.TestCase):
|
||||||
# Putting it there for consistency, but TQA do not have fast tokenizer
|
# Putting it there for consistency, but TQA do not have fast tokenizer
|
||||||
@ -150,7 +143,6 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt(self, torch_dtype="float32"):
|
def test_small_model_pt(self, torch_dtype="float32"):
|
||||||
model_id = "lysandre/tiny-tapas-random-wtq"
|
model_id = "lysandre/tiny-tapas-random-wtq"
|
||||||
@ -253,12 +245,10 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_fp16(self):
|
def test_small_model_pt_fp16(self):
|
||||||
self.test_small_model_pt(torch_dtype="float16")
|
self.test_small_model_pt(torch_dtype="float16")
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"):
|
def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"):
|
||||||
model_id = "lysandre/tiny-tapas-random-sqa"
|
model_id = "lysandre/tiny-tapas-random-sqa"
|
||||||
@ -378,7 +368,6 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_slow_tokenizer_sqa_pt_fp16(self):
|
def test_slow_tokenizer_sqa_pt_fp16(self):
|
||||||
self.test_slow_tokenizer_sqa_pt(torch_dtype="float16")
|
self.test_slow_tokenizer_sqa_pt(torch_dtype="float16")
|
||||||
@ -505,7 +494,6 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_integration_wtq_pt(self, torch_dtype="float32"):
|
def test_integration_wtq_pt(self, torch_dtype="float32"):
|
||||||
@ -551,7 +539,6 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(results, expected_results)
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_integration_wtq_pt_fp16(self):
|
def test_integration_wtq_pt_fp16(self):
|
||||||
@ -606,7 +593,6 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(results, expected_results)
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_integration_sqa_pt(self, torch_dtype="float32"):
|
def test_integration_sqa_pt(self, torch_dtype="float32"):
|
||||||
@ -632,7 +618,6 @@ class TQAPipelineTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(results, expected_results)
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_integration_sqa_pt_fp16(self):
|
def test_integration_sqa_pt_fp16(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user