transformers/tests/generation/test_fsdp.py
Matthew Hoffman 70b07d97cf
Default synced_gpus to True when using FullyShardedDataParallel (#33483)
* Default synced_gpus to True when using FullyShardedDataParallel

Fixes #30228

Related:

* https://github.com/pytorch/pytorch/issues/100069
* https://github.com/pytorch/pytorch/issues/123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

* Remove test file

* ruff formatting

* ruff format

* Update copyright year

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add test for FSDP-wrapped model generation

Before #33483, these tests would have hung for 10 minutes before crashing due to a timeout error

* Ruff format

* Move argparse import

* Remove barrier

I think this might cause more problems if one of the workers was killed

* Move import into function to decrease load time

https://github.com/huggingface/transformers/pull/33483#discussion_r1787972735

* Add test for accelerate and Trainer

https://github.com/huggingface/transformers/pull/33483#discussion_r1790309675

* Refactor imports

* Ruff format

* Use nullcontext

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-10-10 14:09:04 -04:00

149 lines
5.5 KiB
Python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Any, Callable
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)
if is_torch_available():
import functools
import torch
import torch.distributed
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
data = 4 * [
"Hello world!",
"The quick brown fox jumps over the lazy dog.",
]
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
def wrapped(*args: Any, **kwargs: Any) -> Any:
torch.distributed.init_process_group(world_size=torch.cuda.device_count())
try:
return func(*args, **kwargs)
finally:
torch.distributed.destroy_process_group()
return wrapped
@manage_process_group
def fsdp_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
fsdp_model = FullyShardedDataParallel(
model,
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}),
limit_all_gathers=True,
use_orig_params=True,
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
with FullyShardedDataParallel.summon_full_params(fsdp_model):
_ = fsdp_model.module.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)
@manage_process_group
def fsdp2_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
for submodule in model.modules():
if isinstance(submodule, GPT2Block):
fully_shard(submodule, mesh=mesh)
fully_shard(model, mesh=mesh)
register_fsdp_forward_method(model, "generate")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
_ = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)
class TestFSDPGeneration(TestCasePlus):
@require_torch_multi_gpu
def test_fsdp_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()
args = "--fsdp".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@require_torch_multi_gpu
def test_fsdp2_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()
args = "--fsdp2".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
#
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/generation/test_fsdp.py --fsdp
class CLIArgs(argparse.Namespace):
fsdp: bool
fsdp2: bool
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group()
group.add_argument("--fsdp", action="store_true")
group.add_argument("--fsdp2", action="store_true")
args = parser.parse_args(namespace=CLIArgs())
if args.fsdp:
fsdp_generate()
elif args.fsdp2:
fsdp2_generate()
else:
raise ValueError("Missing test selection")