mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
parent
42c489f2ae
commit
ecd60d01c3
2
.github/workflows/update_metdata.yml
vendored
2
.github/workflows/update_metdata.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: Setup environment
|
- name: Setup environment
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install datasets pandas==2.0.3
|
pip install datasets pandas
|
||||||
pip install .[torch,tf,flax]
|
pip install .[torch,tf,flax]
|
||||||
|
|
||||||
- name: Update metadata
|
- name: Update metadata
|
||||||
|
@ -537,6 +537,7 @@ class DynamicCache(Cache):
|
|||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||||
|
|
||||||
|
|
||||||
|
# Utilities for `DynamicCache` <> torch.export support
|
||||||
def _flatten_dynamic_cache(
|
def _flatten_dynamic_cache(
|
||||||
dynamic_cache: DynamicCache,
|
dynamic_cache: DynamicCache,
|
||||||
):
|
):
|
||||||
@ -584,15 +585,16 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
|
|||||||
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
||||||
|
|
||||||
|
|
||||||
torch.utils._pytree.register_pytree_node(
|
if is_torch_greater_or_equal("2.2"):
|
||||||
|
torch.utils._pytree.register_pytree_node(
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
_flatten_dynamic_cache,
|
_flatten_dynamic_cache,
|
||||||
_unflatten_dynamic_cache,
|
_unflatten_dynamic_cache,
|
||||||
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
||||||
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
||||||
)
|
)
|
||||||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
|
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
|
||||||
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
|
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
|
||||||
|
|
||||||
|
|
||||||
class OffloadedCache(DynamicCache):
|
class OffloadedCache(DynamicCache):
|
||||||
|
Loading…
Reference in New Issue
Block a user