Update-tp test (#35844)

* update test for now

* up

* cleanup

* update todo
This commit is contained in:
Arthur 2025-02-03 09:37:02 +01:00 committed by GitHub
parent 62db3e6ed6
commit 7eecdf2a86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 12 deletions

View File

@ -343,6 +343,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)
# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel

View File

@ -17,6 +17,7 @@ import subprocess
import tempfile
import textwrap
# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
@ -110,9 +111,8 @@ if __name__ == "__main__":
# Test settings
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bs = 4
seqlen = 64
bs = 1
seqlen = 4096
# Get distributed settings
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
@ -124,23 +124,45 @@ if __name__ == "__main__":
# Get model config
config = LlamaConfig.from_pretrained(model_id)
# Shrink model size
config.num_hidden_layers //= 8
config.vocab_size //= 8
config.hidden_size = 2048
config.attention_bias = False
# Instantiate model
with device:
model = LlamaModel(config)
model = LlamaModel(config).to(dtype=torch.float16)
model.eval()
# Tensor Parallel
if world_size > 1:
model.tensor_parallel(device_mesh)
# Run model
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
with torch.no_grad():
out = model(inputs)
# Test cuda graphing explicitly
with torch.cuda.device(device):
print("Cuda graphing")
with torch.no_grad():
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
# CUDA Graph setup
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out = model(inputs)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(inputs)
for _ in range(2):
g.replay()
s.synchronize()
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
# Test compile
with torch.no_grad():
out = model(inputs)
model.forward = torch.compile(model.forward, mode="reduce-overhead")
out = model(inputs)
out = model(inputs)