Fix SEW-D implementation differences (#14191)

* Fix SEW-D

* Update tests

* isort
This commit is contained in:
Anton Lozhkov 2021-10-28 16:22:18 +03:00 committed by GitHub
parent 78b6a2ecbd
commit 1251072f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 16 deletions

View File

@ -24,7 +24,7 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _gelu_python(x): def gelu_python(x):
""" """
Original Implementation of the GELU activation function in Google BERT repo when initially created. For Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
@ -43,7 +43,7 @@ def gelu_new(x):
if version.parse(torch.__version__) < version.parse("1.4"): if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python gelu = gelu_python
else: else:
gelu = nn.functional.gelu gelu = nn.functional.gelu
@ -97,6 +97,7 @@ ACT2FN = {
"swish": silu, "swish": silu,
"gelu": gelu, "gelu": gelu,
"tanh": torch.tanh, "tanh": torch.tanh,
"gelu_python": gelu_python,
"gelu_new": gelu_new, "gelu_new": gelu_new,
"gelu_fast": gelu_fast, "gelu_fast": gelu_fast,
"quick_gelu": quick_gelu, "quick_gelu": quick_gelu,

View File

@ -67,9 +67,9 @@ class SEWDConfig(PretrainedConfig):
:obj:`("p2c")`, :obj:`("p2c", "c2p")`, :obj:`("p2c", "c2p", 'p2p")`. :obj:`("p2c")`, :obj:`("p2c", "c2p")`, :obj:`("p2c", "c2p", 'p2p")`.
norm_rel_ebd (:obj:`str`, `optional`, defaults to :obj:`"layer_norm"`): norm_rel_ebd (:obj:`str`, `optional`, defaults to :obj:`"layer_norm"`):
Whether to use layer norm in relative embedding (:obj:`"layer_norm"` if yes) Whether to use layer norm in relative embedding (:obj:`"layer_norm"` if yes)
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_python"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"`, :obj:`"gelu_python"` and :obj:`"gelu_new"` are supported.
hidden_dropout (:obj:`float`, `optional`, defaults to 0.1): hidden_dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.1): attention_dropout (:obj:`float`, `optional`, defaults to 0.1):
@ -78,8 +78,10 @@ class SEWDConfig(PretrainedConfig):
The dropout probability for the final projection layer of :class:`SEWDForCTC`. The dropout probability for the final projection layer of :class:`SEWDForCTC`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02): initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-7):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers in the transformer encoder.
feature_layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-5):
The epsilon used by the layer normalization after the feature extractor.
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`): feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
@ -167,7 +169,7 @@ class SEWDConfig(PretrainedConfig):
position_biased_input=False, position_biased_input=False,
pos_att_type=("p2c", "c2p"), pos_att_type=("p2c", "c2p"),
norm_rel_ebd="layer_norm", norm_rel_ebd="layer_norm",
hidden_act="gelu", hidden_act="gelu_python",
hidden_dropout=0.1, hidden_dropout=0.1,
activation_dropout=0.1, activation_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
@ -175,7 +177,8 @@ class SEWDConfig(PretrainedConfig):
final_dropout=0.1, final_dropout=0.1,
layerdrop=0.1, layerdrop=0.1,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-7,
feature_layer_norm_eps=1e-5,
feat_extract_norm="group", feat_extract_norm="group",
feat_extract_activation="gelu", feat_extract_activation="gelu",
conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512), conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
@ -228,6 +231,7 @@ class SEWDConfig(PretrainedConfig):
self.final_dropout = final_dropout self.final_dropout = final_dropout
self.layerdrop = layerdrop self.layerdrop = layerdrop
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.feature_layer_norm_eps = feature_layer_norm_eps
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.vocab_size = vocab_size self.vocab_size = vocab_size

View File

@ -1310,13 +1310,13 @@ SEWD_INPUTS_DOCSTRING = r"""
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.", "The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
SEWD_START_DOCSTRING, SEWD_START_DOCSTRING,
) )
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD # Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps
class SEWDModel(SEWDPreTrainedModel): class SEWDModel(SEWDPreTrainedModel):
def __init__(self, config: SEWDConfig): def __init__(self, config: SEWDConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.feature_extractor = SEWDFeatureExtractor(config) self.feature_extractor = SEWDFeatureExtractor(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps)
self.project_features = config.conv_dim[-1] != config.hidden_size self.project_features = config.conv_dim[-1] != config.hidden_size
if self.project_features: if self.project_features:

View File

@ -21,7 +21,7 @@ from transformers.testing_utils import require_torch
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.activations import _gelu_python, gelu_new, get_activation from transformers.activations import gelu_new, gelu_python, get_activation
@require_torch @require_torch
@ -29,8 +29,8 @@ class TestActivations(unittest.TestCase):
def test_gelu_versions(self): def test_gelu_versions(self):
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100]) x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu") torch_builtin = get_activation("gelu")
self.assertTrue(torch.allclose(_gelu_python(x), torch_builtin(x))) self.assertTrue(torch.allclose(gelu_python(x), torch_builtin(x)))
self.assertFalse(torch.allclose(_gelu_python(x), gelu_new(x))) self.assertFalse(torch.allclose(gelu_python(x), gelu_new(x)))
def test_get_activation(self): def test_get_activation(self):
get_activation("swish") get_activation("swish")
@ -39,6 +39,7 @@ class TestActivations(unittest.TestCase):
get_activation("tanh") get_activation("tanh")
get_activation("gelu_new") get_activation("gelu_new")
get_activation("gelu_fast") get_activation("gelu_fast")
get_activation("gelu_python")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
get_activation("bogus") get_activation("bogus")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):

View File

@ -540,9 +540,9 @@ class SEWDModelIntegrationTest(unittest.TestCase):
) )
expected_output_sum = 54201.0469 expected_output_sum = 54201.0469
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3)) self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=1e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3)) self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=1e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5) self.assertTrue(abs(outputs.sum() - expected_output_sum) < 1)
def test_inference_ctc_batched(self): def test_inference_ctc_batched(self):
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device) model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device)