mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[TPU tests] Enable first TPU examples pytorch (#14121)
* up * up * fix * up * Update examples/pytorch/test_xla_examples.py * correct labels * up * up * up * up * up * up
This commit is contained in:
parent
232822f36d
commit
01b1466983
39
.github/workflows/self-scheduled.yml
vendored
39
.github/workflows/self-scheduled.yml
vendored
@ -181,6 +181,45 @@ jobs:
|
|||||||
name: run_all_tests_tf_gpu_test_reports
|
name: run_all_tests_tf_gpu_test_reports
|
||||||
path: reports
|
path: reports
|
||||||
|
|
||||||
|
run_all_examples_torch_xla_tpu:
|
||||||
|
runs-on: [self-hosted, docker-tpu-test, tpu-v3-8]
|
||||||
|
container:
|
||||||
|
image: gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
|
||||||
|
options: --privileged -v "/lib/libtpu.so:/lib/libtpu.so" -v /mnt/cache/.cache/huggingface:/mnt/cache/ --shm-size 16G
|
||||||
|
steps:
|
||||||
|
- name: Launcher docker
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[testing]
|
||||||
|
|
||||||
|
- name: Are TPUs recognized by our DL frameworks
|
||||||
|
env:
|
||||||
|
XRT_TPU_CONFIG: localservice;0;localhost:51011
|
||||||
|
run: |
|
||||||
|
python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())"
|
||||||
|
|
||||||
|
- name: Run example tests on TPU
|
||||||
|
env:
|
||||||
|
XRT_TPU_CONFIG: "localservice;0;localhost:51011"
|
||||||
|
MKL_SERVICE_FORCE_INTEL: "1" # See: https://github.com/pytorch/pytorch/issues/37377
|
||||||
|
|
||||||
|
run: |
|
||||||
|
python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_xla_tpu examples/pytorch/test_xla_examples.py
|
||||||
|
|
||||||
|
- name: Failure short reports
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: cat reports/tests_torch_xla_tpu_failures_short.txt
|
||||||
|
|
||||||
|
- name: Test suite reports artifacts
|
||||||
|
if: ${{ always() }}
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
with:
|
||||||
|
name: run_all_examples_torch_xla_tpu
|
||||||
|
path: reports
|
||||||
|
|
||||||
run_all_tests_torch_multi_gpu:
|
run_all_tests_torch_multi_gpu:
|
||||||
runs-on: [self-hosted, docker-gpu, multi-gpu]
|
runs-on: [self-hosted, docker-gpu, multi-gpu]
|
||||||
container:
|
container:
|
||||||
|
@ -14,13 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from time import time
|
from time import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch_tpu
|
from transformers.testing_utils import TestCasePlus, require_torch_tpu
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -28,66 +29,65 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_results(output_dir):
|
||||||
|
results = {}
|
||||||
|
path = os.path.join(output_dir, "all_results.json")
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"can't find {path}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
@require_torch_tpu
|
@require_torch_tpu
|
||||||
class TorchXLAExamplesTests(unittest.TestCase):
|
class TorchXLAExamplesTests(TestCasePlus):
|
||||||
def test_run_glue(self):
|
def test_run_glue(self):
|
||||||
import xla_spawn
|
import xla_spawn
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
output_directory = "run_glue_output"
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
transformers/examples/text-classification/run_glue.py
|
./examples/pytorch/text-classification/run_glue.py
|
||||||
--num_cores=8
|
--num_cores=8
|
||||||
transformers/examples/text-classification/run_glue.py
|
./examples/pytorch/text-classification/run_glue.py
|
||||||
|
--model_name_or_path distilbert-base-uncased
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
|
||||||
|
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--task_name=mrpc
|
--debug tpu_metrics_debug
|
||||||
--cache_dir=./cache_dir
|
--per_device_train_batch_size=2
|
||||||
--num_train_epochs=1
|
--per_device_eval_batch_size=1
|
||||||
|
--learning_rate=1e-4
|
||||||
|
--max_steps=10
|
||||||
|
--warmup_steps=2
|
||||||
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
--learning_rate=3e-5
|
|
||||||
--output_dir={output_directory}
|
|
||||||
--overwrite_output_dir
|
|
||||||
--logging_steps=5
|
|
||||||
--save_steps=5
|
|
||||||
--overwrite_cache
|
|
||||||
--tpu_metrics_debug
|
|
||||||
--model_name_or_path=bert-base-cased
|
|
||||||
--per_device_train_batch_size=64
|
|
||||||
--per_device_eval_batch_size=64
|
|
||||||
--evaluation_strategy steps
|
|
||||||
--overwrite_cache
|
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
start = time()
|
start = time()
|
||||||
xla_spawn.main()
|
xla_spawn.main()
|
||||||
end = time()
|
end = time()
|
||||||
|
|
||||||
result = {}
|
result = get_results(tmp_dir)
|
||||||
with open(f"{output_directory}/eval_results_mrpc.txt") as f:
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
lines = f.readlines()
|
|
||||||
for line in lines:
|
|
||||||
key, value = line.split(" = ")
|
|
||||||
result[key] = float(value)
|
|
||||||
|
|
||||||
del result["eval_loss"]
|
# Assert that the script takes less than 500 seconds to make sure it doesn't hang.
|
||||||
for value in result.values():
|
|
||||||
# Assert that the model trains
|
|
||||||
self.assertGreaterEqual(value, 0.70)
|
|
||||||
|
|
||||||
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
|
|
||||||
self.assertLess(end - start, 500)
|
self.assertLess(end - start, 500)
|
||||||
|
|
||||||
def test_trainer_tpu(self):
|
def test_trainer_tpu(self):
|
||||||
import xla_spawn
|
import xla_spawn
|
||||||
|
|
||||||
testargs = """
|
testargs = """
|
||||||
transformers/tests/test_trainer_tpu.py
|
./tests/test_trainer_tpu.py
|
||||||
--num_cores=8
|
--num_cores=8
|
||||||
transformers/tests/test_trainer_tpu.py
|
./tests/test_trainer_tpu.py
|
||||||
""".split()
|
""".split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
xla_spawn.main()
|
xla_spawn.main()
|
||||||
|
@ -99,7 +99,7 @@ def main():
|
|||||||
|
|
||||||
p = trainer.predict(dataset)
|
p = trainer.predict(dataset)
|
||||||
logger.info(p.metrics)
|
logger.info(p.metrics)
|
||||||
if p.metrics["eval_success"] is not True:
|
if p.metrics["test_success"] is not True:
|
||||||
logger.error(p.metrics)
|
logger.error(p.metrics)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ def main():
|
|||||||
|
|
||||||
p = trainer.predict(dataset)
|
p = trainer.predict(dataset)
|
||||||
logger.info(p.metrics)
|
logger.info(p.metrics)
|
||||||
if p.metrics["eval_success"] is not True:
|
if p.metrics["test_success"] is not True:
|
||||||
logger.error(p.metrics)
|
logger.error(p.metrics)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user