mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
Fix TFGroupViT
CI (#19461)
* Fix TFGroupViT CI Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a293a0e8a3
commit
d7dc774a79
@ -33,6 +33,7 @@ RUN echo torch=$VERSION
|
|||||||
RUN [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA
|
RUN [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA
|
||||||
|
|
||||||
RUN python3 -m pip install --no-cache-dir -U tensorflow
|
RUN python3 -m pip install --no-cache-dir -U tensorflow
|
||||||
|
RUN python3 -m pip install --no-cache-dir -U tensorflow_probability
|
||||||
RUN python3 -m pip uninstall -y flax jax
|
RUN python3 -m pip uninstall -y flax jax
|
||||||
|
|
||||||
# Use installed torch version for `torch-scatter` to avid to deal with PYTORCH='pre'.
|
# Use installed torch version for `torch-scatter` to avid to deal with PYTORCH='pre'.
|
||||||
|
@ -26,7 +26,13 @@ import numpy as np
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
|
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
|
||||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow
|
from transformers.testing_utils import (
|
||||||
|
is_pt_tf_cross_test,
|
||||||
|
require_tensorflow_probability,
|
||||||
|
require_tf,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
from transformers.utils import is_tf_available, is_vision_available
|
from transformers.utils import is_tf_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -155,6 +161,16 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
|
During saving, TensorFlow will also run with `training=True` which trigger `gumbel_softmax` that requires
|
||||||
|
`tensorflow-probability`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@require_tensorflow_probability
|
||||||
|
@slow
|
||||||
|
def test_saved_model_creation(self):
|
||||||
|
super().test_saved_model_creation()
|
||||||
|
|
||||||
@unittest.skip(reason="GroupViT does not use inputs_embeds")
|
@unittest.skip(reason="GroupViT does not use inputs_embeds")
|
||||||
def test_graph_mode_with_inputs_embeds(self):
|
def test_graph_mode_with_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
@ -295,6 +311,10 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFGroupViTVisionModel.from_pretrained(model_name)
|
model = TFGroupViTVisionModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"TFGroupViTVisionModel does not convert `hidden_states` and `attentions` to tensors as they are all of"
|
||||||
|
" different dimensions, and we get `Got a non-Tensor value` error when saving the model."
|
||||||
|
)
|
||||||
@slow
|
@slow
|
||||||
def test_saved_model_creation_extended(self):
|
def test_saved_model_creation_extended(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -578,6 +598,10 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@require_tensorflow_probability
|
||||||
|
def test_keras_fit(self):
|
||||||
|
super().test_keras_fit()
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
@is_pt_tf_cross_test
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
# `GroupViT` computes some indices using argmax, uses them as
|
# `GroupViT` computes some indices using argmax, uses them as
|
||||||
|
Loading…
Reference in New Issue
Block a user