mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* accept custom device_mesh * fix device_map * assert that num_heads % tp_size == 0 * todo. * ReplicateParallel * handle tied weights * handle dtensor in save_pretrained with safe_serialization * tp test works * doesnt work * fix shard_and_distribute_module's rank should be local_rank * tp=4 is correct * dp+tp is broken * todo allreduce with dtensors on another dim is annoying * workaround to sync dp grads when using dtensors * loading a checkpoint works * wandb and compare losses with different tp/dp * cleaning * cleaning * . * . * logs * CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention * DP=2 TP=2 now works even with tied embeddings * model.parameters() and model.module.parameters() are empty.. * reformat sanity_check_tensor_sync * set atol=1e-4 for CP to pass * try populate _parameters from named_modules * refactors TP2 DP2 works CP2 DP2 works * is_causal=True and pack sequences, no attn mask, and preshuffle dataset * fix packing * CP=4 doesn't work * fix labels and position_ids for CP * DP CP works with transformers 🥳🥳🥳 * refactor * add example cp * fixup * revert sdpa changes * example cleared * add CP, DP to the mesh init * nit * clean * use `ALL_PARALLEL_STYLES` * style * FSDP works * log on 1 rank * . * fix? * FSDP1 also has .parameters() bug * reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay * . * style and fixup * move stuff around * fix tests * style * let's make it a check * add missing licences * warning should be an info * tp plan should not be NONE * test all * god damn it * test all --------- Co-authored-by: nouamanetazi <nouamane98@gmail.com>
95 lines
3.1 KiB
Python
95 lines
3.1 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.tensor.experimental import context_parallel
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
from transformers.loss.loss_utils import ForCausalLMLoss
|
|
|
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
cp_mesh = init_device_mesh("cuda", (world_size,))
|
|
rank = torch.distributed.get_node_local_rank()
|
|
|
|
device = "cuda"
|
|
dtype = torch.bfloat16
|
|
sdpa_backend = SDPBackend.FLASH_ATTENTION
|
|
|
|
# prepare inputs
|
|
batch_size = 1
|
|
seq_len = 128
|
|
|
|
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
|
|
|
|
ignore_index = -100
|
|
# When using CP, we need to use `shift_labels`
|
|
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
|
|
shift_labels = shift_labels[..., 1:].contiguous()
|
|
|
|
position_ids = (
|
|
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
|
|
)
|
|
|
|
# sync input as they are created randomly
|
|
dist.broadcast(input_ids, src=0)
|
|
dist.broadcast(shift_labels, src=0)
|
|
dist.broadcast(position_ids, src=0)
|
|
|
|
# model and optimizer
|
|
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
|
|
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
|
|
|
model.train()
|
|
model.zero_grad()
|
|
optimizer.zero_grad()
|
|
|
|
# For loss
|
|
vocab_size = model.config.vocab_size
|
|
|
|
# so training could be synced
|
|
model = DDP(model, device_ids=[rank])
|
|
|
|
# prepare for CP
|
|
buffers = (input_ids, shift_labels, position_ids)
|
|
buffer_seq_dims = (1, 1, 1)
|
|
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
|
|
# no_restore_buffers = set(buffers)
|
|
no_restore_buffers = None
|
|
|
|
# run with CP
|
|
with sdpa_kernel(sdpa_backend):
|
|
with context_parallel(
|
|
cp_mesh,
|
|
buffers=buffers,
|
|
buffer_seq_dims=buffer_seq_dims,
|
|
no_restore_buffers=no_restore_buffers,
|
|
):
|
|
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
|
|
print(outputs.logits.shape)
|
|
|
|
# So far we need to compute `loss` outside `model.forward` when using `shift_labels`
|
|
# loss = outputs.loss
|
|
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
|
|
|
|
# This could be outside `context_parallel` context if `no_restore_buffers` is specified
|
|
loss.backward()
|
|
optimizer.step()
|