From 01b14669839ab9fe247253b79b3b1af2d23e1a46 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 28 Oct 2021 01:22:28 +0200 Subject: [PATCH] [TPU tests] Enable first TPU examples pytorch (#14121) * up * up * fix * up * Update examples/pytorch/test_xla_examples.py * correct labels * up * up * up * up * up * up --- .github/workflows/self-scheduled.yml | 39 +++++++++++++++ examples/pytorch/test_xla_examples.py | 72 +++++++++++++-------------- tests/test_trainer_tpu.py | 4 +- 3 files changed, 77 insertions(+), 38 deletions(-) diff --git a/.github/workflows/self-scheduled.yml b/.github/workflows/self-scheduled.yml index 1ecca8f54bc..0027e139975 100644 --- a/.github/workflows/self-scheduled.yml +++ b/.github/workflows/self-scheduled.yml @@ -181,6 +181,45 @@ jobs: name: run_all_tests_tf_gpu_test_reports path: reports + run_all_examples_torch_xla_tpu: + runs-on: [self-hosted, docker-tpu-test, tpu-v3-8] + container: + image: gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm + options: --privileged -v "/lib/libtpu.so:/lib/libtpu.so" -v /mnt/cache/.cache/huggingface:/mnt/cache/ --shm-size 16G + steps: + - name: Launcher docker + uses: actions/checkout@v2 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[testing] + + - name: Are TPUs recognized by our DL frameworks + env: + XRT_TPU_CONFIG: localservice;0;localhost:51011 + run: | + python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())" + + - name: Run example tests on TPU + env: + XRT_TPU_CONFIG: "localservice;0;localhost:51011" + MKL_SERVICE_FORCE_INTEL: "1" # See: https://github.com/pytorch/pytorch/issues/37377 + + run: | + python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_xla_tpu examples/pytorch/test_xla_examples.py + + - name: Failure short reports + if: ${{ always() }} + run: cat reports/tests_torch_xla_tpu_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: run_all_examples_torch_xla_tpu + path: reports + run_all_tests_torch_multi_gpu: runs-on: [self-hosted, docker-gpu, multi-gpu] container: diff --git a/examples/pytorch/test_xla_examples.py b/examples/pytorch/test_xla_examples.py index ed1458a010f..8168a1679b5 100644 --- a/examples/pytorch/test_xla_examples.py +++ b/examples/pytorch/test_xla_examples.py @@ -14,13 +14,14 @@ # limitations under the License. +import json import logging +import os import sys -import unittest from time import time from unittest.mock import patch -from transformers.testing_utils import require_torch_tpu +from transformers.testing_utils import TestCasePlus, require_torch_tpu logging.basicConfig(level=logging.DEBUG) @@ -28,66 +29,65 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() +def get_results(output_dir): + results = {} + path = os.path.join(output_dir, "all_results.json") + if os.path.exists(path): + with open(path, "r") as f: + results = json.load(f) + else: + raise ValueError(f"can't find {path}") + return results + + @require_torch_tpu -class TorchXLAExamplesTests(unittest.TestCase): +class TorchXLAExamplesTests(TestCasePlus): def test_run_glue(self): import xla_spawn stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) - output_directory = "run_glue_output" - + tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" - transformers/examples/text-classification/run_glue.py + ./examples/pytorch/text-classification/run_glue.py --num_cores=8 - transformers/examples/text-classification/run_glue.py + ./examples/pytorch/text-classification/run_glue.py + --model_name_or_path distilbert-base-uncased + --output_dir {tmp_dir} + --overwrite_output_dir + --train_file ./tests/fixtures/tests_samples/MRPC/train.csv + --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv --do_train --do_eval - --task_name=mrpc - --cache_dir=./cache_dir - --num_train_epochs=1 + --debug tpu_metrics_debug + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --learning_rate=1e-4 + --max_steps=10 + --warmup_steps=2 + --seed=42 --max_seq_length=128 - --learning_rate=3e-5 - --output_dir={output_directory} - --overwrite_output_dir - --logging_steps=5 - --save_steps=5 - --overwrite_cache - --tpu_metrics_debug - --model_name_or_path=bert-base-cased - --per_device_train_batch_size=64 - --per_device_eval_batch_size=64 - --evaluation_strategy steps - --overwrite_cache """.split() + with patch.object(sys, "argv", testargs): start = time() xla_spawn.main() end = time() - result = {} - with open(f"{output_directory}/eval_results_mrpc.txt") as f: - lines = f.readlines() - for line in lines: - key, value = line.split(" = ") - result[key] = float(value) + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.75) - del result["eval_loss"] - for value in result.values(): - # Assert that the model trains - self.assertGreaterEqual(value, 0.70) - - # Assert that the script takes less than 300 seconds to make sure it doesn't hang. + # Assert that the script takes less than 500 seconds to make sure it doesn't hang. self.assertLess(end - start, 500) def test_trainer_tpu(self): import xla_spawn testargs = """ - transformers/tests/test_trainer_tpu.py + ./tests/test_trainer_tpu.py --num_cores=8 - transformers/tests/test_trainer_tpu.py + ./tests/test_trainer_tpu.py """.split() with patch.object(sys, "argv", testargs): xla_spawn.main() diff --git a/tests/test_trainer_tpu.py b/tests/test_trainer_tpu.py index e9f8f1a1525..135153fdddd 100644 --- a/tests/test_trainer_tpu.py +++ b/tests/test_trainer_tpu.py @@ -99,7 +99,7 @@ def main(): p = trainer.predict(dataset) logger.info(p.metrics) - if p.metrics["eval_success"] is not True: + if p.metrics["test_success"] is not True: logger.error(p.metrics) exit(1) @@ -113,7 +113,7 @@ def main(): p = trainer.predict(dataset) logger.info(p.metrics) - if p.metrics["eval_success"] is not True: + if p.metrics["test_success"] is not True: logger.error(p.metrics) exit(1)