mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix a bunch of slow tests (#8634)
* CI should install `sentencepiece` * Requiring TF * Fixing some TFDPR bugs * remove return_dict=False/True hack Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
parent
5362bb8a6b
commit
f2e07e7272
8
.github/workflows/self-push.yml
vendored
8
.github/workflows/self-push.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[torch,sklearn,testing,onnxruntime]
|
||||
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[tf,sklearn,testing,onnxruntime]
|
||||
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
@ -185,7 +185,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[torch,sklearn,testing,onnxruntime]
|
||||
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
@ -244,7 +244,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[tf,sklearn,testing,onnxruntime]
|
||||
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
|
16
.github/workflows/self-scheduled.yml
vendored
16
.github/workflows/self-scheduled.yml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[torch,sklearn,testing,onnxruntime]
|
||||
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
pip list
|
||||
|
||||
@ -144,7 +144,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[tf,sklearn,testing,onnxruntime]
|
||||
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
pip list
|
||||
|
||||
@ -223,7 +223,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[torch,sklearn,testing,onnxruntime]
|
||||
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
pip list
|
||||
|
||||
@ -251,11 +251,11 @@ jobs:
|
||||
RUN_SLOW: yes
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -n 1 --dist=loadfile -s --make-reports=examples_torch_multi_gpu examples
|
||||
python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch_examples_multi_gpu examples
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ always() }}
|
||||
run: cat reports/examples_torch_multi_gpu_failures_short.txt
|
||||
run: cat reports/tests_torch_examples_multi_gpu_failures_short.txt
|
||||
|
||||
- name: Run all pipeline tests on multi-GPU
|
||||
if: ${{ always() }}
|
||||
@ -314,7 +314,7 @@ jobs:
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .[tf,sklearn,testing,onnxruntime]
|
||||
pip install .[tf,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
pip list
|
||||
|
||||
@ -345,11 +345,11 @@ jobs:
|
||||
RUN_PIPELINE_TESTS: yes
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -n 1 --dist=loadfile -s -m is_pipeline_test --make-reports=tests_tf_pipelines_multi_gpu tests
|
||||
python -m pytest -n 1 --dist=loadfile -s -m is_pipeline_test --make-reports=tests_tf_pipeline_multi_gpu tests
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ always() }}
|
||||
run: cat reports/tests_tf_multi_gpu_pipelines_failures_short.txt
|
||||
run: cat reports/tests_tf_pipeline_multi_gpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
|
@ -82,7 +82,7 @@ class TFDPRContextEncoderOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
pooler_output: tf.Tensor
|
||||
pooler_output: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
@ -110,7 +110,7 @@ class TFDPRQuestionEncoderOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
pooler_output: tf.Tensor
|
||||
pooler_output: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
@ -141,7 +141,7 @@ class TFDPRReaderOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
start_logits: tf.Tensor
|
||||
start_logits: tf.Tensor = None
|
||||
end_logits: tf.Tensor = None
|
||||
relevance_logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
@ -181,7 +181,7 @@ class TFDPREncoder(TFPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict
|
||||
|
||||
outputs = self.bert_model(
|
||||
inputs=input_ids,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -228,7 +228,8 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
def call(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@ -242,6 +243,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
outputs = self.encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -474,19 +476,21 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
|
||||
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
|
||||
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
|
||||
return_dict = inputs[5] if len(inputs) > 5 else return_dict
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@ -573,19 +577,21 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
|
||||
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
|
||||
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
|
||||
return_dict = inputs[5] if len(inputs) > 5 else return_dict
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@ -650,6 +656,7 @@ class TFDPRReader(TFDPRPretrainedReader):
|
||||
self,
|
||||
inputs,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions: bool = None,
|
||||
output_hidden_states: bool = None,
|
||||
@ -679,19 +686,21 @@ class TFDPRReader(TFDPRPretrainedReader):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
|
||||
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
|
||||
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
|
||||
return_dict = inputs[5] if len(inputs) > 5 else return_dict
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@ -713,9 +722,13 @@ class TFDPRReader(TFDPRPretrainedReader):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
return self.span_predictor(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -340,6 +340,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
|
@ -12,8 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available
|
||||
@ -124,8 +123,6 @@ class TFDPRModelTester:
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
# MODIFY
|
||||
return_dict=False,
|
||||
)
|
||||
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
|
||||
|
||||
@ -137,7 +134,7 @@ class TFDPRModelTester:
|
||||
model = TFDPRContextEncoder(config=config)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, return_dict=True) # MODIFY
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
|
||||
|
||||
def create_and_check_dpr_question_encoder(
|
||||
@ -146,14 +143,14 @@ class TFDPRModelTester:
|
||||
model = TFDPRQuestionEncoder(config=config)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, return_dict=True) # MODIFY
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
|
||||
|
||||
def create_and_check_dpr_reader(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFDPRReader(config=config)
|
||||
result = model(input_ids, attention_mask=input_mask, return_dict=True) # MODIFY
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
@ -214,27 +211,61 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFDPRContextEncoder.from_pretrained(model_name, from_pt=True)
|
||||
model = TFDPRContextEncoder.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFDPRContextEncoder.from_pretrained(model_name, from_pt=True)
|
||||
model = TFDPRContextEncoder.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
for model_name in TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFDPRQuestionEncoder.from_pretrained(model_name, from_pt=True)
|
||||
model = TFDPRQuestionEncoder.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
for model_name in TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFDPRReader.from_pretrained(model_name, from_pt=True)
|
||||
model = TFDPRReader.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_attentions = True
|
||||
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
print(model_class)
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
model._set_save_spec(class_inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(class_inputs_dict)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
output = outputs["encoder_attentions"] if isinstance(outputs, dict) else outputs[-1]
|
||||
else:
|
||||
output = outputs["attentions"] if isinstance(outputs, dict) else outputs[-1]
|
||||
|
||||
attentions = [t.numpy() for t in output]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFDPRModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", return_dict=False)
|
||||
model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
||||
|
||||
input_ids = tf.constant(
|
||||
[[101, 7592, 1010, 2003, 2026, 3899, 10140, 1029, 102]]
|
||||
|
@ -249,6 +249,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFElectraModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
|
Loading…
Reference in New Issue
Block a user