mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[TPU tests] Enable first TPU examples pytorch (#14121)
* up * up * fix * up * Update examples/pytorch/test_xla_examples.py * correct labels * up * up * up * up * up * up
This commit is contained in:
parent
232822f36d
commit
01b1466983
39
.github/workflows/self-scheduled.yml
vendored
39
.github/workflows/self-scheduled.yml
vendored
@ -181,6 +181,45 @@ jobs:
|
||||
name: run_all_tests_tf_gpu_test_reports
|
||||
path: reports
|
||||
|
||||
run_all_examples_torch_xla_tpu:
|
||||
runs-on: [self-hosted, docker-tpu-test, tpu-v3-8]
|
||||
container:
|
||||
image: gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
|
||||
options: --privileged -v "/lib/libtpu.so:/lib/libtpu.so" -v /mnt/cache/.cache/huggingface:/mnt/cache/ --shm-size 16G
|
||||
steps:
|
||||
- name: Launcher docker
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[testing]
|
||||
|
||||
- name: Are TPUs recognized by our DL frameworks
|
||||
env:
|
||||
XRT_TPU_CONFIG: localservice;0;localhost:51011
|
||||
run: |
|
||||
python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())"
|
||||
|
||||
- name: Run example tests on TPU
|
||||
env:
|
||||
XRT_TPU_CONFIG: "localservice;0;localhost:51011"
|
||||
MKL_SERVICE_FORCE_INTEL: "1" # See: https://github.com/pytorch/pytorch/issues/37377
|
||||
|
||||
run: |
|
||||
python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_xla_tpu examples/pytorch/test_xla_examples.py
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ always() }}
|
||||
run: cat reports/tests_torch_xla_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: run_all_examples_torch_xla_tpu
|
||||
path: reports
|
||||
|
||||
run_all_tests_torch_multi_gpu:
|
||||
runs-on: [self-hosted, docker-gpu, multi-gpu]
|
||||
container:
|
||||
|
@ -14,13 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from time import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import require_torch_tpu
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_tpu
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -28,66 +29,65 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"can't find {path}")
|
||||
return results
|
||||
|
||||
|
||||
@require_torch_tpu
|
||||
class TorchXLAExamplesTests(unittest.TestCase):
|
||||
class TorchXLAExamplesTests(TestCasePlus):
|
||||
def test_run_glue(self):
|
||||
import xla_spawn
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
output_directory = "run_glue_output"
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
transformers/examples/text-classification/run_glue.py
|
||||
./examples/pytorch/text-classification/run_glue.py
|
||||
--num_cores=8
|
||||
transformers/examples/text-classification/run_glue.py
|
||||
./examples/pytorch/text-classification/run_glue.py
|
||||
--model_name_or_path distilbert-base-uncased
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
|
||||
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
|
||||
--do_train
|
||||
--do_eval
|
||||
--task_name=mrpc
|
||||
--cache_dir=./cache_dir
|
||||
--num_train_epochs=1
|
||||
--debug tpu_metrics_debug
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--learning_rate=1e-4
|
||||
--max_steps=10
|
||||
--warmup_steps=2
|
||||
--seed=42
|
||||
--max_seq_length=128
|
||||
--learning_rate=3e-5
|
||||
--output_dir={output_directory}
|
||||
--overwrite_output_dir
|
||||
--logging_steps=5
|
||||
--save_steps=5
|
||||
--overwrite_cache
|
||||
--tpu_metrics_debug
|
||||
--model_name_or_path=bert-base-cased
|
||||
--per_device_train_batch_size=64
|
||||
--per_device_eval_batch_size=64
|
||||
--evaluation_strategy steps
|
||||
--overwrite_cache
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
start = time()
|
||||
xla_spawn.main()
|
||||
end = time()
|
||||
|
||||
result = {}
|
||||
with open(f"{output_directory}/eval_results_mrpc.txt") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
key, value = line.split(" = ")
|
||||
result[key] = float(value)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
|
||||
del result["eval_loss"]
|
||||
for value in result.values():
|
||||
# Assert that the model trains
|
||||
self.assertGreaterEqual(value, 0.70)
|
||||
|
||||
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
|
||||
# Assert that the script takes less than 500 seconds to make sure it doesn't hang.
|
||||
self.assertLess(end - start, 500)
|
||||
|
||||
def test_trainer_tpu(self):
|
||||
import xla_spawn
|
||||
|
||||
testargs = """
|
||||
transformers/tests/test_trainer_tpu.py
|
||||
./tests/test_trainer_tpu.py
|
||||
--num_cores=8
|
||||
transformers/tests/test_trainer_tpu.py
|
||||
./tests/test_trainer_tpu.py
|
||||
""".split()
|
||||
with patch.object(sys, "argv", testargs):
|
||||
xla_spawn.main()
|
||||
|
@ -99,7 +99,7 @@ def main():
|
||||
|
||||
p = trainer.predict(dataset)
|
||||
logger.info(p.metrics)
|
||||
if p.metrics["eval_success"] is not True:
|
||||
if p.metrics["test_success"] is not True:
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
@ -113,7 +113,7 @@ def main():
|
||||
|
||||
p = trainer.predict(dataset)
|
||||
logger.info(p.metrics)
|
||||
if p.metrics["eval_success"] is not True:
|
||||
if p.metrics["test_success"] is not True:
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user