Fix dtype for params without tp_plan (#36681)

* Update tensor_parallel.py

* CIs
This commit is contained in:
Cyril Vallez 2025-03-13 15:28:14 +01:00 committed by GitHub
parent bb965d8e87
commit 32c95bd847
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -531,7 +531,7 @@ def shard_and_distribute_module(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
else:
param = param[...]
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()