diff --git a/Makefile b/Makefile index 9ef8e2659d8..6a09470050a 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,12 @@ test: test-examples: python -m pytest -n auto --dist=loadfile -s -v ./examples/ +# Run tests for SageMaker DLC release + +test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker] + TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker + + # Check that docs can build docs: diff --git a/setup.py b/setup.py index 0744058e661..d25376fa7ca 100644 --- a/setup.py +++ b/setup.py @@ -19,15 +19,17 @@ To create the package for pypi. 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the documentation. + +2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid. -2. Unpin specific versions from setup.py that use a git install. +3. Unpin specific versions from setup.py that use a git install. -3. Commit these changes with the message: "Release: VERSION" +4. Commit these changes with the message: "Release: VERSION" -4. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " +5. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " Push the tag to git: git push --tags origin master -5. Build both the sources and the wheel. Do not change anything in setup.py between +6. Build both the sources and the wheel. Do not change anything in setup.py between creating the wheel and the source distribution (obviously). For the wheel, run: "python setup.py bdist_wheel" in the top level directory. @@ -36,7 +38,7 @@ To create the package for pypi. For the sources, run: "python setup.py sdist" You should now have a /dist directory with both .whl and .tar.gz source versions. -6. Check that everything looks correct by uploading the package to the pypi test server: +7. Check that everything looks correct by uploading the package to the pypi test server: twine upload dist/* -r pypitest (pypi suggest using twine as other methods upload files via plaintext.) @@ -46,12 +48,12 @@ To create the package for pypi. Check that you can install it in a virtualenv by running: pip install -i https://testpypi.python.org/pypi transformers -7. Upload the final version to actual pypi: +8. Upload the final version to actual pypi: twine upload dist/* -r pypi -8. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. +9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. -9. Run `make post-release` (or `make post-patch` for a patch release). +10. Run `make post-release` (or `make post-patch` for a patch release). """ import os @@ -134,6 +136,7 @@ _deps = [ "unidic>=1.0.2", "unidic_lite>=1.0.7", "uvicorn", + "sagemaker>=2.31.0", ] @@ -223,12 +226,16 @@ extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools") extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"] extras["modelcreation"] = deps_list("cookiecutter") +extras["sagemaker"] = deps_list("sagemaker") + extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["speech"] = deps_list("soundfile", "torchaudio") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["testing"] = ( - deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar", "black") + deps_list( + "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar", "black" + ) + extras["retrieval"] + extras["modelcreation"] ) diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 8e0f3773e94..1b89ed9d5c3 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -53,4 +53,5 @@ deps = { "unidic": "unidic>=1.0.2", "unidic_lite": "unidic_lite>=1.0.7", "uvicorn": "uvicorn", + "sagemaker": "sagemaker>=2.31.0", } diff --git a/tests/sagemaker/README.md b/tests/sagemaker/README.md new file mode 100644 index 00000000000..12e2f8d890f --- /dev/null +++ b/tests/sagemaker/README.md @@ -0,0 +1,153 @@ +# Testing new Hugging Face Deep Learning Container. + +This document explains the testing strategy for releasing the new Hugging Face Deep Learning Container. AWS maintains 14 days of currency with framework releases. Besides framework releases, AWS release train is bi-weekly on Monday. Code cutoff date for any changes is the Wednesday before release-Monday. + + +## Test Case 1: Releasing a New Version (Minor/Major) of 🤗 Transformers + +### Requirements: Test should run on Release Candidate for new `transformers` release to validate the new release is compatible with the DLCs. To run these tests you need credentials for the HF SageMaker AWS Account. You can ask @philschmid or @n1t0 to get access. + +### Run Tests: + +Before we can run the tests we need to adjust the `requirements.txt` for PyTorch under `/tests/sagemaker/scripts/pytorch` and for TensorFlow under `/tests/sagemaker/scripts/pytorch`. We adjust the branch to the new RC-tag. + +``` +git+https://github.com/huggingface/transformers.git@v4.5.0.rc0 # install master or adjust ist with vX.X.X for installing version specific-transforms +``` + +After we adjusted the `requirements.txt` we can run Amazon SageMaker tests with: + +```bash +AWS_PROFILE= make sagemaker-test +``` +These tests take around 10-15 minutes to finish. Preferably make a screenshot of the successfully ran tests. + +### After Transformers Release: + +After we have released the Release Candidate we need to create a PR at the [Deep Learning Container Repository](https://github.com/aws/deep-learning-containers). + +**Creating the update PR:** + +1. Update the two latest `buildspec.yaml` config for [PyTorch](https://github.com/aws/deep-learning-containers/tree/master/huggingface/pytorch) and [TensorFlow](https://github.com/aws/deep-learning-containers/tree/master/huggingface/tensorflow). The two latest `buildspec.yaml` are the `buildspec.yaml` without a version tag and the one with the highest framework version, e.g. `buildspec-1-7-1.yml` and not `buildspec-1-6.yml`. + +To update the `buildspec.yaml` we need to adjust either the `transformers_version` or the `datasets_version` or both. Example for upgrading to `transformers 4.5.0` and `datasets 1.6.0`. +```yaml +account_id: &ACCOUNT_ID +region: ®ION +base_framework: &BASE_FRAMEWORK pytorch +framework: &FRAMEWORK !join [ "huggingface_", *BASE_FRAMEWORK] +version: &VERSION 1.6.0 +short_version: &SHORT_VERSION 1.6 + +repository_info: + training_repository: &TRAINING_REPOSITORY + image_type: &TRAINING_IMAGE_TYPE training + root: !join [ "huggingface/", *BASE_FRAMEWORK, "/", *TRAINING_IMAGE_TYPE ] + repository_name: &REPOSITORY_NAME !join ["pr", "-", "huggingface", "-", *BASE_FRAMEWORK, "-", *TRAINING_IMAGE_TYPE] + repository: &REPOSITORY !join [ *ACCOUNT_ID, .dkr.ecr., *REGION, .amazonaws.com/, + *REPOSITORY_NAME ] + +images: + BuildHuggingFacePytorchGpuPy37Cu110TrainingDockerImage: + <<: *TRAINING_REPOSITORY + build: &HUGGINGFACE_PYTORCH_GPU_TRAINING_PY3 false + image_size_baseline: &IMAGE_SIZE_BASELINE 15000 + device_type: &DEVICE_TYPE gpu + python_version: &DOCKER_PYTHON_VERSION py3 + tag_python_version: &TAG_PYTHON_VERSION py36 + cuda_version: &CUDA_VERSION cu110 + os_version: &OS_VERSION ubuntu18.04 + transformers_version: &TRANSFORMERS_VERSION 4.5.0 # this was adjusted from 4.4.2 to 4.5.0 + datasets_version: &DATASETS_VERSION 1.6.0 # this was adjusted from 1.5.0 to 1.6.0 + tag: !join [ *VERSION, '-', 'transformers', *TRANSFORMERS_VERSION, '-', *DEVICE_TYPE, '-', *TAG_PYTHON_VERSION, '-', + *CUDA_VERSION, '-', *OS_VERSION ] + docker_file: !join [ docker/, *SHORT_VERSION, /, *DOCKER_PYTHON_VERSION, /, + *CUDA_VERSION, /Dockerfile., *DEVICE_TYPE ] +``` +2. In the PR comment describe what test, we ran and with which package versions. Here you can copy the table from [Current Tests](#current-tests). + +TODO: Add a screenshot of PR + Text template to make it easy to open. + +## Test Case 2: Releasing a New AWS Framework DLC + + +## Execute Tests + +### Requirements: +AWS is going to release new DLCs for PyTorch and/or TensorFlow. The Tests should run on the new framework versions with current `transformers` release to validate the new framework release is compatible with the `transformers` version. To run these tests you need credentials for the HF SageMaker AWS Account. You can ask @philschmid or @n1t0 to get access. AWS will notify us with a new issue in the repository pointing to their framework upgrade PR. + +### Run Tests: + +Before we can run the tests we need to adjust the `requirements.txt` for Pytorch under `/tests/sagemaker/scripts/pytorch` and for Tensorflow under `/tests/sagemaker/scripts/pytorch`. We add the new framework version to it. + +``` +torch==1.8.1 # for pytorch +tensorflow-gpu==2.5.0 # for tensorflow +``` + +After we adjusted the `requirements.txt` we can run Amazon SageMaker tests with. + +```bash +AWS_PROFILE= make sagemaker-test +``` +These tests take around 10-15 minutes to finish. Preferably make a screenshot of the successfully ran tests. + + +### After successful Tests: + +After we have successfully run tests for the new framework version we need to create a PR at the [Deep Learning Container Repository](https://github.com/aws/deep-learning-containers). + +**Creating the update PR:** + +1. Create a new `buildspec.yaml` config for [PyTorch](https://github.com/aws/deep-learning-containers/tree/master/huggingface/pytorch) and [TensorFlow](https://github.com/aws/deep-learning-containers/tree/master/huggingface/tensorflow) and rename the old `buildspec.yaml` to `buildespec-x.x.x`, where `x.x.x` is the base framework version, e.g. if pytorch 1.6.0 is the latest version in `buildspec.yaml` the file should be renamed to `buildspec-yaml-1-6.yaml`. + +To create the new `buildspec.yaml` we need to adjust the `version` and the `short_version`. Example for upgrading to `pytorch 1.7.1`. + +```yaml +account_id: &ACCOUNT_ID +region: ®ION +base_framework: &BASE_FRAMEWORK pytorch +framework: &FRAMEWORK !join [ "huggingface_", *BASE_FRAMEWORK] +version: &VERSION 1.7.1 # this was adjusted from 1.6.0 to 1.7.1 +short_version: &SHORT_VERSION 1.7 # this was adjusted from 1.6 to 1.7 + +repository_info: + training_repository: &TRAINING_REPOSITORY + image_type: &TRAINING_IMAGE_TYPE training + root: !join [ "huggingface/", *BASE_FRAMEWORK, "/", *TRAINING_IMAGE_TYPE ] + repository_name: &REPOSITORY_NAME !join ["pr", "-", "huggingface", "-", *BASE_FRAMEWORK, "-", *TRAINING_IMAGE_TYPE] + repository: &REPOSITORY !join [ *ACCOUNT_ID, .dkr.ecr., *REGION, .amazonaws.com/, + *REPOSITORY_NAME ] + +images: + BuildHuggingFacePytorchGpuPy37Cu110TrainingDockerImage: + <<: *TRAINING_REPOSITORY + build: &HUGGINGFACE_PYTORCH_GPU_TRAINING_PY3 false + image_size_baseline: &IMAGE_SIZE_BASELINE 15000 + device_type: &DEVICE_TYPE gpu + python_version: &DOCKER_PYTHON_VERSION py3 + tag_python_version: &TAG_PYTHON_VERSION py36 + cuda_version: &CUDA_VERSION cu110 + os_version: &OS_VERSION ubuntu18.04 + transformers_version: &TRANSFORMERS_VERSION 4.4.2 + datasets_version: &DATASETS_VERSION 1.5.0 + tag: !join [ *VERSION, '-', 'transformers', *TRANSFORMERS_VERSION, '-', *DEVICE_TYPE, '-', *TAG_PYTHON_VERSION, '-', + *CUDA_VERSION, '-', *OS_VERSION ] + docker_file: !join [ docker/, *SHORT_VERSION, /, *DOCKER_PYTHON_VERSION, /, + *CUDA_VERSION, /Dockerfile., *DEVICE_TYPE ] +``` +2. In the PR comment describe what test we ran and with which framework versions. Here you can copy the table from [Current Tests](#current-tests). + +TODO: Add a screenshot of PR + Text template to make it easy to open. + + +## Current Tests + +| ID | Description | Platform | #GPUS | Collected & evaluated metrics | +|-------------------------------------|-------------------------------------------------------------------|-----------------------------|-------|------------------------------------------| +| pytorch-transfromers-test-single | test bert finetuning using BERT fromtransformerlib+PT | SageMaker createTrainingJob | 1 | train_runtime, eval_accuracy & eval_loss | +| pytorch-transfromers-test-2-ddp | test bert finetuning using BERT from transformer lib+ PT DPP | SageMaker createTrainingJob | 16 | train_runtime, eval_accuracy & eval_loss | +| pytorch-transfromers-test-2-smd | test bert finetuning using BERT from transformer lib+ PT SM DDP | SageMaker createTrainingJob | 16 | train_runtime, eval_accuracy & eval_loss | +| pytorch-transfromers-test-1-smp | test roberta finetuning using BERT from transformer lib+ PT SM MP | SageMaker createTrainingJob | 8 | train_runtime, eval_accuracy & eval_loss | +| tensorflow-transfromers-test-single | Test bert finetuning using BERT from transformer lib+TF | SageMaker createTrainingJob | 1 | train_runtime, eval_accuracy & eval_loss | +| tensorflow-transfromers-test-2-smd | test bert finetuning using BERT from transformer lib+ TF SM DDP | SageMaker createTrainingJob | 16 | train_runtime, eval_accuracy & eval_loss | \ No newline at end of file diff --git a/tests/sagemaker/__init__.py b/tests/sagemaker/__init__.py new file mode 100644 index 00000000000..ecda04614d4 --- /dev/null +++ b/tests/sagemaker/__init__.py @@ -0,0 +1,5 @@ +import importlib + + +def is_sagemaker_available(): + return importlib.util.find_spec("sagemaker") is not None diff --git a/tests/sagemaker/conftest.py b/tests/sagemaker/conftest.py new file mode 100644 index 00000000000..076e06784bc --- /dev/null +++ b/tests/sagemaker/conftest.py @@ -0,0 +1,65 @@ +# we define a fixture function below and it will be "used" by +# referencing its name from tests + +import os + +import pytest + +from attr import dataclass + + +os.environ["AWS_DEFAULT_REGION"] = "us-east-1" # defaults region + + +@dataclass +class SageMakerTestEnvironment: + framework: str + role = "arn:aws:iam::558105141721:role/sagemaker_execution_role" + hyperparameters = { + "task_name": "mnli", + "per_device_train_batch_size": 32, + "per_device_eval_batch_size": 32, + "do_train": True, + "do_eval": True, + "do_predict": True, + "output_dir": "/opt/ml/model", + "overwrite_output_dir": True, + "max_steps": 500, + "save_steps": 5500, + } + distributed_hyperparameters = {**hyperparameters, "max_steps": 1000} + + @property + def metric_definitions(self) -> str: + if self.framework == "pytorch": + return [ + {"Name": "train_runtime", "Regex": "train_runtime.*=\D*(.*?)$"}, + {"Name": "eval_accuracy", "Regex": "eval_accuracy.*=\D*(.*?)$"}, + {"Name": "eval_loss", "Regex": "eval_loss.*=\D*(.*?)$"}, + ] + else: + return [ + {"Name": "train_runtime", "Regex": "train_runtime.*=\D*(.*?)$"}, + {"Name": "eval_accuracy", "Regex": "loss.*=\D*(.*?)]?$"}, + {"Name": "eval_loss", "Regex": "sparse_categorical_accuracy.*=\D*(.*?)]?$"}, + ] + + @property + def base_job_name(self) -> str: + return f"{self.framework}-transfromers-test" + + @property + def test_path(self) -> str: + return f"./tests/sagemaker/scripts/{self.framework}" + + @property + def image_uri(self) -> str: + if self.framework == "pytorch": + return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + else: + return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.4.2-gpu-py37-cu110-ubuntu18.04" + + +@pytest.fixture(scope="class") +def sm_env(request): + request.cls.env = SageMakerTestEnvironment(framework=request.cls.framework) diff --git a/tests/sagemaker/scripts/pytorch/requirements.txt b/tests/sagemaker/scripts/pytorch/requirements.txt new file mode 100644 index 00000000000..0194b67c403 --- /dev/null +++ b/tests/sagemaker/scripts/pytorch/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/huggingface/transformers.git@master # install master or adjust ist with vX.X.X for installing version specific transforms \ No newline at end of file diff --git a/tests/sagemaker/scripts/pytorch/run_ddp.py b/tests/sagemaker/scripts/pytorch/run_ddp.py new file mode 100644 index 00000000000..1191caeb96a --- /dev/null +++ b/tests/sagemaker/scripts/pytorch/run_ddp.py @@ -0,0 +1,52 @@ +import json +import logging +import os +import subprocess +from argparse import ArgumentParser + + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = ArgumentParser() + parsed, unknown = parser.parse_known_args() + for arg in unknown: + if arg.startswith(("-", "--")): + parser.add_argument(arg.split("=")[0]) + + return parser.parse_args() + + +def main(): + args = parse_args() + port = 8888 + num_gpus = int(os.environ["SM_NUM_GPUS"]) + hosts = json.loads(os.environ["SM_HOSTS"]) + num_nodes = len(hosts) + current_host = os.environ["SM_CURRENT_HOST"] + rank = hosts.index(current_host) + os.environ["NCCL_DEBUG"] = "INFO" + + if num_nodes > 1: + cmd = f"""python -m torch.distributed.launch \ + --nnodes={num_nodes} \ + --node_rank={rank} \ + --nproc_per_node={num_gpus} \ + --master_addr={hosts[0]} \ + --master_port={port} \ + ./run_glue.py \ + {"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" + else: + cmd = f"""python -m torch.distributed.launch \ + --nproc_per_node={num_gpus} \ + ./run_glue.py \ + {"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" + try: + subprocess.run(cmd, shell=True) + except Exception as e: + logger.info(e) + + +if __name__ == "__main__": + main() diff --git a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py new file mode 100644 index 00000000000..1bc9ed4ce82 --- /dev/null +++ b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( # Trainer,; TrainingArguments, + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + default_data_collator, + set_seed, +) + +# Will import SageMaker Model parallelism specific Trainer +from transformers.sagemaker import SageMakerTrainer as Trainer +from transformers.sagemaker import SageMakerTrainingArguments as TrainingArguments +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.4.2") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task or a training/validation file.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset("glue", data_args.task_name) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset("csv", data_files=data_files) + else: + # Loading a dataset from local json files + datasets = load_dataset("json", data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in datasets["train"].column_names if name != "label"] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warn( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache) + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in datasets and "validation_matched" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + + if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: + if "test" not in datasets and "test_matched" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from + # compute_metrics + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.save_model() # Saves the tokenizer too for easy upload + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + eval_datasets.append(datasets["validation_mismatched"]) + + for eval_dataset, task in zip(eval_datasets, tasks): + metrics = trainer.evaluate(eval_dataset=eval_dataset) + + max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Test ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + test_datasets.append(datasets["test_mismatched"]) + + for test_dataset, task in zip(test_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + test_dataset.remove_columns_("label") + predictions = trainer.predict(test_dataset=test_dataset).predictions + predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + + output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + logger.info(f"***** Test results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/tests/sagemaker/scripts/tensorflow/requirements.txt b/tests/sagemaker/scripts/tensorflow/requirements.txt new file mode 100644 index 00000000000..0194b67c403 --- /dev/null +++ b/tests/sagemaker/scripts/tensorflow/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/huggingface/transformers.git@master # install master or adjust ist with vX.X.X for installing version specific transforms \ No newline at end of file diff --git a/tests/sagemaker/scripts/tensorflow/run_tf.py b/tests/sagemaker/scripts/tensorflow/run_tf.py new file mode 100644 index 00000000000..21716e996c5 --- /dev/null +++ b/tests/sagemaker/scripts/tensorflow/run_tf.py @@ -0,0 +1,91 @@ +import argparse +import logging +import sys +import time + +import tensorflow as tf +from datasets import load_dataset + +from transformers import AutoTokenizer, TFAutoModelForSequenceClassification + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # Hyperparameters sent by the client are passed as command-line arguments to the script. + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--per_device_train_batch_size", type=int, default=16) + parser.add_argument("--per_device_eval_batch_size", type=int, default=8) + parser.add_argument("--model_name_or_path", type=str) + parser.add_argument("--learning_rate", type=str, default=5e-5) + parser.add_argument("--do_train", type=bool, default=True) + parser.add_argument("--do_eval", type=bool, default=True) + parser.add_argument("--output_dir", type=str) + + args, _ = parser.parse_known_args() + + # overwrite batch size until we have tf_glue.py + args.per_device_train_batch_size = 16 + args.per_device_eval_batch_size = 16 + + # Set up logging + logger = logging.getLogger(__name__) + + logging.basicConfig( + level=logging.getLevelName("INFO"), + handlers=[logging.StreamHandler(sys.stdout)], + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Load model and tokenizer + model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + + # Load dataset + train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"]) + train_dataset = train_dataset.shuffle().select(range(5000)) # smaller the size for train dataset to 5k + test_dataset = test_dataset.shuffle().select(range(500)) # smaller the size for test dataset to 500 + + # Preprocess train dataset + train_dataset = train_dataset.map( + lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True + ) + train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) + + train_features = { + x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) + for x in ["input_ids", "attention_mask"] + } + tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"])).batch( + args.per_device_train_batch_size + ) + + # Preprocess test dataset + test_dataset = test_dataset.map( + lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True + ) + test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) + + test_features = { + x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) + for x in ["input_ids", "attention_mask"] + } + tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"])).batch( + args.per_device_eval_batch_size + ) + + # fine optimizer and loss + optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metrics = [tf.keras.metrics.SparseCategoricalAccuracy()] + model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + + start_train_time = time.time() + train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size) + end_train_time = time.time() - start_train_time + + logger.info("*** Train ***") + logger.info("train_runtime = %s", end_train_time) + for key, value in train_results.history.items(): + logger.info(" %s = %s", key, value) diff --git a/tests/sagemaker/scripts/tensorflow/run_tf_dist.py b/tests/sagemaker/scripts/tensorflow/run_tf_dist.py new file mode 100644 index 00000000000..7bfe76571af --- /dev/null +++ b/tests/sagemaker/scripts/tensorflow/run_tf_dist.py @@ -0,0 +1,194 @@ +import argparse +import logging +import os +import sys +import time + +import tensorflow as tf +from datasets import load_dataset +from tqdm import tqdm + +from transformers import AutoTokenizer, TFAutoModelForSequenceClassification +from transformers.file_utils import is_sagemaker_distributed_available + + +if os.environ.get("SDP_ENABLED") or is_sagemaker_distributed_available(): + SDP_ENABLED = True + os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge" + import smdistributed.dataparallel.tensorflow as sdp +else: + SDP_ENABLED = False + + +def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None): + pbar = tqdm(train_dataset) + for i, batch in enumerate(pbar): + with tf.GradientTape() as tape: + inputs, targets = batch + outputs = model(batch) + loss_value = loss(targets, outputs.logits) + + if SDP_ENABLED: + tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True) + + grads = tape.gradient(loss_value, model.trainable_variables) + opt.apply_gradients(zip(grads, model.trainable_variables)) + + pbar.set_description(f"Loss: {loss_value:.4f}") + + if SDP_ENABLED and i == 0: + sdp.broadcast_variables(model.variables, root_rank=0) + sdp.broadcast_variables(opt.variables(), root_rank=0) + + if max_steps and i >= max_steps: + break + + train_results = {"loss": loss_value.numpy()} + return train_results + + +def get_datasets(tokenizer, train_batch_size, eval_batch_size): + # Load dataset + train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"]) + + # Preprocess train dataset + train_dataset = train_dataset.map( + lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True + ) + train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) + + train_features = { + x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) + for x in ["input_ids", "attention_mask"] + } + tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"])) + + # Preprocess test dataset + test_dataset = test_dataset.map( + lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True + ) + test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) + + test_features = { + x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) + for x in ["input_ids", "attention_mask"] + } + tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"])) + + if SDP_ENABLED: + tf_train_dataset = tf_train_dataset.shard(sdp.size(), sdp.rank()) + tf_test_dataset = tf_test_dataset.shard(sdp.size(), sdp.rank()) + tf_train_dataset = tf_train_dataset.batch(train_batch_size, drop_remainder=True) + tf_test_dataset = tf_test_dataset.batch(eval_batch_size, drop_remainder=True) + + return tf_train_dataset, tf_test_dataset + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # Hyperparameters sent by the client are passed as command-line arguments to the script. + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--per_device_train_batch_size", type=int, default=16) + parser.add_argument("--per_device_eval_batch_size", type=int, default=8) + parser.add_argument("--model_name_or_path", type=str) + parser.add_argument("--learning_rate", type=str, default=5e-5) + parser.add_argument("--do_train", type=bool, default=True) + parser.add_argument("--do_eval", type=bool, default=True) + parser.add_argument("--output_dir", type=str) + parser.add_argument("--max_steps", type=int, default=None) + + # Data, model, and output directories + parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) + parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) + + args, _ = parser.parse_known_args() + + # Set up logging + logger = logging.getLogger(__name__) + + logging.basicConfig( + level=logging.getLevelName("INFO"), + handlers=[logging.StreamHandler(sys.stdout)], + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + if SDP_ENABLED: + sdp.init() + + gpus = tf.config.experimental.list_physical_devices("GPU") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + if gpus: + tf.config.experimental.set_visible_devices(gpus[sdp.local_rank()], "GPU") + + # Load model and tokenizer + model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + + # get datasets + tf_train_dataset, tf_test_dataset = get_datasets( + tokenizer=tokenizer, + train_batch_size=args.per_device_train_batch_size, + eval_batch_size=args.per_device_eval_batch_size, + ) + + # fine optimizer and loss + optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metrics = [tf.keras.metrics.SparseCategoricalAccuracy()] + model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + + # Training + if args.do_train: + + # train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.train_batch_size) + start_train_time = time.time() + train_results = fit( + model, + loss, + optimizer, + tf_train_dataset, + args.epochs, + args.per_device_train_batch_size, + max_steps=args.max_steps, + ) + end_train_time = time.time() - start_train_time + logger.info("*** Train ***") + logger.info("train_runtime = %s", end_train_time) + + output_eval_file = os.path.join(args.output_dir, "train_results.txt") + + if not SDP_ENABLED or sdp.rank() == 0: + with open(output_eval_file, "w") as writer: + logger.info("***** Train results *****") + logger.info(train_results) + for key, value in train_results.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + + # Evaluation + if args.do_eval and (not SDP_ENABLED or sdp.rank() == 0): + + result = model.evaluate(tf_test_dataset, batch_size=args.per_device_eval_batch_size, return_dict=True) + logger.info("*** Evaluate ***") + + output_eval_file = os.path.join(args.output_dir, "eval_results.txt") + + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + logger.info(result) + for key, value in result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + + # Save result + if SDP_ENABLED: + if sdp.rank() == 0: + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + else: + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) diff --git a/tests/sagemaker/test_multi_node_data_parallel.py b/tests/sagemaker/test_multi_node_data_parallel.py new file mode 100644 index 00000000000..460465606cb --- /dev/null +++ b/tests/sagemaker/test_multi_node_data_parallel.py @@ -0,0 +1,104 @@ +import os +import subprocess +import unittest +from ast import literal_eval + +import pytest + +from parameterized import parameterized, parameterized_class + +from . import is_sagemaker_available + + +if is_sagemaker_available(): + from sagemaker import TrainingJobAnalytics + from sagemaker.huggingface import HuggingFace + + +@pytest.mark.skipif( + literal_eval(os.getenv("TEST_SAGEMAKER", "False")) is not True, + reason="Skipping test because should only be run when releasing minor transformers version", +) +@pytest.mark.usefixtures("sm_env") +@parameterized_class( + [ + { + "framework": "pytorch", + "script": "run_glue.py", + "model_name_or_path": "distilbert-base-cased", + "instance_type": "ml.p3dn.24xlarge", + "results": {"train_runtime": 300, "eval_accuracy": 0.7, "eval_loss": 0.6}, + }, + { + "framework": "pytorch", + "script": "run_ddp.py", + "model_name_or_path": "distilbert-base-cased", + "instance_type": "ml.p3dn.24xlarge", + "results": {"train_runtime": 300, "eval_accuracy": 0.7, "eval_loss": 0.6}, + }, + { + "framework": "tensorflow", + "script": "run_tf_dist.py", + "model_name_or_path": "distilbert-base-cased", + "instance_type": "ml.p3dn.24xlarge", + "results": {"train_runtime": 500, "eval_accuracy": 0.6, "eval_loss": 0.7}, + }, + ] +) +class MultiNodeTest(unittest.TestCase): + def setUp(self): + if self.framework == "pytorch": + subprocess.run( + f"cp ./examples/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) + assert hasattr(self, "env") + + def create_estimator(self, instance_count): + job_name = f"{self.env.base_job_name}-{instance_count}-{'ddp' if 'ddp' in self.script else 'smd'}" + # distributed data settings + distribution = {"smdistributed": {"dataparallel": {"enabled": True}}} + + # creates estimator + return HuggingFace( + entry_point=self.script, + source_dir=self.env.test_path, + role=self.env.role, + image_uri=self.env.image_uri, + base_job_name=job_name, + instance_count=instance_count, + instance_type=self.instance_type, + debugger_hook_config=False, + hyperparameters={**self.env.distributed_hyperparameters, "model_name_or_path": self.model_name_or_path}, + metric_definitions=self.env.metric_definitions, + distribution=distribution, + py_version="py36", + ) + + def save_results_as_csv(self, job_name): + TrainingJobAnalytics(job_name).export_csv(f"{self.env.test_path}/{job_name}_metrics.csv") + + # @parameterized.expand([(2,), (4,),]) + @parameterized.expand([(2,)]) + def test_script(self, instance_count): + # create estimator + estimator = self.create_estimator(instance_count) + + # run training + estimator.fit() + + # save csv + self.save_results_as_csv(estimator.latest_training_job.name) + # result dataframe + result_metrics_df = TrainingJobAnalytics(estimator.latest_training_job.name).dataframe() + + # extract kpis + train_runtime = list(result_metrics_df[result_metrics_df.metric_name == "train_runtime"]["value"]) + eval_accuracy = list(result_metrics_df[result_metrics_df.metric_name == "eval_accuracy"]["value"]) + eval_loss = list(result_metrics_df[result_metrics_df.metric_name == "eval_loss"]["value"]) + + # assert kpis + assert all(t <= self.results["train_runtime"] for t in train_runtime) + assert any(t >= self.results["eval_accuracy"] for t in eval_accuracy) + assert all(t <= self.results["eval_loss"] for t in eval_loss) diff --git a/tests/sagemaker/test_multi_node_model_parallel.py b/tests/sagemaker/test_multi_node_model_parallel.py new file mode 100644 index 00000000000..bca402bcba4 --- /dev/null +++ b/tests/sagemaker/test_multi_node_model_parallel.py @@ -0,0 +1,103 @@ +import os +import unittest +from ast import literal_eval + +import pytest + +from parameterized import parameterized, parameterized_class + +from . import is_sagemaker_available + + +if is_sagemaker_available(): + from sagemaker import TrainingJobAnalytics + from sagemaker.huggingface import HuggingFace + + +@pytest.mark.skipif( + literal_eval(os.getenv("TEST_SAGEMAKER", "False")) is not True, + reason="Skipping test because should only be run when releasing minor transformers version", +) +@pytest.mark.usefixtures("sm_env") +@parameterized_class( + [ + { + "framework": "pytorch", + "script": "run_glue_model_parallelism.py", + "model_name_or_path": "roberta-large", + "instance_type": "ml.p3dn.24xlarge", + "results": {"train_runtime": 700, "eval_accuracy": 0.3, "eval_loss": 1.2}, + }, + ] +) +class MultiNodeTest(unittest.TestCase): + def setUp(self): + assert hasattr(self, "env") + + def create_estimator(self, instance_count): + + # configuration for running training on smdistributed Model Parallel + mpi_options = { + "enabled": True, + "processes_per_host": 8, + } + smp_options = { + "enabled": True, + "parameters": { + "microbatches": 4, + "placement_strategy": "spread", + "pipeline": "interleaved", + "optimize": "speed", + "partitions": 4, + "ddp": True, + }, + } + + distribution = {"smdistributed": {"modelparallel": smp_options}, "mpi": mpi_options} + + # creates estimator + return HuggingFace( + entry_point=self.script, + source_dir=self.env.test_path, + role=self.env.role, + image_uri=self.env.image_uri, + base_job_name=f"{self.env.base_job_name}-{instance_count}-smp", + instance_count=instance_count, + instance_type=self.instance_type, + debugger_hook_config=False, + hyperparameters={ + **self.env.hyperparameters, + "model_name_or_path": self.model_name_or_path, + "max_steps": 500, + }, + metric_definitions=self.env.metric_definitions, + distribution=distribution, + py_version="py36", + ) + + def save_results_as_csv(self, job_name): + TrainingJobAnalytics(job_name).export_csv(f"{self.env.test_path}/{job_name}_metrics.csv") + + # @parameterized.expand([(2,), (4,),]) + @parameterized.expand([(1,)]) + def test_scripz(self, instance_count): + # create estimator + estimator = self.create_estimator(instance_count) + + # run training + estimator.fit() + + # save csv + self.save_results_as_csv(estimator.latest_training_job.name) + # result dataframe + result_metrics_df = TrainingJobAnalytics(estimator.latest_training_job.name).dataframe() + + # extract kpis + train_runtime = list(result_metrics_df[result_metrics_df.metric_name == "train_runtime"]["value"]) + eval_accuracy = list(result_metrics_df[result_metrics_df.metric_name == "eval_accuracy"]["value"]) + eval_loss = list(result_metrics_df[result_metrics_df.metric_name == "eval_loss"]["value"]) + + # assert kpis + assert all(t <= self.results["train_runtime"] for t in train_runtime) + assert all(t >= self.results["eval_accuracy"] for t in eval_accuracy) + assert all(t <= self.results["eval_loss"] for t in eval_loss) diff --git a/tests/sagemaker/test_single_node_gpu.py b/tests/sagemaker/test_single_node_gpu.py new file mode 100644 index 00000000000..aa08bd06419 --- /dev/null +++ b/tests/sagemaker/test_single_node_gpu.py @@ -0,0 +1,90 @@ +import os +import subprocess +import unittest +from ast import literal_eval + +import pytest + +from parameterized import parameterized_class + +from . import is_sagemaker_available + + +if is_sagemaker_available(): + from sagemaker import TrainingJobAnalytics + from sagemaker.huggingface import HuggingFace + + +@pytest.mark.skipif( + literal_eval(os.getenv("TEST_SAGEMAKER", "False")) is not True, + reason="Skipping test because should only be run when releasing minor transformers version", +) +@pytest.mark.usefixtures("sm_env") +@parameterized_class( + [ + { + "framework": "pytorch", + "script": "run_glue.py", + "model_name_or_path": "distilbert-base-cased", + "instance_type": "ml.g4dn.xlarge", + "results": {"train_runtime": 200, "eval_accuracy": 0.6, "eval_loss": 0.9}, + }, + { + "framework": "tensorflow", + "script": "run_tf.py", + "model_name_or_path": "distilbert-base-cased", + "instance_type": "ml.g4dn.xlarge", + "results": {"train_runtime": 350, "eval_accuracy": 0.3, "eval_loss": 0.9}, + }, + ] +) +class SingleNodeTest(unittest.TestCase): + def setUp(self): + if self.framework == "pytorch": + subprocess.run( + f"cp ./examples/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) + assert hasattr(self, "env") + + def create_estimator(self, instance_count=1): + # creates estimator + return HuggingFace( + entry_point=self.script, + source_dir=self.env.test_path, + role=self.env.role, + image_uri=self.env.image_uri, + base_job_name=f"{self.env.base_job_name}-single", + instance_count=instance_count, + instance_type=self.instance_type, + debugger_hook_config=False, + hyperparameters={**self.env.hyperparameters, "model_name_or_path": self.model_name_or_path}, + metric_definitions=self.env.metric_definitions, + py_version="py36", + ) + + def save_results_as_csv(self, job_name): + TrainingJobAnalytics(job_name).export_csv(f"{self.env.test_path}/{job_name}_metrics.csv") + + def test_glue(self): + # create estimator + estimator = self.create_estimator() + + # run training + estimator.fit() + + # save csv + self.save_results_as_csv(estimator.latest_training_job.name) + # result dataframe + result_metrics_df = TrainingJobAnalytics(estimator.latest_training_job.name).dataframe() + + # extract kpis + train_runtime = list(result_metrics_df[result_metrics_df.metric_name == "train_runtime"]["value"]) + eval_accuracy = list(result_metrics_df[result_metrics_df.metric_name == "eval_accuracy"]["value"]) + eval_loss = list(result_metrics_df[result_metrics_df.metric_name == "eval_loss"]["value"]) + + # assert kpis + assert all(t <= self.results["train_runtime"] for t in train_runtime) + assert all(t >= self.results["eval_accuracy"] for t in eval_accuracy) + assert all(t <= self.results["eval_loss"] for t in eval_loss)