mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix SEW-D implementation differences (#14191)
* Fix SEW-D * Update tests * isort
This commit is contained in:
parent
78b6a2ecbd
commit
1251072f46
@ -24,7 +24,7 @@ from .utils import logging
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _gelu_python(x):
|
def gelu_python(x):
|
||||||
"""
|
"""
|
||||||
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
||||||
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||||||
@ -43,7 +43,7 @@ def gelu_new(x):
|
|||||||
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse("1.4"):
|
if version.parse(torch.__version__) < version.parse("1.4"):
|
||||||
gelu = _gelu_python
|
gelu = gelu_python
|
||||||
else:
|
else:
|
||||||
gelu = nn.functional.gelu
|
gelu = nn.functional.gelu
|
||||||
|
|
||||||
@ -97,6 +97,7 @@ ACT2FN = {
|
|||||||
"swish": silu,
|
"swish": silu,
|
||||||
"gelu": gelu,
|
"gelu": gelu,
|
||||||
"tanh": torch.tanh,
|
"tanh": torch.tanh,
|
||||||
|
"gelu_python": gelu_python,
|
||||||
"gelu_new": gelu_new,
|
"gelu_new": gelu_new,
|
||||||
"gelu_fast": gelu_fast,
|
"gelu_fast": gelu_fast,
|
||||||
"quick_gelu": quick_gelu,
|
"quick_gelu": quick_gelu,
|
||||||
|
@ -67,9 +67,9 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
:obj:`("p2c")`, :obj:`("p2c", "c2p")`, :obj:`("p2c", "c2p", 'p2p")`.
|
:obj:`("p2c")`, :obj:`("p2c", "c2p")`, :obj:`("p2c", "c2p", 'p2p")`.
|
||||||
norm_rel_ebd (:obj:`str`, `optional`, defaults to :obj:`"layer_norm"`):
|
norm_rel_ebd (:obj:`str`, `optional`, defaults to :obj:`"layer_norm"`):
|
||||||
Whether to use layer norm in relative embedding (:obj:`"layer_norm"` if yes)
|
Whether to use layer norm in relative embedding (:obj:`"layer_norm"` if yes)
|
||||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_python"`):
|
||||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"`, :obj:`"gelu_python"` and :obj:`"gelu_new"` are supported.
|
||||||
hidden_dropout (:obj:`float`, `optional`, defaults to 0.1):
|
hidden_dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.1):
|
attention_dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
@ -78,8 +78,10 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
The dropout probability for the final projection layer of :class:`SEWDForCTC`.
|
The dropout probability for the final projection layer of :class:`SEWDForCTC`.
|
||||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-7):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers in the transformer encoder.
|
||||||
|
feature_layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-5):
|
||||||
|
The epsilon used by the layer normalization after the feature extractor.
|
||||||
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
|
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
|
||||||
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
|
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
|
||||||
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
|
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
|
||||||
@ -167,7 +169,7 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
position_biased_input=False,
|
position_biased_input=False,
|
||||||
pos_att_type=("p2c", "c2p"),
|
pos_att_type=("p2c", "c2p"),
|
||||||
norm_rel_ebd="layer_norm",
|
norm_rel_ebd="layer_norm",
|
||||||
hidden_act="gelu",
|
hidden_act="gelu_python",
|
||||||
hidden_dropout=0.1,
|
hidden_dropout=0.1,
|
||||||
activation_dropout=0.1,
|
activation_dropout=0.1,
|
||||||
attention_dropout=0.1,
|
attention_dropout=0.1,
|
||||||
@ -175,7 +177,8 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
final_dropout=0.1,
|
final_dropout=0.1,
|
||||||
layerdrop=0.1,
|
layerdrop=0.1,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-5,
|
layer_norm_eps=1e-7,
|
||||||
|
feature_layer_norm_eps=1e-5,
|
||||||
feat_extract_norm="group",
|
feat_extract_norm="group",
|
||||||
feat_extract_activation="gelu",
|
feat_extract_activation="gelu",
|
||||||
conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
|
conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
|
||||||
@ -228,6 +231,7 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
self.final_dropout = final_dropout
|
self.final_dropout = final_dropout
|
||||||
self.layerdrop = layerdrop
|
self.layerdrop = layerdrop
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.feature_layer_norm_eps = feature_layer_norm_eps
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
@ -1310,13 +1310,13 @@ SEWD_INPUTS_DOCSTRING = r"""
|
|||||||
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
|
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
SEWD_START_DOCSTRING,
|
SEWD_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD
|
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps
|
||||||
class SEWDModel(SEWDPreTrainedModel):
|
class SEWDModel(SEWDPreTrainedModel):
|
||||||
def __init__(self, config: SEWDConfig):
|
def __init__(self, config: SEWDConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.feature_extractor = SEWDFeatureExtractor(config)
|
self.feature_extractor = SEWDFeatureExtractor(config)
|
||||||
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps)
|
||||||
|
|
||||||
self.project_features = config.conv_dim[-1] != config.hidden_size
|
self.project_features = config.conv_dim[-1] != config.hidden_size
|
||||||
if self.project_features:
|
if self.project_features:
|
||||||
|
@ -21,7 +21,7 @@ from transformers.testing_utils import require_torch
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.activations import _gelu_python, gelu_new, get_activation
|
from transformers.activations import gelu_new, gelu_python, get_activation
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -29,8 +29,8 @@ class TestActivations(unittest.TestCase):
|
|||||||
def test_gelu_versions(self):
|
def test_gelu_versions(self):
|
||||||
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
|
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
|
||||||
torch_builtin = get_activation("gelu")
|
torch_builtin = get_activation("gelu")
|
||||||
self.assertTrue(torch.allclose(_gelu_python(x), torch_builtin(x)))
|
self.assertTrue(torch.allclose(gelu_python(x), torch_builtin(x)))
|
||||||
self.assertFalse(torch.allclose(_gelu_python(x), gelu_new(x)))
|
self.assertFalse(torch.allclose(gelu_python(x), gelu_new(x)))
|
||||||
|
|
||||||
def test_get_activation(self):
|
def test_get_activation(self):
|
||||||
get_activation("swish")
|
get_activation("swish")
|
||||||
@ -39,6 +39,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
get_activation("tanh")
|
get_activation("tanh")
|
||||||
get_activation("gelu_new")
|
get_activation("gelu_new")
|
||||||
get_activation("gelu_fast")
|
get_activation("gelu_fast")
|
||||||
|
get_activation("gelu_python")
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
get_activation("bogus")
|
get_activation("bogus")
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
|
@ -540,9 +540,9 @@ class SEWDModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
expected_output_sum = 54201.0469
|
expected_output_sum = 54201.0469
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
|
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=1e-3))
|
||||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=1e-3))
|
||||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
|
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 1)
|
||||||
|
|
||||||
def test_inference_ctc_batched(self):
|
def test_inference_ctc_batched(self):
|
||||||
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device)
|
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device)
|
||||||
|
Loading…
Reference in New Issue
Block a user