mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
parent
42c489f2ae
commit
ecd60d01c3
2
.github/workflows/update_metdata.yml
vendored
2
.github/workflows/update_metdata.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
- name: Setup environment
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install datasets pandas==2.0.3
|
||||
pip install datasets pandas
|
||||
pip install .[torch,tf,flax]
|
||||
|
||||
- name: Update metadata
|
||||
|
@ -537,6 +537,7 @@ class DynamicCache(Cache):
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||
|
||||
|
||||
# Utilities for `DynamicCache` <> torch.export support
|
||||
def _flatten_dynamic_cache(
|
||||
dynamic_cache: DynamicCache,
|
||||
):
|
||||
@ -584,15 +585,16 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
|
||||
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
||||
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
_flatten_dynamic_cache,
|
||||
_unflatten_dynamic_cache,
|
||||
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
||||
)
|
||||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
|
||||
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
|
||||
if is_torch_greater_or_equal("2.2"):
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
_flatten_dynamic_cache,
|
||||
_unflatten_dynamic_cache,
|
||||
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
||||
)
|
||||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
|
||||
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
|
||||
|
||||
|
||||
class OffloadedCache(DynamicCache):
|
||||
|
Loading…
Reference in New Issue
Block a user