[CI] fix update metadata job (#36850)

fix updata_metadata job
This commit is contained in:
Joao Gante 2025-03-20 17:17:36 +00:00 committed by GitHub
parent 42c489f2ae
commit ecd60d01c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 10 deletions

View File

@ -19,7 +19,7 @@ jobs:
- name: Setup environment
run: |
pip install --upgrade pip
pip install datasets pandas==2.0.3
pip install datasets pandas
pip install .[torch,tf,flax]
- name: Update metadata

View File

@ -537,6 +537,7 @@ class DynamicCache(Cache):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
# Utilities for `DynamicCache` <> torch.export support
def _flatten_dynamic_cache(
dynamic_cache: DynamicCache,
):
@ -584,15 +585,16 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
return torch.utils._pytree.tree_flatten(dictionary)[0]
torch.utils._pytree.register_pytree_node(
DynamicCache,
_flatten_dynamic_cache,
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
if is_torch_greater_or_equal("2.2"):
torch.utils._pytree.register_pytree_node(
DynamicCache,
_flatten_dynamic_cache,
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
class OffloadedCache(DynamicCache):