mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Feed forward chunking (#6024)
* Chunked feed forward for Bert This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results. * Black and cleanup * Feed forward chunking in BertLayer class. * Isort * add chunking for all models * fix docs * Fix typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
parent
8a3db6b303
commit
b25cec13c5
@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
|
||||
A chunk size of 0 means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
|
||||
The chunk size of all feed forward layers in the residual attention blocks.
|
||||
A chunk size of 0 means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
eos_token_id (:obj:`int`, optional, defaults to 2):
|
||||
The token id for the <EOS> token.
|
||||
feed_forward_size (:obj:`int`, optional, defaults to 512):
|
||||
@ -147,7 +142,6 @@ class ReformerConfig(PretrainedConfig):
|
||||
axial_pos_shape=[64, 64],
|
||||
axial_pos_embds_dim=[64, 192],
|
||||
chunk_size_lm_head=0,
|
||||
chunk_size_feed_forward=0,
|
||||
eos_token_id=2,
|
||||
feed_forward_size=512,
|
||||
hash_seed=None,
|
||||
@ -202,5 +196,4 @@ class ReformerConfig(PretrainedConfig):
|
||||
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
|
||||
self.axial_norm_std = axial_norm_std
|
||||
self.chunk_size_lm_head = chunk_size_lm_head
|
||||
self.chunk_size_feed_forward = chunk_size_feed_forward
|
||||
self.attn_layers = attn_layers
|
||||
|
6
src/transformers/configuration_utils.py
Normal file → Executable file
6
src/transformers/configuration_utils.py
Normal file → Executable file
@ -66,6 +66,11 @@ class PretrainedConfig(object):
|
||||
2.
|
||||
xla_device (:obj:`bool`, `optional`):
|
||||
A flag to indicate if TPU are available or not.
|
||||
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
|
||||
The chunk size of all feed forward layers in the residual attention blocks.
|
||||
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
|
||||
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
|
||||
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
|
||||
|
||||
Parameters for sequence generation
|
||||
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
|
||||
@ -163,6 +168,7 @@ class PretrainedConfig(object):
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
|
20
src/transformers/modeling_bert.py
Normal file → Executable file
20
src/transformers/modeling_bert.py
Normal file → Executable file
@ -48,7 +48,12 @@ from .modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from .modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
@ -376,6 +382,8 @@ class BertOutput(nn.Module):
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = BertAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
@ -415,11 +423,17 @@ class BertLayer(nn.Module):
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -370,6 +370,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_chunking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTester(self)
|
||||
|
@ -60,6 +60,7 @@ class ModelTesterMixin:
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
test_missing_keys = True
|
||||
test_chunking = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
@ -519,6 +520,29 @@ class ModelTesterMixin:
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_chunking:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_size_feed_forward = 1
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
|
@ -291,24 +291,6 @@ class ReformerModelTester:
|
||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||
|
||||
config.chunk_size_lm_head = 1
|
||||
config.chunk_size_feed_forward = 1
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
if not self.is_training:
|
||||
return
|
||||
@ -517,10 +499,6 @@ class ReformerTesterMixin:
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
|
||||
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)
|
||||
|
||||
def test_reformer_chunking_forward_equality(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
|
||||
|
||||
def test_reformer_chunking_backward_equality(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
|
||||
@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
test_chunking = True
|
||||
|
||||
def prepare_kwargs(self):
|
||||
return {
|
||||
@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
test_chunking = True
|
||||
|
||||
def prepare_kwargs(self):
|
||||
return {
|
||||
|
Loading…
Reference in New Issue
Block a user