[TPU tests] Enable first TPU examples pytorch (#14121)

* up

* up

* fix

* up

* Update examples/pytorch/test_xla_examples.py

* correct labels

* up

* up

* up

* up

* up

* up
This commit is contained in:
Patrick von Platen 2021-10-28 01:22:28 +02:00 committed by GitHub
parent 232822f36d
commit 01b1466983
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 38 deletions

View File

@ -181,6 +181,45 @@ jobs:
name: run_all_tests_tf_gpu_test_reports
path: reports
run_all_examples_torch_xla_tpu:
runs-on: [self-hosted, docker-tpu-test, tpu-v3-8]
container:
image: gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
options: --privileged -v "/lib/libtpu.so:/lib/libtpu.so" -v /mnt/cache/.cache/huggingface:/mnt/cache/ --shm-size 16G
steps:
- name: Launcher docker
uses: actions/checkout@v2
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[testing]
- name: Are TPUs recognized by our DL frameworks
env:
XRT_TPU_CONFIG: localservice;0;localhost:51011
run: |
python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())"
- name: Run example tests on TPU
env:
XRT_TPU_CONFIG: "localservice;0;localhost:51011"
MKL_SERVICE_FORCE_INTEL: "1" # See: https://github.com/pytorch/pytorch/issues/37377
run: |
python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_xla_tpu examples/pytorch/test_xla_examples.py
- name: Failure short reports
if: ${{ always() }}
run: cat reports/tests_torch_xla_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: run_all_examples_torch_xla_tpu
path: reports
run_all_tests_torch_multi_gpu:
runs-on: [self-hosted, docker-gpu, multi-gpu]
container:

View File

@ -14,13 +14,14 @@
# limitations under the License.
import json
import logging
import os
import sys
import unittest
from time import time
from unittest.mock import patch
from transformers.testing_utils import require_torch_tpu
from transformers.testing_utils import TestCasePlus, require_torch_tpu
logging.basicConfig(level=logging.DEBUG)
@ -28,66 +29,65 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results
@require_torch_tpu
class TorchXLAExamplesTests(unittest.TestCase):
class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self):
import xla_spawn
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
output_directory = "run_glue_output"
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
transformers/examples/text-classification/run_glue.py
./examples/pytorch/text-classification/run_glue.py
--num_cores=8
transformers/examples/text-classification/run_glue.py
./examples/pytorch/text-classification/run_glue.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--overwrite_output_dir
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--do_train
--do_eval
--task_name=mrpc
--cache_dir=./cache_dir
--num_train_epochs=1
--debug tpu_metrics_debug
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--max_steps=10
--warmup_steps=2
--seed=42
--max_seq_length=128
--learning_rate=3e-5
--output_dir={output_directory}
--overwrite_output_dir
--logging_steps=5
--save_steps=5
--overwrite_cache
--tpu_metrics_debug
--model_name_or_path=bert-base-cased
--per_device_train_batch_size=64
--per_device_eval_batch_size=64
--evaluation_strategy steps
--overwrite_cache
""".split()
with patch.object(sys, "argv", testargs):
start = time()
xla_spawn.main()
end = time()
result = {}
with open(f"{output_directory}/eval_results_mrpc.txt") as f:
lines = f.readlines()
for line in lines:
key, value = line.split(" = ")
result[key] = float(value)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
del result["eval_loss"]
for value in result.values():
# Assert that the model trains
self.assertGreaterEqual(value, 0.70)
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
# Assert that the script takes less than 500 seconds to make sure it doesn't hang.
self.assertLess(end - start, 500)
def test_trainer_tpu(self):
import xla_spawn
testargs = """
transformers/tests/test_trainer_tpu.py
./tests/test_trainer_tpu.py
--num_cores=8
transformers/tests/test_trainer_tpu.py
./tests/test_trainer_tpu.py
""".split()
with patch.object(sys, "argv", testargs):
xla_spawn.main()

View File

@ -99,7 +99,7 @@ def main():
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
if p.metrics["test_success"] is not True:
logger.error(p.metrics)
exit(1)
@ -113,7 +113,7 @@ def main():
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
if p.metrics["test_success"] is not True:
logger.error(p.metrics)
exit(1)