[setup] make fairscale and deepspeed setup extras (#11151)

* make fairscale and deepspeed setup extras

* fix default

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* no reason not to ask for the good version

* update the CIs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-04-08 15:46:54 -07:00 committed by GitHub
parent ba8b1f4754
commit c2e0fd5283
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 41 additions and 14 deletions

View File

@ -33,8 +33,7 @@ jobs:
run: |
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install deepspeed
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed]
- name: Are GPUs recognized by our DL frameworks
run: |
@ -156,9 +155,7 @@ jobs:
run: |
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install fairscale
pip install deepspeed
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed,fairscale]
- name: Are GPUs recognized by our DL frameworks
run: |

View File

@ -274,6 +274,14 @@ Install the library via pypi:
pip install fairscale
or via ``transformers``' ``extras``:
.. code-block:: bash
pip install transformers[fairscale]
(will become available starting from ``transformers==4.6.0``)
or find more details on `the FairScale's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`.
@ -419,6 +427,14 @@ Install the library via pypi:
pip install deepspeed
or via ``transformers``' ``extras``:
.. code-block:: bash
pip install transformers[deepspeed]
(will become available starting from ``transformers==4.6.0``)
or find more details on `the DeepSpeed's GitHub page <https://github.com/microsoft/deepspeed#installation>`__ and
`advanced install <https://www.deepspeed.ai/tutorials/advanced-install/>`__.

View File

@ -90,7 +90,9 @@ _deps = [
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>0.3.13",
"docutils==0.16.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
"filelock",
@ -233,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
extras["modelcreation"] = deps_list("cookiecutter")
extras["sagemaker"] = deps_list("sagemaker")
extras["deepspeed"] = deps_list("deepspeed")
extras["fairscale"] = deps_list("fairscale")
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile", "torchaudio")

View File

@ -14,7 +14,7 @@
import sys
from .dependency_versions_table import deps
from .utils.versions import require_version_core
from .utils.versions import require_version, require_version_core
# define which module versions we always want to check at run time
@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
require_version_core(deps[pkg])
else:
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
def dep_version_check(pkg, hint=None):
require_version(deps[pkg], hint)

View File

@ -7,7 +7,9 @@ deps = {
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>0.3.13",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
"filelock": "filelock",

View File

@ -24,8 +24,8 @@ import tempfile
from copy import deepcopy
from pathlib import Path
from .dependency_versions_check import dep_version_check
from .utils import logging
from .utils.versions import require_version
logger = logging.get_logger(__name__)
@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
require_version("deepspeed>0.3.13")
dep_version_check("deepspeed")
if isinstance(ds_config, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we

View File

@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .dependency_versions_check import dep_version_check
from .file_utils import (
WEIGHTS_NAME,
is_apex_available,
@ -139,17 +140,14 @@ if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available():
dep_version_check("fairscale")
import fairscale
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.wrap import auto_wrap
else:
FullyShardedDDP = None
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP

View File

@ -60,6 +60,12 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
Args:
requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
Example::
require_version("pandas>1.1.2")
require_version("numpy>1.18.5", "this is important to have for whatever reason")
"""
hint = f"\n{hint}" if hint is not None else ""