Feed forward chunking others (#6365)

* Feed forward chunking for Distilbert & Albert

* Added ff chunking for many other models

* Change model signature

* Added chunking for XLM

* Cleaned up by removing some variables.

* remove test_chunking flag

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Pradhy729 2020-08-19 05:31:10 -07:00 committed by GitHub
parent fe0b85e77a
commit 2a7402cbd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 78 additions and 31 deletions

0
src/transformers/configuration_reformer.py Normal file → Executable file
View File

View File

@ -191,6 +191,7 @@ class PretrainedConfig(object):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forwar", 0)
# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)

19
src/transformers/modeling_albert.py Normal file → Executable file
View File

@ -43,7 +43,7 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices
logger = logging.getLogger(__name__)
@ -69,6 +69,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
@ -286,6 +287,8 @@ class AlbertLayer(nn.Module):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
@ -297,14 +300,20 @@ class AlbertLayer(nn.Module):
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
ffn_output = self.dropout(ffn_output)
ffn_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0],
)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
return (hidden_states,) + attention_output[1:] # add attentions if we output them
def ff_chunk(self, attention_output):
ffn_output = self.ffn(attention_output)
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
return ffn_output
class AlbertLayerGroup(nn.Module):
def __init__(self, config):

View File

@ -424,7 +424,7 @@ class BertLayer(nn.Module):
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs

12
src/transformers/modeling_distilbert.py Normal file → Executable file
View File

@ -44,7 +44,12 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__)
@ -208,6 +213,8 @@ class FFN(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = nn.Dropout(p=config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
@ -216,6 +223,9 @@ class FFN(nn.Module):
self.activation = gelu if config.activation == "gelu" else nn.ReLU()
def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
def ff_chunk(self, input):
x = self.lin1(input)
x = self.activation(x)
x = self.lin2(x)

19
src/transformers/modeling_longformer.py Normal file → Executable file
View File

@ -41,7 +41,12 @@ from .modeling_outputs import (
TokenClassifierOutput,
)
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__)
@ -685,6 +690,8 @@ class LongformerLayer(nn.Module):
self.attention = LongformerAttention(config, layer_id)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
@ -693,11 +700,17 @@ class LongformerLayer(nn.Module):
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
layer_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output
)
outputs = (layer_output,) + outputs
return outputs
def ff_chunk(self, attn_output):
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
return layer_output
class LongformerEncoder(nn.Module):
def __init__(self, config):

4
src/transformers/modeling_reformer.py Normal file → Executable file
View File

@ -1369,7 +1369,7 @@ class ChunkReformerFeedForward(nn.Module):
def forward(self, attention_output):
return apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,
self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
)
def forward_chunk(self, hidden_states):
@ -1730,7 +1730,7 @@ class ReformerOnlyLMHead(nn.Module):
self.decoder.bias = self.bias
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)

8
src/transformers/modeling_utils.py Normal file → Executable file
View File

@ -1519,7 +1519,7 @@ def prune_layer(
def apply_chunking_to_forward(
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
"""
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
@ -1529,12 +1529,12 @@ def apply_chunking_to_forward(
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
Args:
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (:obj:`int`):
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (:obj:`int`):
The dimension over which the :obj:`input_tensors` should be chunked.
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
input_tensors (:obj:`Tuple[torch.Tensor]`):
The input tensors of ``forward_fn`` which will be chunked.
Returns:
@ -1550,7 +1550,7 @@ def apply_chunking_to_forward(
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
"""
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)

6
src/transformers/modeling_xlm.py Normal file → Executable file
View File

@ -50,6 +50,7 @@ from .modeling_utils import (
PreTrainedModel,
SequenceSummary,
SQuADHead,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
@ -212,8 +213,13 @@ class TransformerFFN(nn.Module):
self.lin1 = nn.Linear(in_dim, dim_hidden)
self.lin2 = nn.Linear(dim_hidden, out_dim)
self.act = gelu if config.gelu_activation else F.relu
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
def ff_chunk(self, input):
x = self.lin1(input)
x = self.act(x)
x = self.lin2(x)

21
src/transformers/modeling_xlnet.py Normal file → Executable file
View File

@ -35,7 +35,14 @@ from .file_utils import (
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
from .modeling_utils import (
PoolerAnswerClass,
PoolerEndLogits,
PoolerStartLogits,
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
)
logger = logging.getLogger(__name__)
@ -495,6 +502,8 @@ class XLNetLayer(nn.Module):
self.rel_attn = XLNetRelativeAttention(config)
self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward(
self,
@ -524,12 +533,18 @@ class XLNetLayer(nn.Module):
output_h, output_g = outputs[:2]
if output_g is not None:
output_g = self.ff(output_g)
output_h = self.ff(output_h)
output_g = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
)
output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
return outputs
def ff_chunk(self, output_x):
output_x = self.ff(output_x)
return output_x
class XLNetPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and

7
tests/test_modeling_bert.py Normal file → Executable file
View File

@ -26,15 +26,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
from transformers import (
BertConfig,
BertModel,
BertLMHeadModel,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertForMultipleChoice,
BertLMHeadModel,
BertModel,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_chunking = True
def setUp(self):
self.model_tester = BertModelTester(self)

8
tests/test_modeling_common.py Normal file → Executable file
View File

@ -25,15 +25,15 @@ from transformers.testing_utils import require_multigpu, require_torch, slow, to
if is_torch_available():
import torch
import numpy as np
import torch
from transformers import (
AdaptiveEmbedding,
PretrainedConfig,
PreTrainedModel,
BertModel,
BertConfig,
BertModel,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@ -65,7 +65,6 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_chunking = False
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@ -552,9 +551,6 @@ class ModelTesterMixin:
def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)

View File

@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {
@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {