mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Moved functions to pytorch_utils.py (#16625)
* Moved functions to pytorch_utils.py * isort formatting * Reverted tf changes * isort, make fix-copies * documentation fix * Fixed Conv1D import * Reverted research examples file * backward compatibility for pytorch_utils * missing import * isort fix
This commit is contained in:
parent
0711c45eae
commit
a315988bae
@ -19,7 +19,7 @@ Most of those are only useful if you are studying the code of the models in the
|
||||
|
||||
## Pytorch custom modules
|
||||
|
||||
[[autodoc]] modeling_utils.Conv1D
|
||||
[[autodoc]] pytorch_utils.Conv1D
|
||||
|
||||
[[autodoc]] modeling_utils.PoolerStartLogits
|
||||
- forward
|
||||
@ -40,15 +40,15 @@ Most of those are only useful if you are studying the code of the models in the
|
||||
|
||||
## PyTorch Helper Functions
|
||||
|
||||
[[autodoc]] apply_chunking_to_forward
|
||||
[[autodoc]] pytorch_utils.apply_chunking_to_forward
|
||||
|
||||
[[autodoc]] modeling_utils.find_pruneable_heads_and_indices
|
||||
[[autodoc]] pytorch_utils.find_pruneable_heads_and_indices
|
||||
|
||||
[[autodoc]] modeling_utils.prune_layer
|
||||
[[autodoc]] pytorch_utils.prune_layer
|
||||
|
||||
[[autodoc]] modeling_utils.prune_conv1d_layer
|
||||
[[autodoc]] pytorch_utils.prune_conv1d_layer
|
||||
|
||||
[[autodoc]] modeling_utils.prune_linear_layer
|
||||
[[autodoc]] pytorch_utils.prune_linear_layer
|
||||
|
||||
## TensorFlow custom layers
|
||||
|
||||
|
@ -638,7 +638,7 @@ if is_torch_available():
|
||||
]
|
||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel"]
|
||||
|
||||
# PyTorch models structure
|
||||
_import_structure["models.albert"].extend(
|
||||
@ -1693,7 +1693,7 @@ if is_torch_available():
|
||||
"get_polynomial_decay_schedule_with_warmup",
|
||||
"get_scheduler",
|
||||
]
|
||||
_import_structure["pytorch_utils"] = []
|
||||
_import_structure["pytorch_utils"] = ["Conv1D", "apply_chunking_to_forward", "prune_layer"]
|
||||
_import_structure["sagemaker"] = []
|
||||
_import_structure["trainer"] = ["Trainer"]
|
||||
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
|
||||
@ -2956,7 +2956,7 @@ if TYPE_CHECKING:
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from .generation_utils import top_k_top_p_filtering
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.albert import (
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
AlbertForMaskedLM,
|
||||
@ -3831,6 +3831,7 @@ if TYPE_CHECKING:
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
)
|
||||
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
|
||||
|
||||
# Trainer
|
||||
from .trainer import Trainer
|
||||
|
@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -24,7 +23,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, nn
|
||||
@ -37,6 +36,14 @@ from .configuration_utils import PretrainedConfig
|
||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation_utils import GenerationMixin
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
@ -97,32 +104,6 @@ except ImportError:
|
||||
return input
|
||||
|
||||
|
||||
def find_pruneable_heads_and_indices(
|
||||
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
|
||||
) -> Tuple[Set[int], torch.LongTensor]:
|
||||
"""
|
||||
Finds the heads and their indices taking `already_pruned_heads` into account.
|
||||
|
||||
Args:
|
||||
heads (`List[int]`): List of the indices of heads to prune.
|
||||
n_heads (`int`): The number of heads in the model.
|
||||
head_size (`int`): The size of each head.
|
||||
already_pruned_heads (`Set[int]`): A set of already pruned heads.
|
||||
|
||||
Returns:
|
||||
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
||||
"""
|
||||
mask = torch.ones(n_heads, head_size)
|
||||
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||
return heads, index
|
||||
|
||||
|
||||
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@ -2305,32 +2286,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return url
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
|
||||
Args:
|
||||
nf (`int`): The number of output features.
|
||||
nx (`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = nn.Parameter(w)
|
||||
self.bias = nn.Parameter(torch.zeros(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(size_out)
|
||||
return x
|
||||
|
||||
|
||||
class PoolerStartLogits(nn.Module):
|
||||
"""
|
||||
Compute SQuAD start logits from sequence hidden states.
|
||||
@ -2762,169 +2717,3 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
|
||||
"""
|
||||
Prune a linear layer to keep only entries in index.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (`torch.nn.Linear`): The layer to prune.
|
||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if layer.bias is not None:
|
||||
if dim == 1:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
if layer.bias is not None:
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
|
||||
"""
|
||||
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
|
||||
are transposed.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer ([`~modeling_utils.Conv1D`]): The layer to prune.
|
||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
[`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if dim == 0:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_layer(
|
||||
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
|
||||
) -> Union[nn.Linear, Conv1D]:
|
||||
"""
|
||||
Prune a Conv1D or linear layer to keep only entries in index.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
|
||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (`int`, *optional*): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Linear` or [`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
if isinstance(layer, nn.Linear):
|
||||
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
||||
elif isinstance(layer, Conv1D):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError(f"Can't prune layer of class {layer.__class__}")
|
||||
|
||||
|
||||
def apply_chunking_to_forward(
|
||||
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
|
||||
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
|
||||
|
||||
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
|
||||
applying `forward_fn` to `input_tensors`.
|
||||
|
||||
Args:
|
||||
forward_fn (`Callable[..., torch.Tensor]`):
|
||||
The forward function of the model.
|
||||
chunk_size (`int`):
|
||||
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
|
||||
chunk_dim (`int`):
|
||||
The dimension over which the `input_tensors` should be chunked.
|
||||
input_tensors (`Tuple[torch.Tensor]`):
|
||||
The input tensors of `forward_fn` which will be chunked
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# rename the usual forward() fn to forward_chunk()
|
||||
def forward_chunk(self, hidden_states):
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# implement a chunked forward function
|
||||
def forward(self, hidden_states):
|
||||
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
||||
```"""
|
||||
|
||||
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
|
||||
|
||||
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
|
||||
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
||||
if num_args_in_forward_chunk_fn != len(input_tensors):
|
||||
raise ValueError(
|
||||
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
|
||||
"tensors are given"
|
||||
)
|
||||
|
||||
if chunk_size > 0:
|
||||
tensor_shape = input_tensors[0].shape[chunk_dim]
|
||||
for input_tensor in input_tensors:
|
||||
if input_tensor.shape[chunk_dim] != tensor_shape:
|
||||
raise ValueError(
|
||||
f"All input tenors have to be of the same shape: {tensor_shape}, "
|
||||
f"found shape {input_tensor.shape[chunk_dim]}"
|
||||
)
|
||||
|
||||
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
|
||||
f"size {chunk_size}"
|
||||
)
|
||||
|
||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||
|
||||
# chunk input tensor into tuples
|
||||
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
||||
# apply forward fn to every tuple
|
||||
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
||||
# concatenate output at same dimension
|
||||
return torch.cat(output_chunks, dim=chunk_dim)
|
||||
|
||||
return forward_fn(*input_tensors)
|
||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
||||
MaskedLMOutput,
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -40,12 +40,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -37,7 +37,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -35,12 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_canine import CanineConfig
|
||||
|
||||
|
@ -35,13 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_convbert import ConvBertConfig
|
||||
|
||||
|
@ -23,7 +23,8 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
|
||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -25,7 +25,8 @@ from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -27,7 +27,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -39,12 +39,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -42,7 +42,8 @@ from ...modeling_outputs import (
|
||||
DepthEstimatorOutput,
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_dpt import DPTConfig
|
||||
|
||||
|
@ -36,13 +36,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -43,7 +43,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -24,7 +24,8 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -40,13 +40,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
Conv1D,
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_conv1d_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -35,7 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_ibert import IBertConfig
|
||||
from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear
|
||||
|
@ -38,7 +38,8 @@ from ...modeling_outputs import (
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_imagegpt import ImageGPTConfig
|
||||
|
||||
|
@ -31,12 +31,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_layoutlm import LayoutLMConfig
|
||||
|
||||
|
@ -31,7 +31,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
|
@ -24,12 +24,8 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -24,7 +24,8 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -29,7 +29,8 @@ from transformers.utils import logging
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
||||
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -39,12 +39,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -41,7 +41,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_mpnet import MPNetConfig
|
||||
|
||||
|
@ -33,12 +33,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_nystromformer import NystromformerConfig
|
||||
|
||||
|
@ -28,13 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import gelu_new, silu
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import (
|
||||
Conv1D,
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_conv1d_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -29,12 +29,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -38,7 +38,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -31,12 +31,8 @@ from ...modeling_outputs import (
|
||||
MaskedLMOutput,
|
||||
ModelOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_realm import RealmConfig
|
||||
|
||||
|
@ -30,7 +30,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
|
@ -35,12 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -35,12 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -35,13 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -25,7 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -25,12 +25,8 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_splinter import SplinterConfig
|
||||
|
||||
|
@ -26,7 +26,8 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
|
@ -28,12 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -33,7 +33,8 @@ from ...modeling_outputs import (
|
||||
ModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_vilt import ViltConfig
|
||||
|
||||
|
@ -31,12 +31,8 @@ from ...modeling_outputs import (
|
||||
MultipleChoiceModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -26,7 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -28,7 +28,8 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -35,14 +35,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
SQuADHead,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
|
@ -25,14 +25,8 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import (
|
||||
PoolerAnswerClass,
|
||||
PoolerEndLogits,
|
||||
PoolerStartLogits,
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
apply_chunking_to_forward,
|
||||
)
|
||||
from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
|
@ -34,12 +34,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_yoso import YosoConfig
|
||||
|
||||
|
@ -11,10 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import _softmax_backward_data
|
||||
from torch import _softmax_backward_data, nn
|
||||
|
||||
from .utils import logging
|
||||
|
||||
@ -45,3 +47,221 @@ def softmax_backward_data(parent, grad_output, output, dim, self):
|
||||
return _softmax_backward_data(grad_output, output, parent.dim, self)
|
||||
else:
|
||||
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
|
||||
|
||||
|
||||
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
|
||||
"""
|
||||
Prune a linear layer to keep only entries in index.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (`torch.nn.Linear`): The layer to prune.
|
||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if layer.bias is not None:
|
||||
if dim == 1:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
if layer.bias is not None:
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
|
||||
Args:
|
||||
nf (`int`): The number of output features.
|
||||
nx (`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = nn.Parameter(w)
|
||||
self.bias = nn.Parameter(torch.zeros(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(size_out)
|
||||
return x
|
||||
|
||||
|
||||
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
|
||||
"""
|
||||
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
|
||||
are transposed.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
|
||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if dim == 0:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_layer(
|
||||
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
|
||||
) -> Union[nn.Linear, Conv1D]:
|
||||
"""
|
||||
Prune a Conv1D or linear layer to keep only entries in index.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
|
||||
index (`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (`int`, *optional*): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
|
||||
"""
|
||||
if isinstance(layer, nn.Linear):
|
||||
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
||||
elif isinstance(layer, Conv1D):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError(f"Can't prune layer of class {layer.__class__}")
|
||||
|
||||
|
||||
def apply_chunking_to_forward(
|
||||
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
|
||||
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
|
||||
|
||||
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
|
||||
applying `forward_fn` to `input_tensors`.
|
||||
|
||||
Args:
|
||||
forward_fn (`Callable[..., torch.Tensor]`):
|
||||
The forward function of the model.
|
||||
chunk_size (`int`):
|
||||
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
|
||||
chunk_dim (`int`):
|
||||
The dimension over which the `input_tensors` should be chunked.
|
||||
input_tensors (`Tuple[torch.Tensor]`):
|
||||
The input tensors of `forward_fn` which will be chunked
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# rename the usual forward() fn to forward_chunk()
|
||||
def forward_chunk(self, hidden_states):
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# implement a chunked forward function
|
||||
def forward(self, hidden_states):
|
||||
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
||||
```"""
|
||||
|
||||
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
|
||||
|
||||
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
|
||||
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
||||
if num_args_in_forward_chunk_fn != len(input_tensors):
|
||||
raise ValueError(
|
||||
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
|
||||
"tensors are given"
|
||||
)
|
||||
|
||||
if chunk_size > 0:
|
||||
tensor_shape = input_tensors[0].shape[chunk_dim]
|
||||
for input_tensor in input_tensors:
|
||||
if input_tensor.shape[chunk_dim] != tensor_shape:
|
||||
raise ValueError(
|
||||
f"All input tenors have to be of the same shape: {tensor_shape}, "
|
||||
f"found shape {input_tensor.shape[chunk_dim]}"
|
||||
)
|
||||
|
||||
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
|
||||
f"size {chunk_size}"
|
||||
)
|
||||
|
||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||
|
||||
# chunk input tensor into tuples
|
||||
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
||||
# apply forward fn to every tuple
|
||||
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
||||
# concatenate output at same dimension
|
||||
return torch.cat(output_chunks, dim=chunk_dim)
|
||||
|
||||
return forward_fn(*input_tensors)
|
||||
|
||||
|
||||
def find_pruneable_heads_and_indices(
|
||||
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
|
||||
) -> Tuple[Set[int], torch.LongTensor]:
|
||||
"""
|
||||
Finds the heads and their indices taking `already_pruned_heads` into account.
|
||||
|
||||
Args:
|
||||
heads (`List[int]`): List of the indices of heads to prune.
|
||||
n_heads (`int`): The number of heads in the model.
|
||||
head_size (`int`): The size of each head.
|
||||
already_pruned_heads (`Set[int]`): A set of already pruned heads.
|
||||
|
||||
Returns:
|
||||
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
||||
"""
|
||||
mask = torch.ones(n_heads, head_size)
|
||||
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||
return heads, index
|
||||
|
@ -266,13 +266,6 @@ def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||
|
||||
|
||||
class Conv1D(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -280,14 +273,6 @@ class PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def apply_chunking_to_forward(*args, **kwargs):
|
||||
requires_backends(apply_chunking_to_forward, ["torch"])
|
||||
|
||||
|
||||
def prune_layer(*args, **kwargs):
|
||||
requires_backends(prune_layer, ["torch"])
|
||||
|
||||
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -4782,6 +4767,21 @@ def get_scheduler(*args, **kwargs):
|
||||
requires_backends(get_scheduler, ["torch"])
|
||||
|
||||
|
||||
class Conv1D(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def apply_chunking_to_forward(*args, **kwargs):
|
||||
requires_backends(apply_chunking_to_forward, ["torch"])
|
||||
|
||||
|
||||
def prune_layer(*args, **kwargs):
|
||||
requires_backends(prune_layer, ["torch"])
|
||||
|
||||
|
||||
class Trainer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -43,9 +43,8 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
SequenceSummary,
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
|
Loading…
Reference in New Issue
Block a user