mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Moved functions to pytorch_utils.py (#16625)
* Moved functions to pytorch_utils.py * isort formatting * Reverted tf changes * isort, make fix-copies * documentation fix * Fixed Conv1D import * Reverted research examples file * backward compatibility for pytorch_utils * missing import * isort fix
This commit is contained in:
parent
0711c45eae
commit
a315988bae
@ -19,7 +19,7 @@ Most of those are only useful if you are studying the code of the models in the
|
|||||||
|
|
||||||
## Pytorch custom modules
|
## Pytorch custom modules
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.Conv1D
|
[[autodoc]] pytorch_utils.Conv1D
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.PoolerStartLogits
|
[[autodoc]] modeling_utils.PoolerStartLogits
|
||||||
- forward
|
- forward
|
||||||
@ -40,15 +40,15 @@ Most of those are only useful if you are studying the code of the models in the
|
|||||||
|
|
||||||
## PyTorch Helper Functions
|
## PyTorch Helper Functions
|
||||||
|
|
||||||
[[autodoc]] apply_chunking_to_forward
|
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.find_pruneable_heads_and_indices
|
[[autodoc]] pytorch_utils.find_pruneable_heads_and_indices
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.prune_layer
|
[[autodoc]] pytorch_utils.prune_layer
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.prune_conv1d_layer
|
[[autodoc]] pytorch_utils.prune_conv1d_layer
|
||||||
|
|
||||||
[[autodoc]] modeling_utils.prune_linear_layer
|
[[autodoc]] pytorch_utils.prune_linear_layer
|
||||||
|
|
||||||
## TensorFlow custom layers
|
## TensorFlow custom layers
|
||||||
|
|
||||||
|
@ -638,7 +638,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||||
_import_structure["modeling_outputs"] = []
|
_import_structure["modeling_outputs"] = []
|
||||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
_import_structure["modeling_utils"] = ["PreTrainedModel"]
|
||||||
|
|
||||||
# PyTorch models structure
|
# PyTorch models structure
|
||||||
_import_structure["models.albert"].extend(
|
_import_structure["models.albert"].extend(
|
||||||
@ -1693,7 +1693,7 @@ if is_torch_available():
|
|||||||
"get_polynomial_decay_schedule_with_warmup",
|
"get_polynomial_decay_schedule_with_warmup",
|
||||||
"get_scheduler",
|
"get_scheduler",
|
||||||
]
|
]
|
||||||
_import_structure["pytorch_utils"] = []
|
_import_structure["pytorch_utils"] = ["Conv1D", "apply_chunking_to_forward", "prune_layer"]
|
||||||
_import_structure["sagemaker"] = []
|
_import_structure["sagemaker"] = []
|
||||||
_import_structure["trainer"] = ["Trainer"]
|
_import_structure["trainer"] = ["Trainer"]
|
||||||
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
|
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
|
||||||
@ -2956,7 +2956,7 @@ if TYPE_CHECKING:
|
|||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
)
|
)
|
||||||
from .generation_utils import top_k_top_p_filtering
|
from .generation_utils import top_k_top_p_filtering
|
||||||
from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer
|
from .modeling_utils import PreTrainedModel
|
||||||
from .models.albert import (
|
from .models.albert import (
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
AlbertForMaskedLM,
|
AlbertForMaskedLM,
|
||||||
@ -3831,6 +3831,7 @@ if TYPE_CHECKING:
|
|||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
get_scheduler,
|
get_scheduler,
|
||||||
)
|
)
|
||||||
|
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -24,7 +23,7 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device, nn
|
from torch import Tensor, device, nn
|
||||||
@ -37,6 +36,14 @@ from .configuration_utils import PretrainedConfig
|
|||||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation_utils import GenerationMixin
|
from .generation_utils import GenerationMixin
|
||||||
|
from .pytorch_utils import ( # noqa: F401
|
||||||
|
Conv1D,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_conv1d_layer,
|
||||||
|
prune_layer,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
@ -97,32 +104,6 @@ except ImportError:
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def find_pruneable_heads_and_indices(
|
|
||||||
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
|
|
||||||
) -> Tuple[Set[int], torch.LongTensor]:
|
|
||||||
"""
|
|
||||||
Finds the heads and their indices taking `already_pruned_heads` into account.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
heads (`List[int]`): List of the indices of heads to prune.
|
|
||||||
n_heads (`int`): The number of heads in the model.
|
|
||||||
head_size (`int`): The size of each head.
|
|
||||||
already_pruned_heads (`Set[int]`): A set of already pruned heads.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
|
||||||
"""
|
|
||||||
mask = torch.ones(n_heads, head_size)
|
|
||||||
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
|
||||||
for head in heads:
|
|
||||||
# Compute how many pruned heads are before the head and move the index accordingly
|
|
||||||
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
|
||||||
mask[head] = 0
|
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
|
||||||
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
|
||||||
return heads, index
|
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||||
try:
|
try:
|
||||||
return next(parameter.parameters()).device
|
return next(parameter.parameters()).device
|
||||||
@ -2305,32 +2286,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
|
||||||
"""
|
|
||||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
|
||||||
|
|
||||||
Basically works like a linear layer but the weights are transposed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nf (`int`): The number of output features.
|
|
||||||
nx (`int`): The number of input features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, nf, nx):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
w = torch.empty(nx, nf)
|
|
||||||
nn.init.normal_(w, std=0.02)
|
|
||||||
self.weight = nn.Parameter(w)
|
|
||||||
self.bias = nn.Parameter(torch.zeros(nf))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
size_out = x.size()[:-1] + (self.nf,)
|
|
||||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
|
||||||
x = x.view(size_out)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PoolerStartLogits(nn.Module):
|
class PoolerStartLogits(nn.Module):
|
||||||
"""
|
"""
|
||||||
Compute SQuAD start logits from sequence hidden states.
|
Compute SQuAD start logits from sequence hidden states.
|
||||||
@ -2762,169 +2717,3 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|||||||
return unwrap_model(model.module)
|
return unwrap_model(model.module)
|
||||||
else:
|
else:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
|
|
||||||
"""
|
|
||||||
Prune a linear layer to keep only entries in index.
|
|
||||||
|
|
||||||
Used to remove heads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer (`torch.nn.Linear`): The layer to prune.
|
|
||||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
|
||||||
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
|
||||||
"""
|
|
||||||
index = index.to(layer.weight.device)
|
|
||||||
W = layer.weight.index_select(dim, index).clone().detach()
|
|
||||||
if layer.bias is not None:
|
|
||||||
if dim == 1:
|
|
||||||
b = layer.bias.clone().detach()
|
|
||||||
else:
|
|
||||||
b = layer.bias[index].clone().detach()
|
|
||||||
new_size = list(layer.weight.size())
|
|
||||||
new_size[dim] = len(index)
|
|
||||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
|
||||||
new_layer.weight.requires_grad = False
|
|
||||||
new_layer.weight.copy_(W.contiguous())
|
|
||||||
new_layer.weight.requires_grad = True
|
|
||||||
if layer.bias is not None:
|
|
||||||
new_layer.bias.requires_grad = False
|
|
||||||
new_layer.bias.copy_(b.contiguous())
|
|
||||||
new_layer.bias.requires_grad = True
|
|
||||||
return new_layer
|
|
||||||
|
|
||||||
|
|
||||||
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
|
|
||||||
"""
|
|
||||||
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
|
|
||||||
are transposed.
|
|
||||||
|
|
||||||
Used to remove heads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer ([`~modeling_utils.Conv1D`]): The layer to prune.
|
|
||||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
|
||||||
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
|
||||||
"""
|
|
||||||
index = index.to(layer.weight.device)
|
|
||||||
W = layer.weight.index_select(dim, index).clone().detach()
|
|
||||||
if dim == 0:
|
|
||||||
b = layer.bias.clone().detach()
|
|
||||||
else:
|
|
||||||
b = layer.bias[index].clone().detach()
|
|
||||||
new_size = list(layer.weight.size())
|
|
||||||
new_size[dim] = len(index)
|
|
||||||
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
|
||||||
new_layer.weight.requires_grad = False
|
|
||||||
new_layer.weight.copy_(W.contiguous())
|
|
||||||
new_layer.weight.requires_grad = True
|
|
||||||
new_layer.bias.requires_grad = False
|
|
||||||
new_layer.bias.copy_(b.contiguous())
|
|
||||||
new_layer.bias.requires_grad = True
|
|
||||||
return new_layer
|
|
||||||
|
|
||||||
|
|
||||||
def prune_layer(
|
|
||||||
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
|
|
||||||
) -> Union[nn.Linear, Conv1D]:
|
|
||||||
"""
|
|
||||||
Prune a Conv1D or linear layer to keep only entries in index.
|
|
||||||
|
|
||||||
Used to remove heads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
|
|
||||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
|
||||||
dim (`int`, *optional*): The dimension on which to keep the indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.nn.Linear` or [`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
|
||||||
"""
|
|
||||||
if isinstance(layer, nn.Linear):
|
|
||||||
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
|
||||||
elif isinstance(layer, Conv1D):
|
|
||||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Can't prune layer of class {layer.__class__}")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_chunking_to_forward(
|
|
||||||
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
|
|
||||||
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
|
|
||||||
|
|
||||||
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
|
|
||||||
applying `forward_fn` to `input_tensors`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
forward_fn (`Callable[..., torch.Tensor]`):
|
|
||||||
The forward function of the model.
|
|
||||||
chunk_size (`int`):
|
|
||||||
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
|
|
||||||
chunk_dim (`int`):
|
|
||||||
The dimension over which the `input_tensors` should be chunked.
|
|
||||||
input_tensors (`Tuple[torch.Tensor]`):
|
|
||||||
The input tensors of `forward_fn` which will be chunked
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
|
|
||||||
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# rename the usual forward() fn to forward_chunk()
|
|
||||||
def forward_chunk(self, hidden_states):
|
|
||||||
hidden_states = self.decoder(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# implement a chunked forward function
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
|
||||||
```"""
|
|
||||||
|
|
||||||
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
|
|
||||||
|
|
||||||
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
|
|
||||||
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
|
||||||
if num_args_in_forward_chunk_fn != len(input_tensors):
|
|
||||||
raise ValueError(
|
|
||||||
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
|
|
||||||
"tensors are given"
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunk_size > 0:
|
|
||||||
tensor_shape = input_tensors[0].shape[chunk_dim]
|
|
||||||
for input_tensor in input_tensors:
|
|
||||||
if input_tensor.shape[chunk_dim] != tensor_shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"All input tenors have to be of the same shape: {tensor_shape}, "
|
|
||||||
f"found shape {input_tensor.shape[chunk_dim]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
|
|
||||||
f"size {chunk_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
|
||||||
|
|
||||||
# chunk input tensor into tuples
|
|
||||||
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
|
||||||
# apply forward fn to every tuple
|
|
||||||
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
|
||||||
# concatenate output at same dimension
|
|
||||||
return torch.cat(output_chunks, dim=chunk_dim)
|
|
||||||
|
|
||||||
return forward_fn(*input_tensors)
|
|
||||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
|||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
SemanticSegmenterOutput,
|
SemanticSegmenterOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -40,12 +40,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -37,7 +37,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -35,12 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_canine import CanineConfig
|
from .configuration_canine import CanineConfig
|
||||||
|
|
||||||
|
@ -35,13 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
SequenceSummary,
|
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_convbert import ConvBertConfig
|
from .configuration_convbert import ConvBertConfig
|
||||||
|
|
||||||
|
@ -23,7 +23,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
||||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
|
|
||||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -25,7 +25,8 @@ from packaging import version
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -27,7 +27,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -39,12 +39,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -42,7 +42,8 @@ from ...modeling_outputs import (
|
|||||||
DepthEstimatorOutput,
|
DepthEstimatorOutput,
|
||||||
SemanticSegmenterOutput,
|
SemanticSegmenterOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_dpt import DPTConfig
|
from .configuration_dpt import DPTConfig
|
||||||
|
|
||||||
|
@ -36,13 +36,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
SequenceSummary,
|
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -43,7 +43,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -24,7 +24,8 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
|
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -40,13 +40,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
Conv1D,
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
PreTrainedModel,
|
|
||||||
SequenceSummary,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_conv1d_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -35,7 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_ibert import IBertConfig
|
from .configuration_ibert import IBertConfig
|
||||||
from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear
|
from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear
|
||||||
|
@ -38,7 +38,8 @@ from ...modeling_outputs import (
|
|||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_imagegpt import ImageGPTConfig
|
from .configuration_imagegpt import ImageGPTConfig
|
||||||
|
|
||||||
|
@ -31,12 +31,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_layoutlm import LayoutLMConfig
|
from .configuration_layoutlm import LayoutLMConfig
|
||||||
|
|
||||||
|
@ -31,7 +31,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
@ -24,12 +24,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN, gelu
|
from ...activations import ACT2FN, gelu
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -24,7 +24,8 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN, gelu
|
from ...activations import ACT2FN, gelu
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -29,7 +29,8 @@ from transformers.utils import logging
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
||||||
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -39,12 +39,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -41,7 +41,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_mpnet import MPNetConfig
|
from .configuration_mpnet import MPNetConfig
|
||||||
|
|
||||||
|
@ -33,12 +33,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_nystromformer import NystromformerConfig
|
from .configuration_nystromformer import NystromformerConfig
|
||||||
|
|
||||||
|
@ -28,13 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import gelu_new, silu
|
from ...activations import gelu_new, silu
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
Conv1D,
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
PreTrainedModel,
|
|
||||||
SequenceSummary,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_conv1d_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -29,12 +29,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -38,7 +38,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -31,12 +31,8 @@ from ...modeling_outputs import (
|
|||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_realm import RealmConfig
|
from .configuration_realm import RealmConfig
|
||||||
|
|
||||||
|
@ -30,7 +30,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
|
from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
DUMMY_MASK,
|
DUMMY_MASK,
|
||||||
|
@ -35,12 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -35,12 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -35,13 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
SequenceSummary,
|
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -25,7 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
|
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -25,12 +25,8 @@ from torch.nn import CrossEntropyLoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_splinter import SplinterConfig
|
from .configuration_splinter import SplinterConfig
|
||||||
|
|
||||||
|
@ -26,7 +26,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
|||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
DUMMY_MASK,
|
DUMMY_MASK,
|
||||||
|
@ -28,12 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_vilt import ViltConfig
|
from .configuration_vilt import ViltConfig
|
||||||
|
|
||||||
|
@ -31,12 +31,8 @@ from ...modeling_outputs import (
|
|||||||
MultipleChoiceModelOutput,
|
MultipleChoiceModelOutput,
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -26,7 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -28,7 +28,8 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -35,14 +35,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
SequenceSummary,
|
|
||||||
SQuADHead,
|
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -25,14 +25,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
|
||||||
PoolerAnswerClass,
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
PoolerEndLogits,
|
|
||||||
PoolerStartLogits,
|
|
||||||
PreTrainedModel,
|
|
||||||
SequenceSummary,
|
|
||||||
apply_chunking_to_forward,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_yoso import YosoConfig
|
from .configuration_yoso import YosoConfig
|
||||||
|
|
||||||
|
@ -11,10 +11,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
from typing import Callable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import _softmax_backward_data
|
from torch import _softmax_backward_data, nn
|
||||||
|
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@ -45,3 +47,221 @@ def softmax_backward_data(parent, grad_output, output, dim, self):
|
|||||||
return _softmax_backward_data(grad_output, output, parent.dim, self)
|
return _softmax_backward_data(grad_output, output, parent.dim, self)
|
||||||
else:
|
else:
|
||||||
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
|
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
|
||||||
|
"""
|
||||||
|
Prune a linear layer to keep only entries in index.
|
||||||
|
|
||||||
|
Used to remove heads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer (`torch.nn.Linear`): The layer to prune.
|
||||||
|
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||||
|
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
||||||
|
"""
|
||||||
|
index = index.to(layer.weight.device)
|
||||||
|
W = layer.weight.index_select(dim, index).clone().detach()
|
||||||
|
if layer.bias is not None:
|
||||||
|
if dim == 1:
|
||||||
|
b = layer.bias.clone().detach()
|
||||||
|
else:
|
||||||
|
b = layer.bias[index].clone().detach()
|
||||||
|
new_size = list(layer.weight.size())
|
||||||
|
new_size[dim] = len(index)
|
||||||
|
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
||||||
|
new_layer.weight.requires_grad = False
|
||||||
|
new_layer.weight.copy_(W.contiguous())
|
||||||
|
new_layer.weight.requires_grad = True
|
||||||
|
if layer.bias is not None:
|
||||||
|
new_layer.bias.requires_grad = False
|
||||||
|
new_layer.bias.copy_(b.contiguous())
|
||||||
|
new_layer.bias.requires_grad = True
|
||||||
|
return new_layer
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1D(nn.Module):
|
||||||
|
"""
|
||||||
|
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||||
|
|
||||||
|
Basically works like a linear layer but the weights are transposed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nf (`int`): The number of output features.
|
||||||
|
nx (`int`): The number of input features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, nf, nx):
|
||||||
|
super().__init__()
|
||||||
|
self.nf = nf
|
||||||
|
w = torch.empty(nx, nf)
|
||||||
|
nn.init.normal_(w, std=0.02)
|
||||||
|
self.weight = nn.Parameter(w)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(nf))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
size_out = x.size()[:-1] + (self.nf,)
|
||||||
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||||
|
x = x.view(size_out)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
|
||||||
|
"""
|
||||||
|
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
|
||||||
|
are transposed.
|
||||||
|
|
||||||
|
Used to remove heads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
|
||||||
|
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||||
|
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||||
|
"""
|
||||||
|
index = index.to(layer.weight.device)
|
||||||
|
W = layer.weight.index_select(dim, index).clone().detach()
|
||||||
|
if dim == 0:
|
||||||
|
b = layer.bias.clone().detach()
|
||||||
|
else:
|
||||||
|
b = layer.bias[index].clone().detach()
|
||||||
|
new_size = list(layer.weight.size())
|
||||||
|
new_size[dim] = len(index)
|
||||||
|
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
||||||
|
new_layer.weight.requires_grad = False
|
||||||
|
new_layer.weight.copy_(W.contiguous())
|
||||||
|
new_layer.weight.requires_grad = True
|
||||||
|
new_layer.bias.requires_grad = False
|
||||||
|
new_layer.bias.copy_(b.contiguous())
|
||||||
|
new_layer.bias.requires_grad = True
|
||||||
|
return new_layer
|
||||||
|
|
||||||
|
|
||||||
|
def prune_layer(
|
||||||
|
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
|
||||||
|
) -> Union[nn.Linear, Conv1D]:
|
||||||
|
"""
|
||||||
|
Prune a Conv1D or linear layer to keep only entries in index.
|
||||||
|
|
||||||
|
Used to remove heads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
|
||||||
|
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||||
|
dim (`int`, *optional*): The dimension on which to keep the indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||||
|
"""
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
||||||
|
elif isinstance(layer, Conv1D):
|
||||||
|
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Can't prune layer of class {layer.__class__}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chunking_to_forward(
|
||||||
|
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
|
||||||
|
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
|
||||||
|
|
||||||
|
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
|
||||||
|
applying `forward_fn` to `input_tensors`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_fn (`Callable[..., torch.Tensor]`):
|
||||||
|
The forward function of the model.
|
||||||
|
chunk_size (`int`):
|
||||||
|
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
|
||||||
|
chunk_dim (`int`):
|
||||||
|
The dimension over which the `input_tensors` should be chunked.
|
||||||
|
input_tensors (`Tuple[torch.Tensor]`):
|
||||||
|
The input tensors of `forward_fn` which will be chunked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# rename the usual forward() fn to forward_chunk()
|
||||||
|
def forward_chunk(self, hidden_states):
|
||||||
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# implement a chunked forward function
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
||||||
|
```"""
|
||||||
|
|
||||||
|
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
|
||||||
|
|
||||||
|
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
|
||||||
|
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
||||||
|
if num_args_in_forward_chunk_fn != len(input_tensors):
|
||||||
|
raise ValueError(
|
||||||
|
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
|
||||||
|
"tensors are given"
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunk_size > 0:
|
||||||
|
tensor_shape = input_tensors[0].shape[chunk_dim]
|
||||||
|
for input_tensor in input_tensors:
|
||||||
|
if input_tensor.shape[chunk_dim] != tensor_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"All input tenors have to be of the same shape: {tensor_shape}, "
|
||||||
|
f"found shape {input_tensor.shape[chunk_dim]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
|
||||||
|
f"size {chunk_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||||
|
|
||||||
|
# chunk input tensor into tuples
|
||||||
|
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
||||||
|
# apply forward fn to every tuple
|
||||||
|
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
||||||
|
# concatenate output at same dimension
|
||||||
|
return torch.cat(output_chunks, dim=chunk_dim)
|
||||||
|
|
||||||
|
return forward_fn(*input_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
def find_pruneable_heads_and_indices(
|
||||||
|
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
|
||||||
|
) -> Tuple[Set[int], torch.LongTensor]:
|
||||||
|
"""
|
||||||
|
Finds the heads and their indices taking `already_pruned_heads` into account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
heads (`List[int]`): List of the indices of heads to prune.
|
||||||
|
n_heads (`int`): The number of heads in the model.
|
||||||
|
head_size (`int`): The size of each head.
|
||||||
|
already_pruned_heads (`Set[int]`): A set of already pruned heads.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
||||||
|
"""
|
||||||
|
mask = torch.ones(n_heads, head_size)
|
||||||
|
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||||
|
for head in heads:
|
||||||
|
# Compute how many pruned heads are before the head and move the index accordingly
|
||||||
|
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||||
|
mask[head] = 0
|
||||||
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
|
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||||
|
return heads, index
|
||||||
|
@ -266,13 +266,6 @@ def top_k_top_p_filtering(*args, **kwargs):
|
|||||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(metaclass=DummyObject):
|
|
||||||
_backends = ["torch"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(metaclass=DummyObject):
|
class PreTrainedModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@ -280,14 +273,6 @@ class PreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
def apply_chunking_to_forward(*args, **kwargs):
|
|
||||||
requires_backends(apply_chunking_to_forward, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
def prune_layer(*args, **kwargs):
|
|
||||||
requires_backends(prune_layer, ["torch"])
|
|
||||||
|
|
||||||
|
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@ -4782,6 +4767,21 @@ def get_scheduler(*args, **kwargs):
|
|||||||
requires_backends(get_scheduler, ["torch"])
|
requires_backends(get_scheduler, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1D(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chunking_to_forward(*args, **kwargs):
|
||||||
|
requires_backends(apply_chunking_to_forward, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def prune_layer(*args, **kwargs):
|
||||||
|
requires_backends(prune_layer, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Trainer(metaclass=DummyObject):
|
class Trainer(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -43,9 +43,8 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
PreTrainedModel,
|
from ...pytorch_utils import (
|
||||||
SequenceSummary,
|
|
||||||
apply_chunking_to_forward,
|
apply_chunking_to_forward,
|
||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
|
Loading…
Reference in New Issue
Block a user