From ab62a23d8c4927a26775a01cd0cca7ba77368e04 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 8 Aug 2022 23:48:49 +0200 Subject: [PATCH] Let's not cast them all (#18471) * add correct dtypes when checking for params dtype * forward contrib credits * Update src/transformers/modeling_utils.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> * more comments - added more comments on why we cast only floating point parameters * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: sgugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2a86128c221..8bce35f9e33 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -543,8 +543,10 @@ def _load_state_dict_into_meta_model( param_name = param_name[len(start_prefix) :] module_name = param_name - # We convert floating dtypes to the `dtype` passed. - if dtype is not None and not str(param.dtype).startswith("torch.int"): + + # We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params + # in int/uint/bool and not cast them. + if dtype is not None and torch.is_floating_point(param): param = param.to(dtype) if device_map is None: