diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 18163f230ee..b3fab0fe19c 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -531,7 +531,7 @@ def shard_and_distribute_module( param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh ) else: - param = param[...] + param = param[...].to(param_casting_dtype) if is_contiguous: param = param.contiguous()