mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update-tp test (#35844)
* update test for now * up * cleanup * update todo
This commit is contained in:
parent
62db3e6ed6
commit
7eecdf2a86
@ -343,6 +343,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
|
||||
# TODO need to add the __repr__ that shows that it is a colwise parallel
|
||||
# See https://github.com/pytorch/pytorch/issues/145726
|
||||
def translate_to_torch_parallel_style(style: str):
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
|
@ -17,6 +17,7 @@ import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
@ -110,9 +111,8 @@ if __name__ == "__main__":
|
||||
|
||||
# Test settings
|
||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
bs = 4
|
||||
seqlen = 64
|
||||
|
||||
bs = 1
|
||||
seqlen = 4096
|
||||
# Get distributed settings
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
@ -124,23 +124,45 @@ if __name__ == "__main__":
|
||||
|
||||
# Get model config
|
||||
config = LlamaConfig.from_pretrained(model_id)
|
||||
# Shrink model size
|
||||
config.num_hidden_layers //= 8
|
||||
config.vocab_size //= 8
|
||||
|
||||
config.hidden_size = 2048
|
||||
config.attention_bias = False
|
||||
# Instantiate model
|
||||
with device:
|
||||
model = LlamaModel(config)
|
||||
model = LlamaModel(config).to(dtype=torch.float16)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Tensor Parallel
|
||||
if world_size > 1:
|
||||
model.tensor_parallel(device_mesh)
|
||||
|
||||
# Run model
|
||||
|
||||
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
|
||||
with torch.no_grad():
|
||||
out = model(inputs)
|
||||
|
||||
# Test cuda graphing explicitly
|
||||
with torch.cuda.device(device):
|
||||
print("Cuda graphing")
|
||||
with torch.no_grad():
|
||||
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
|
||||
# CUDA Graph setup
|
||||
s = torch.cuda.Stream(device=device)
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for i in range(3):
|
||||
out = model(inputs)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out = model(inputs)
|
||||
|
||||
for _ in range(2):
|
||||
g.replay()
|
||||
s.synchronize()
|
||||
|
||||
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
|
||||
|
||||
# Test compile
|
||||
with torch.no_grad():
|
||||
out = model(inputs)
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead")
|
||||
out = model(inputs)
|
||||
out = model(inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user