From ecd60d01c3fee6811db40b3a5f9a8eb5977d75ba Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 20 Mar 2025 17:17:36 +0000 Subject: [PATCH] [CI] fix update metadata job (#36850) fix updata_metadata job --- .github/workflows/update_metdata.yml | 2 +- src/transformers/cache_utils.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/update_metdata.yml b/.github/workflows/update_metdata.yml index 90cd73077ac..d55b6e336c0 100644 --- a/.github/workflows/update_metdata.yml +++ b/.github/workflows/update_metdata.yml @@ -19,7 +19,7 @@ jobs: - name: Setup environment run: | pip install --upgrade pip - pip install datasets pandas==2.0.3 + pip install datasets pandas pip install .[torch,tf,flax] - name: Update metadata diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d873a17d369..5c65157d778 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -537,6 +537,7 @@ class DynamicCache(Cache): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +# Utilities for `DynamicCache` <> torch.export support def _flatten_dynamic_cache( dynamic_cache: DynamicCache, ): @@ -584,15 +585,16 @@ def _flatten_dynamic_cache_for_fx(cache, spec): return torch.utils._pytree.tree_flatten(dictionary)[0] -torch.utils._pytree.register_pytree_node( - DynamicCache, - _flatten_dynamic_cache, - _unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, -) -# TODO (tmanlaibaatar) This won't be needed in torch 2.7. -torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) +if is_torch_greater_or_equal("2.2"): + torch.utils._pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache, + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + ) + # TODO (tmanlaibaatar) This won't be needed in torch 2.7. + torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) class OffloadedCache(DynamicCache):