mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix dtype for params without tp_plan (#36681)
* Update tensor_parallel.py * CIs
This commit is contained in:
parent
bb965d8e87
commit
32c95bd847
@ -531,7 +531,7 @@ def shard_and_distribute_module(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
else:
|
||||
param = param[...]
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if is_contiguous:
|
||||
param = param.contiguous()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user