From a315988baeddd64da3eb4f030ca804ef92a73d1f Mon Sep 17 00:00:00 2001 From: Anmol Joshi Date: Tue, 12 Apr 2022 09:38:50 -0700 Subject: [PATCH] Moved functions to pytorch_utils.py (#16625) * Moved functions to pytorch_utils.py * isort formatting * Reverted tf changes * isort, make fix-copies * documentation fix * Fixed Conv1D import * Reverted research examples file * backward compatibility for pytorch_utils * missing import * isort fix --- docs/source/en/internal/modeling_utils.mdx | 12 +- src/transformers/__init__.py | 7 +- src/transformers/modeling_utils.py | 229 +----------------- .../models/albert/modeling_albert.py | 8 +- src/transformers/models/beit/modeling_beit.py | 3 +- src/transformers/models/bert/modeling_bert.py | 8 +- .../models/big_bird/modeling_big_bird.py | 3 +- .../models/canine/modeling_canine.py | 8 +- .../models/convbert/modeling_convbert.py | 9 +- src/transformers/models/ctrl/modeling_ctrl.py | 3 +- .../models/data2vec/modeling_data2vec_text.py | 8 +- .../modeling_decision_transformer.py | 3 +- src/transformers/models/deit/modeling_deit.py | 3 +- .../models/distilbert/modeling_distilbert.py | 8 +- src/transformers/models/dpt/modeling_dpt.py | 3 +- .../models/electra/modeling_electra.py | 9 +- src/transformers/models/fnet/modeling_fnet.py | 3 +- src/transformers/models/glpn/modeling_glpn.py | 3 +- src/transformers/models/gpt2/modeling_gpt2.py | 9 +- .../models/ibert/modeling_ibert.py | 3 +- .../models/imagegpt/modeling_imagegpt.py | 3 +- .../models/layoutlm/modeling_layoutlm.py | 8 +- .../models/layoutlmv2/modeling_layoutlmv2.py | 3 +- .../models/longformer/modeling_longformer.py | 8 +- src/transformers/models/luke/modeling_luke.py | 3 +- .../models/maskformer/modeling_maskformer.py | 3 +- .../megatron_bert/modeling_megatron_bert.py | 8 +- .../models/mobilebert/modeling_mobilebert.py | 3 +- .../models/mpnet/modeling_mpnet.py | 3 +- .../nystromformer/modeling_nystromformer.py | 8 +- .../models/openai/modeling_openai.py | 9 +- .../models/perceiver/modeling_perceiver.py | 8 +- .../models/qdqbert/modeling_qdqbert.py | 3 +- .../models/realm/modeling_realm.py | 8 +- .../models/reformer/modeling_reformer.py | 3 +- .../models/rembert/modeling_rembert.py | 8 +- .../models/roberta/modeling_roberta.py | 8 +- .../models/roformer/modeling_roformer.py | 9 +- .../models/segformer/modeling_segformer.py | 3 +- .../models/splinter/modeling_splinter.py | 8 +- src/transformers/models/swin/modeling_swin.py | 3 +- src/transformers/models/t5/modeling_t5.py | 3 +- .../models/tapas/modeling_tapas.py | 8 +- src/transformers/models/vilt/modeling_vilt.py | 3 +- .../visual_bert/modeling_visual_bert.py | 8 +- src/transformers/models/vit/modeling_vit.py | 3 +- .../models/vit_mae/modeling_vit_mae.py | 3 +- src/transformers/models/xlm/modeling_xlm.py | 10 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 8 +- .../models/xlnet/modeling_xlnet.py | 10 +- src/transformers/models/yoso/modeling_yoso.py | 8 +- src/transformers/pytorch_utils.py | 222 ++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 30 +-- ...ng_{{cookiecutter.lowercase_modelname}}.py | 5 +- 54 files changed, 353 insertions(+), 430 deletions(-) diff --git a/docs/source/en/internal/modeling_utils.mdx b/docs/source/en/internal/modeling_utils.mdx index af4fdf72a8e..914b8ca3679 100644 --- a/docs/source/en/internal/modeling_utils.mdx +++ b/docs/source/en/internal/modeling_utils.mdx @@ -19,7 +19,7 @@ Most of those are only useful if you are studying the code of the models in the ## Pytorch custom modules -[[autodoc]] modeling_utils.Conv1D +[[autodoc]] pytorch_utils.Conv1D [[autodoc]] modeling_utils.PoolerStartLogits - forward @@ -40,15 +40,15 @@ Most of those are only useful if you are studying the code of the models in the ## PyTorch Helper Functions -[[autodoc]] apply_chunking_to_forward +[[autodoc]] pytorch_utils.apply_chunking_to_forward -[[autodoc]] modeling_utils.find_pruneable_heads_and_indices +[[autodoc]] pytorch_utils.find_pruneable_heads_and_indices -[[autodoc]] modeling_utils.prune_layer +[[autodoc]] pytorch_utils.prune_layer -[[autodoc]] modeling_utils.prune_conv1d_layer +[[autodoc]] pytorch_utils.prune_conv1d_layer -[[autodoc]] modeling_utils.prune_linear_layer +[[autodoc]] pytorch_utils.prune_linear_layer ## TensorFlow custom layers diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dbdef5824d6..a399cbe3633 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -638,7 +638,7 @@ if is_torch_available(): ] _import_structure["generation_utils"] = ["top_k_top_p_filtering"] _import_structure["modeling_outputs"] = [] - _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] + _import_structure["modeling_utils"] = ["PreTrainedModel"] # PyTorch models structure _import_structure["models.albert"].extend( @@ -1693,7 +1693,7 @@ if is_torch_available(): "get_polynomial_decay_schedule_with_warmup", "get_scheduler", ] - _import_structure["pytorch_utils"] = [] + _import_structure["pytorch_utils"] = ["Conv1D", "apply_chunking_to_forward", "prune_layer"] _import_structure["sagemaker"] = [] _import_structure["trainer"] = ["Trainer"] _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"] @@ -2956,7 +2956,7 @@ if TYPE_CHECKING: StoppingCriteriaList, ) from .generation_utils import top_k_top_p_filtering - from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer + from .modeling_utils import PreTrainedModel from .models.albert import ( ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, AlbertForMaskedLM, @@ -3831,6 +3831,7 @@ if TYPE_CHECKING: get_polynomial_decay_schedule_with_warmup, get_scheduler, ) + from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer # Trainer from .trainer import Trainer diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e7d253c2b78..011134de7af 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import json import os import re @@ -24,7 +23,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor, device, nn @@ -37,6 +36,14 @@ from .configuration_utils import PretrainedConfig from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .dynamic_module_utils import custom_object_save from .generation_utils import GenerationMixin +from .pytorch_utils import ( # noqa: F401 + Conv1D, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_conv1d_layer, + prune_layer, + prune_linear_layer, +) from .utils import ( DUMMY_INPUTS, FLAX_WEIGHTS_NAME, @@ -97,32 +104,6 @@ except ImportError: return input -def find_pruneable_heads_and_indices( - heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] -) -> Tuple[Set[int], torch.LongTensor]: - """ - Finds the heads and their indices taking `already_pruned_heads` into account. - - Args: - heads (`List[int]`): List of the indices of heads to prune. - n_heads (`int`): The number of heads in the model. - head_size (`int`): The size of each head. - already_pruned_heads (`Set[int]`): A set of already pruned heads. - - Returns: - `Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. - """ - mask = torch.ones(n_heads, head_size) - heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in already_pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index: torch.LongTensor = torch.arange(len(mask))[mask].long() - return heads, index - - def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device @@ -2305,32 +2286,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return url -class Conv1D(nn.Module): - """ - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). - - Basically works like a linear layer but the weights are transposed. - - Args: - nf (`int`): The number of output features. - nx (`int`): The number of input features. - """ - - def __init__(self, nf, nx): - super().__init__() - self.nf = nf - w = torch.empty(nx, nf) - nn.init.normal_(w, std=0.02) - self.weight = nn.Parameter(w) - self.bias = nn.Parameter(torch.zeros(nf)) - - def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(size_out) - return x - - class PoolerStartLogits(nn.Module): """ Compute SQuAD start logits from sequence hidden states. @@ -2762,169 +2717,3 @@ def unwrap_model(model: nn.Module) -> nn.Module: return unwrap_model(model.module) else: return model - - -def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: - """ - Prune a linear layer to keep only entries in index. - - Used to remove heads. - - Args: - layer (`torch.nn.Linear`): The layer to prune. - index (`torch.LongTensor`): The indices to keep in the layer. - dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. - - Returns: - `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`. - """ - index = index.to(layer.weight.device) - W = layer.weight.index_select(dim, index).clone().detach() - if layer.bias is not None: - if dim == 1: - b = layer.bias.clone().detach() - else: - b = layer.bias[index].clone().detach() - new_size = list(layer.weight.size()) - new_size[dim] = len(index) - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) - new_layer.weight.requires_grad = False - new_layer.weight.copy_(W.contiguous()) - new_layer.weight.requires_grad = True - if layer.bias is not None: - new_layer.bias.requires_grad = False - new_layer.bias.copy_(b.contiguous()) - new_layer.bias.requires_grad = True - return new_layer - - -def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: - """ - Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights - are transposed. - - Used to remove heads. - - Args: - layer ([`~modeling_utils.Conv1D`]): The layer to prune. - index (`torch.LongTensor`): The indices to keep in the layer. - dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices. - - Returns: - [`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. - """ - index = index.to(layer.weight.device) - W = layer.weight.index_select(dim, index).clone().detach() - if dim == 0: - b = layer.bias.clone().detach() - else: - b = layer.bias[index].clone().detach() - new_size = list(layer.weight.size()) - new_size[dim] = len(index) - new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) - new_layer.weight.requires_grad = False - new_layer.weight.copy_(W.contiguous()) - new_layer.weight.requires_grad = True - new_layer.bias.requires_grad = False - new_layer.bias.copy_(b.contiguous()) - new_layer.bias.requires_grad = True - return new_layer - - -def prune_layer( - layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None -) -> Union[nn.Linear, Conv1D]: - """ - Prune a Conv1D or linear layer to keep only entries in index. - - Used to remove heads. - - Args: - layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune. - index (`torch.LongTensor`): The indices to keep in the layer. - dim (`int`, *optional*): The dimension on which to keep the indices. - - Returns: - `torch.nn.Linear` or [`~modeling_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. - """ - if isinstance(layer, nn.Linear): - return prune_linear_layer(layer, index, dim=0 if dim is None else dim) - elif isinstance(layer, Conv1D): - return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) - else: - raise ValueError(f"Can't prune layer of class {layer.__class__}") - - -def apply_chunking_to_forward( - forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors -) -> torch.Tensor: - """ - This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension - `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. - - If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly - applying `forward_fn` to `input_tensors`. - - Args: - forward_fn (`Callable[..., torch.Tensor]`): - The forward function of the model. - chunk_size (`int`): - The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`. - chunk_dim (`int`): - The dimension over which the `input_tensors` should be chunked. - input_tensors (`Tuple[torch.Tensor]`): - The input tensors of `forward_fn` which will be chunked - - Returns: - `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`. - - - Examples: - - ```python - # rename the usual forward() fn to forward_chunk() - def forward_chunk(self, hidden_states): - hidden_states = self.decoder(hidden_states) - return hidden_states - - - # implement a chunked forward function - def forward(self, hidden_states): - return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) - ```""" - - assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" - - # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility - num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) - if num_args_in_forward_chunk_fn != len(input_tensors): - raise ValueError( - f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input " - "tensors are given" - ) - - if chunk_size > 0: - tensor_shape = input_tensors[0].shape[chunk_dim] - for input_tensor in input_tensors: - if input_tensor.shape[chunk_dim] != tensor_shape: - raise ValueError( - f"All input tenors have to be of the same shape: {tensor_shape}, " - f"found shape {input_tensor.shape[chunk_dim]}" - ) - - if input_tensors[0].shape[chunk_dim] % chunk_size != 0: - raise ValueError( - f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk " - f"size {chunk_size}" - ) - - num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size - - # chunk input tensor into tuples - input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) - # apply forward fn to every tuple - output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) - # concatenate output at same dimension - return torch.cat(output_chunks, dim=chunk_dim) - - return forward_fn(*input_tensors) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 7563ae52897..241e3550811 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -34,12 +34,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 3422afcce5d..80a49fcd20b 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -33,7 +33,8 @@ from ...modeling_outputs import ( MaskedLMOutput, SemanticSegmenterOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 20b65d0c065..eb9c35e9788 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -40,12 +40,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index dbd41964aaa..47da3ff721a 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -37,7 +37,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 5052f56cc86..2d903109ac0 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -35,12 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_canine import CanineConfig diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 328242b82b1..cf2240b79c0 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -35,13 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - SequenceSummary, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_convbert import ConvBertConfig diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 9622d02e811..38dbec957c9 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -23,7 +23,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput -from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_ctrl import CTRLConfig diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 13988eafb23..9168281eb84 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -34,12 +34,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 50d6930af90..b36ece31b73 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -25,7 +25,8 @@ from packaging import version from torch import nn from ...activations import ACT2FN -from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, add_start_docstrings, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 11591aeaeed..9b858fc3c13 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -27,7 +27,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 14bc7b9e949..5ef86541f23 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -39,12 +39,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index ba454dc964d..ab742a1279c 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -42,7 +42,8 @@ from ...modeling_outputs import ( DepthEstimatorOutput, SemanticSegmenterOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_dpt import DPTConfig diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index e5347cc19c2..f2313a6cc2f 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -36,13 +36,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - SequenceSummary, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 702a66de8bc..3c301727a65 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -43,7 +43,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 86e53c78757..2361e8e61b4 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -24,7 +24,8 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 94c844f5932..d5c1860e168 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -40,13 +40,8 @@ from ...modeling_outputs import ( SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import ( - Conv1D, - PreTrainedModel, - SequenceSummary, - find_pruneable_heads_and_indices, - prune_conv1d_layer, -) +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 4d0c2a1514f..420e8b27404 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -35,7 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_ibert import IBertConfig from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 6bca6c921ab..5866744bd8d 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -38,7 +38,8 @@ from ...modeling_outputs import ( CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, ) -from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_imagegpt import ImageGPTConfig diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 2b6a8dca7bc..b198c183d98 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -31,12 +31,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_layoutlm import LayoutLMConfig diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index af5ba22c2e9..1ce330760e7 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -31,7 +31,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 87b1974ab94..647bb8fb731 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -24,12 +24,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 938f95a6341..cd5a53ddae3 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -24,7 +24,8 @@ from torch import nn from ...activations import ACT2FN, gelu from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( ModelOutput, add_start_docstrings, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index a3e6092c135..339de6eeeb1 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -29,7 +29,8 @@ from transformers.utils import logging from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithCrossAttentions -from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index c0cf71080f0..b64a0d41b93 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -39,12 +39,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 4eb1fbf69d4..1a2156ed31d 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -41,7 +41,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index 61f50f72a70..89b68544a1e 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -33,7 +33,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_mpnet import MPNetConfig diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 3e1592721fa..636e5df108c 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -33,12 +33,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_nystromformer import NystromformerConfig diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index ec943347a9b..8ded535cef1 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -28,13 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu_new, silu from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import ( - Conv1D, - PreTrainedModel, - SequenceSummary, - find_pruneable_heads_and_indices, - prune_conv1d_layer, -) +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index f323b9091a2..b8acef6c226 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -29,12 +29,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithCrossAttentions -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_start_docstrings, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 4089f12d1fe..ecd9fe73a95 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -38,7 +38,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index fbaa1c14b7f..eec4fb2b7de 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -31,12 +31,8 @@ from ...modeling_outputs import ( MaskedLMOutput, ModelOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_realm import RealmConfig diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 7934320a4d8..027fb0dc121 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -30,7 +30,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 76b4c715d75..dc6f88f886a 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -35,12 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 520e24034e6..2b7d2f87859 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -35,12 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 475655b15b5..fe746971504 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -35,13 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - SequenceSummary, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index d90e285ebcb..d8989e340c3 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -25,7 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 86e1c846c82..4d695b3137e 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -25,12 +25,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_splinter import SplinterConfig diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 81e91a19dcc..51a19ab73b8 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -26,7 +26,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b4793c204c7..edbf95d8cef 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -33,7 +33,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index eaf4fc8435b..e34c1abb57e 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -28,12 +28,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_start_docstrings, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 3e294d62f46..f29057addec 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -33,7 +33,8 @@ from ...modeling_outputs import ( ModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_vilt import ViltConfig diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 69495785fe8..506b0c749ae 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -31,12 +31,8 @@ from ...modeling_outputs import ( MultipleChoiceModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_start_docstrings, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index ef4a4191ccf..f4a88e51b51 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -26,7 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 8acae9fca75..5f257dd61f9 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -28,7 +28,8 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_start_docstrings, diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 0c05590d61c..ed0817afbb8 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -35,14 +35,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - SequenceSummary, - SQuADHead, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 04d30b43209..e6c3ac3ec8c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -34,12 +34,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, add_start_docstrings, diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index e84eb740403..079c636628f 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -25,14 +25,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_utils import ( - PoolerAnswerClass, - PoolerEndLogits, - PoolerStartLogits, - PreTrainedModel, - SequenceSummary, - apply_chunking_to_forward, -) +from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary +from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 994f80b9bdd..bcd9c516cc8 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -34,12 +34,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_yoso import YosoConfig diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index ee0c94bd9c7..189853f5577 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Callable, List, Optional, Set, Tuple, Union import torch from packaging import version -from torch import _softmax_backward_data +from torch import _softmax_backward_data, nn from .utils import logging @@ -45,3 +47,221 @@ def softmax_backward_data(parent, grad_output, output, dim, self): return _softmax_backward_data(grad_output, output, parent.dim, self) else: return _softmax_backward_data(grad_output, output, parent.dim, self.dtype) + + +def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: + """ + Prune a linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`torch.nn.Linear`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if layer.bias is not None: + if dim == 1: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: + """ + Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights + are transposed. + + Used to remove heads. + + Args: + layer ([`~pytorch_utils.Conv1D`]): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices. + + Returns: + [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if dim == 0: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_layer( + layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None +) -> Union[nn.Linear, Conv1D]: + """ + Prune a Conv1D or linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. + """ + if isinstance(layer, nn.Linear): + return prune_linear_layer(layer, index, dim=0 if dim is None else dim) + elif isinstance(layer, Conv1D): + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) + else: + raise ValueError(f"Can't prune layer of class {layer.__class__}") + + +def apply_chunking_to_forward( + forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors +) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension + `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. + + If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly + applying `forward_fn` to `input_tensors`. + + Args: + forward_fn (`Callable[..., torch.Tensor]`): + The forward function of the model. + chunk_size (`int`): + The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`. + chunk_dim (`int`): + The dimension over which the `input_tensors` should be chunked. + input_tensors (`Tuple[torch.Tensor]`): + The input tensors of `forward_fn` which will be chunked + + Returns: + `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`. + + + Examples: + + ```python + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) + ```""" + + assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + if num_args_in_forward_chunk_fn != len(input_tensors): + raise ValueError( + f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input " + "tensors are given" + ) + + if chunk_size > 0: + tensor_shape = input_tensors[0].shape[chunk_dim] + for input_tensor in input_tensors: + if input_tensor.shape[chunk_dim] != tensor_shape: + raise ValueError( + f"All input tenors have to be of the same shape: {tensor_shape}, " + f"found shape {input_tensor.shape[chunk_dim]}" + ) + + if input_tensors[0].shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk " + f"size {chunk_size}" + ) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) + + +def find_pruneable_heads_and_indices( + heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] +) -> Tuple[Set[int], torch.LongTensor]: + """ + Finds the heads and their indices taking `already_pruned_heads` into account. + + Args: + heads (`List[int]`): List of the indices of heads to prune. + n_heads (`int`): The number of heads in the model. + head_size (`int`): The size of each head. + already_pruned_heads (`Set[int]`): A set of already pruned heads. + + Returns: + `Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. + """ + mask = torch.ones(n_heads, head_size) + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in already_pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index: torch.LongTensor = torch.arange(len(mask))[mask].long() + return heads, index diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3233b44d896..c4485d91bfd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -266,13 +266,6 @@ def top_k_top_p_filtering(*args, **kwargs): requires_backends(top_k_top_p_filtering, ["torch"]) -class Conv1D(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class PreTrainedModel(metaclass=DummyObject): _backends = ["torch"] @@ -280,14 +273,6 @@ class PreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) -def apply_chunking_to_forward(*args, **kwargs): - requires_backends(apply_chunking_to_forward, ["torch"]) - - -def prune_layer(*args, **kwargs): - requires_backends(prune_layer, ["torch"]) - - ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -4782,6 +4767,21 @@ def get_scheduler(*args, **kwargs): requires_backends(get_scheduler, ["torch"]) +class Conv1D(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +def apply_chunking_to_forward(*args, **kwargs): + requires_backends(apply_chunking_to_forward, ["torch"]) + + +def prune_layer(*args, **kwargs): + requires_backends(prune_layer, ["torch"]) + + class Trainer(metaclass=DummyObject): _backends = ["torch"] diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index bdb896a2527..bde5eaa2e3b 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -43,9 +43,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - SequenceSummary, +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer,