mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* enable generation fsdp/utils test cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> * xx Signed-off-by: Yao Matrix <matrix.yao@intel.com> * use backend_xx APIs Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com>
165 lines
6.3 KiB
Python
165 lines
6.3 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
from typing import Any, Callable
|
|
|
|
from transformers import is_torch_available, is_torch_xpu_available
|
|
from transformers.testing_utils import (
|
|
TestCasePlus,
|
|
backend_device_count,
|
|
backend_torch_accelerator_module,
|
|
execute_subprocess_async,
|
|
get_torch_dist_unique_port,
|
|
require_torch_multi_accelerator,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import is_ccl_available, is_ipex_available
|
|
|
|
|
|
if is_torch_available():
|
|
import functools
|
|
|
|
import torch
|
|
|
|
if is_torch_xpu_available():
|
|
if is_ipex_available():
|
|
import intel_extension_for_pytorch # noqa: F401
|
|
if is_ccl_available():
|
|
import oneccl_bindings_for_pytorch # noqa: F401
|
|
import torch.distributed
|
|
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
|
|
|
data = 4 * [
|
|
"Hello world!",
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
]
|
|
|
|
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
|
|
|
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
device_count = backend_device_count(torch_device)
|
|
torch.distributed.init_process_group(world_size=device_count)
|
|
try:
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
torch.distributed.destroy_process_group()
|
|
|
|
return wrapped
|
|
|
|
@manage_process_group
|
|
def fsdp_generate():
|
|
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
|
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
|
|
|
fsdp_model = FullyShardedDataParallel(
|
|
model,
|
|
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}),
|
|
limit_all_gathers=True,
|
|
use_orig_params=True,
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
|
|
|
|
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
|
_ = fsdp_model.module.generate(
|
|
input_ids=batch["input_ids"],
|
|
attention_mask=batch["attention_mask"],
|
|
max_length=30,
|
|
)
|
|
|
|
@manage_process_group
|
|
def fsdp2_generate():
|
|
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
|
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
|
|
|
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),))
|
|
for submodule in model.modules():
|
|
if isinstance(submodule, GPT2Block):
|
|
fully_shard(submodule, mesh=mesh)
|
|
fully_shard(model, mesh=mesh)
|
|
|
|
register_fsdp_forward_method(model, "generate")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
|
|
|
|
_ = model.generate(
|
|
input_ids=batch["input_ids"],
|
|
attention_mask=batch["attention_mask"],
|
|
max_length=30,
|
|
)
|
|
|
|
|
|
class TestFSDPGeneration(TestCasePlus):
|
|
@require_torch_multi_accelerator
|
|
def test_fsdp_generate(self):
|
|
device_count = backend_device_count(torch_device)
|
|
distributed_args = f"""--nproc_per_node={device_count}
|
|
--master_port={get_torch_dist_unique_port()}
|
|
{self.test_file_dir}/test_fsdp.py
|
|
""".split()
|
|
args = "--fsdp".split()
|
|
cmd = ["torchrun"] + distributed_args + args
|
|
execute_subprocess_async(cmd, env=self.get_env())
|
|
# successful return here == success - any errors would have caused an error in the sub-call
|
|
|
|
@require_torch_multi_accelerator
|
|
def test_fsdp2_generate(self):
|
|
device_count = backend_device_count(torch_device)
|
|
|
|
distributed_args = f"""--nproc_per_node={device_count}
|
|
--master_port={get_torch_dist_unique_port()}
|
|
{self.test_file_dir}/test_fsdp.py
|
|
""".split()
|
|
args = "--fsdp2".split()
|
|
cmd = ["torchrun"] + distributed_args + args
|
|
execute_subprocess_async(cmd, env=self.get_env())
|
|
# successful return here == success - any errors would have caused an error in the sub-call
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
|
#
|
|
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/generation/test_fsdp.py --fsdp
|
|
|
|
class CLIArgs(argparse.Namespace):
|
|
fsdp: bool
|
|
fsdp2: bool
|
|
|
|
parser = argparse.ArgumentParser()
|
|
group = parser.add_mutually_exclusive_group()
|
|
group.add_argument("--fsdp", action="store_true")
|
|
group.add_argument("--fsdp2", action="store_true")
|
|
args = parser.parse_args(namespace=CLIArgs())
|
|
|
|
if args.fsdp:
|
|
fsdp_generate()
|
|
elif args.fsdp2:
|
|
fsdp2_generate()
|
|
else:
|
|
raise ValueError("Missing test selection")
|