Moved functions to pytorch_utils.py (#16625)

* Moved functions to pytorch_utils.py

* isort formatting

* Reverted tf changes

* isort, make fix-copies

* documentation fix

* Fixed Conv1D import

* Reverted research examples file

* backward compatibility for pytorch_utils

* missing import

* isort fix
This commit is contained in:
Anmol Joshi 2022-04-12 09:38:50 -07:00 committed by GitHub
parent 0711c45eae
commit a315988bae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 353 additions and 430 deletions

View File

@ -19,7 +19,7 @@ Most of those are only useful if you are studying the code of the models in the
## Pytorch custom modules
[[autodoc]] modeling_utils.Conv1D
[[autodoc]] pytorch_utils.Conv1D
[[autodoc]] modeling_utils.PoolerStartLogits
- forward
@ -40,15 +40,15 @@ Most of those are only useful if you are studying the code of the models in the
## PyTorch Helper Functions
[[autodoc]] apply_chunking_to_forward
[[autodoc]] pytorch_utils.apply_chunking_to_forward
[[autodoc]] modeling_utils.find_pruneable_heads_and_indices
[[autodoc]] pytorch_utils.find_pruneable_heads_and_indices
[[autodoc]] modeling_utils.prune_layer
[[autodoc]] pytorch_utils.prune_layer
[[autodoc]] modeling_utils.prune_conv1d_layer
[[autodoc]] pytorch_utils.prune_conv1d_layer
[[autodoc]] modeling_utils.prune_linear_layer
[[autodoc]] pytorch_utils.prune_linear_layer
## TensorFlow custom layers

View File

@ -638,7 +638,7 @@ if is_torch_available():
]
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
_import_structure["modeling_outputs"] = []
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
_import_structure["modeling_utils"] = ["PreTrainedModel"]
# PyTorch models structure
_import_structure["models.albert"].extend(
@ -1693,7 +1693,7 @@ if is_torch_available():
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
]
_import_structure["pytorch_utils"] = []
_import_structure["pytorch_utils"] = ["Conv1D", "apply_chunking_to_forward", "prune_layer"]
_import_structure["sagemaker"] = []
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
@ -2956,7 +2956,7 @@ if TYPE_CHECKING:
StoppingCriteriaList,
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer
from .modeling_utils import PreTrainedModel
from .models.albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AlbertForMaskedLM,
@ -3831,6 +3831,7 @@ if TYPE_CHECKING:
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
# Trainer
from .trainer import Trainer

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import re
@ -24,7 +23,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor, device, nn
@ -37,6 +36,14 @@ from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .generation_utils import GenerationMixin
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
)
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
@ -97,32 +104,6 @@ except ImportError:
return input
def find_pruneable_heads_and_indices(
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
"""
Finds the heads and their indices taking `already_pruned_heads` into account.
Args:
heads (`List[int]`): List of the indices of heads to prune.
n_heads (`int`): The number of heads in the model.
head_size (`int`): The size of each head.
already_pruned_heads (`Set[int]`): A set of already pruned heads.
Returns:
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
@ -2305,32 +2286,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return url
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
class PoolerStartLogits(nn.Module):
"""
Compute SQuAD start logits from sequence hidden states.
@ -2762,169 +2717,3 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return unwrap_model(model.module)
else:
return model
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
"""
Prune a linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (`torch.nn.Linear`): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
Returns:
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
"""
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
are transposed.
Used to remove heads.
Args:
layer ([`~modeling_utils.Conv1D`]): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
Returns:
[`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if dim == 0:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def prune_layer(
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[nn.Linear, Conv1D]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*): The dimension on which to keep the indices.
Returns:
`torch.nn.Linear` or [`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
if isinstance(layer, nn.Linear):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
elif isinstance(layer, Conv1D):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError(f"Can't prune layer of class {layer.__class__}")
def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
applying `forward_fn` to `input_tensors`.
Args:
forward_fn (`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (`int`):
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (`int`):
The dimension over which the `input_tensors` should be chunked.
input_tensors (`Tuple[torch.Tensor]`):
The input tensors of `forward_fn` which will be chunked
Returns:
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
Examples:
```python
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
```"""
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors):
raise ValueError(
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
"tensors are given"
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors:
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
f"size {chunk_size}"
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)

View File

@ -34,12 +34,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -33,7 +33,8 @@ from ...modeling_outputs import (
MaskedLMOutput,
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -40,12 +40,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -37,7 +37,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -35,12 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_canine import CanineConfig

View File

@ -35,13 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_convbert import ConvBertConfig

View File

@ -23,7 +23,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_ctrl import CTRLConfig

View File

@ -34,12 +34,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -25,7 +25,8 @@ from packaging import version
from torch import nn
from ...activations import ACT2FN
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
add_start_docstrings,

View File

@ -27,7 +27,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -39,12 +39,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -42,7 +42,8 @@ from ...modeling_outputs import (
DepthEstimatorOutput,
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_dpt import DPTConfig

View File

@ -36,13 +36,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -43,7 +43,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -24,7 +24,8 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -40,13 +40,8 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -35,7 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_ibert import IBertConfig
from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear

View File

@ -38,7 +38,8 @@ from ...modeling_outputs import (
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig

View File

@ -31,12 +31,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlm import LayoutLMConfig

View File

@ -31,7 +31,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,

View File

@ -24,12 +24,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -24,7 +24,8 @@ from torch import nn
from ...activations import ACT2FN, gelu
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
ModelOutput,
add_start_docstrings,

View File

@ -29,7 +29,8 @@ from transformers.utils import logging
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -39,12 +39,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -41,7 +41,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -33,7 +33,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_mpnet import MPNetConfig

View File

@ -33,12 +33,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_nystromformer import NystromformerConfig

View File

@ -28,13 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu_new, silu
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -29,12 +29,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,

View File

@ -38,7 +38,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -31,12 +31,8 @@ from ...modeling_outputs import (
MaskedLMOutput,
ModelOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_realm import RealmConfig

View File

@ -30,7 +30,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,

View File

@ -35,12 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -35,12 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -35,13 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -25,7 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -25,12 +25,8 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_splinter import SplinterConfig

View File

@ -26,7 +26,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -33,7 +33,8 @@ from ...modeling_outputs import (
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,

View File

@ -28,12 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,

View File

@ -33,7 +33,8 @@ from ...modeling_outputs import (
ModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vilt import ViltConfig

View File

@ -31,12 +31,8 @@ from ...modeling_outputs import (
MultipleChoiceModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,

View File

@ -26,7 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -28,7 +28,8 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,

View File

@ -35,14 +35,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
SQuADHead,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -34,12 +34,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,

View File

@ -25,14 +25,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import (
PoolerAnswerClass,
PoolerEndLogits,
PoolerStartLogits,
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
)
from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
ModelOutput,
add_code_sample_docstrings,

View File

@ -34,12 +34,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_yoso import YosoConfig

View File

@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Set, Tuple, Union
import torch
from packaging import version
from torch import _softmax_backward_data
from torch import _softmax_backward_data, nn
from .utils import logging
@ -45,3 +47,221 @@ def softmax_backward_data(parent, grad_output, output, dim, self):
return _softmax_backward_data(grad_output, output, parent.dim, self)
else:
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
"""
Prune a linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (`torch.nn.Linear`): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
Returns:
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
"""
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
are transposed.
Used to remove heads.
Args:
layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
Returns:
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if dim == 0:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def prune_layer(
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[nn.Linear, Conv1D]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*): The dimension on which to keep the indices.
Returns:
`torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
if isinstance(layer, nn.Linear):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
elif isinstance(layer, Conv1D):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError(f"Can't prune layer of class {layer.__class__}")
def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
applying `forward_fn` to `input_tensors`.
Args:
forward_fn (`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (`int`):
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (`int`):
The dimension over which the `input_tensors` should be chunked.
input_tensors (`Tuple[torch.Tensor]`):
The input tensors of `forward_fn` which will be chunked
Returns:
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
Examples:
```python
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
```"""
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors):
raise ValueError(
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
"tensors are given"
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors:
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
f"size {chunk_size}"
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
def find_pruneable_heads_and_indices(
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
"""
Finds the heads and their indices taking `already_pruned_heads` into account.
Args:
heads (`List[int]`): List of the indices of heads to prune.
n_heads (`int`): The number of heads in the model.
head_size (`int`): The size of each head.
already_pruned_heads (`Set[int]`): A set of already pruned heads.
Returns:
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index

View File

@ -266,13 +266,6 @@ def top_k_top_p_filtering(*args, **kwargs):
requires_backends(top_k_top_p_filtering, ["torch"])
class Conv1D(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
@ -280,14 +273,6 @@ class PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
def apply_chunking_to_forward(*args, **kwargs):
requires_backends(apply_chunking_to_forward, ["torch"])
def prune_layer(*args, **kwargs):
requires_backends(prune_layer, ["torch"])
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -4782,6 +4767,21 @@ def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch"])
class Conv1D(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def apply_chunking_to_forward(*args, **kwargs):
requires_backends(apply_chunking_to_forward, ["torch"])
def prune_layer(*args, **kwargs):
requires_backends(prune_layer, ["torch"])
class Trainer(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -43,9 +43,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,