Adapt find_tied_parameters to handle breaking change in Accelerate (#22360)

This commit is contained in:
Sylvain Gugger 2023-03-27 10:11:14 -04:00 committed by GitHub
parent 204737fcc5
commit 8cfc6678da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -154,7 +154,12 @@ def get_keys_to_not_convert(model):
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
tied_keys = list(find_tied_parameters(tied_model).values())
tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = list(tied_params.values())
else:
tied_keys = sum([x[1:] for x in tied_params], [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model