Feed forward chunking (#6024)

* Chunked feed forward for Bert

This is an initial implementation to test applying feed forward chunking for BERT.
Will need additional modifications based on output and benchmark results.

* Black and cleanup

* Feed forward chunking in BertLayer class.

* Isort

* add chunking for all models

* fix docs

* Fix typo

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Pradhy729 2020-08-11 00:12:45 -07:00 committed by GitHub
parent 8a3db6b303
commit b25cec13c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 32 deletions

View File

@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
eos_token_id (:obj:`int`, optional, defaults to 2):
The token id for the <EOS> token.
feed_forward_size (:obj:`int`, optional, defaults to 512):
@ -147,7 +142,6 @@ class ReformerConfig(PretrainedConfig):
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
chunk_size_feed_forward=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
@ -202,5 +196,4 @@ class ReformerConfig(PretrainedConfig):
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.attn_layers = attn_layers

6
src/transformers/configuration_utils.py Normal file → Executable file
View File

@ -66,6 +66,11 @@ class PretrainedConfig(object):
2.
xla_device (:obj:`bool`, `optional`):
A flag to indicate if TPU are available or not.
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
Parameters for sequence generation
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
@ -163,6 +168,7 @@ class PretrainedConfig(object):
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)

20
src/transformers/modeling_bert.py Normal file → Executable file
View File

@ -48,7 +48,12 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__)
@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
@ -376,6 +382,8 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
@ -415,11 +423,17 @@ class BertLayer(nn.Module):
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):

View File

@ -370,6 +370,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_chunking = True
def setUp(self):
self.model_tester = BertModelTester(self)

View File

@ -60,6 +60,7 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_chunking = False
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class):
@ -519,6 +520,29 @@ class ModelTesterMixin:
check_hidden_states_output(inputs_dict, config, model_class)
def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:

View File

@ -291,24 +291,6 @@ class ReformerModelTester:
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
@ -517,10 +499,6 @@ class ReformerTesterMixin:
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
def test_reformer_chunking_backward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {
@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {