diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4bde1f4451c..13fb9418224 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -380,10 +380,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): gen = parameter._named_members(get_members_fn=find_tensor_attributes) last_tuple = None - for tuple in gen: - last_tuple = tuple - if tuple[1].is_floating_point(): - return tuple[1].dtype + for gen_tuple in gen: + last_tuple = gen_tuple + if gen_tuple[1].is_floating_point(): + return gen_tuple[1].dtype if last_tuple is not None: # fallback to the last dtype