mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Support BatchNorm in Hubert pos_conv_emb as in fairseq (#34389)
* Support BatchNorm in Hubert pos_conv_emb as in fairseq * Correct the new defaults (#34377) * Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue * [auto. ping] Avoid sending empty info + add more team members (#34383) * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Fix glm (#34388) * Fix duplicated * fix import * Use non nested images and batched text Idefics2/3 (#34222) * add support for non nested images and add tests * add tests error scenario * fix style * added single and no image to error tests * Fix onnx non-expotable inplace aten op (#34376) * fix onnx non-expotable inplace op * mistral, qwen2, qwen2_vl, starcoder2 * fixup copies * Fix right padding in LLaVA models (#34305) * fix right pad llavas * device mismatch * no filter (#34391) * no filter * no filter * no filter --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * SynthID: better example (#34372) * better example * Update src/transformers/generation/configuration_utils.py * Update src/transformers/generation/logits_process.py * nits * Tests: upgrade `test_eager_matches_sdpa_generate` (#34386) * Fix bnb training test failure (#34414) * Fix bnb training test: compatibility with OPTSdpaAttention * Avoid check expected exception when it is on CUDA (#34408) * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Fix typos in agents_advanced.md (#34405) * [docs] Cache implementations (#34325) cache * [run-slow] hubert * Support BatchNorm in Hubert pos_conv_emb as in fairseq Add conversion integration test, and make batchnorm explicit variable * Support BatchNorm in Hubert pos_conv_emb as in fairseq fix make fixup styling changes * [run-slow] hubert * Support BatchNorm in Hubert pos_conv_emb as in fairseq * [run-slow] hubert * Support BatchNorm in Hubert pos_conv_emb as in fairseq Add conversion integration test, and make batchnorm explicit variable * Support BatchNorm in Hubert pos_conv_emb as in fairseq fix make fixup styling changes * [run-slow] hubert * [run-slow] hubert --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Co-authored-by: Rudy Delouya <rudy.delouya@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
parent
80f2b1610f
commit
6acb4e43a7
@ -94,6 +94,8 @@ class HubertConfig(PretrainedConfig):
|
||||
embeddings layer.
|
||||
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
||||
Number of groups of 1D convolutional positional embeddings layer.
|
||||
conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use batch norm instead of weight norm in conv_pos
|
||||
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
|
||||
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
|
||||
@ -182,6 +184,7 @@ class HubertConfig(PretrainedConfig):
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=128,
|
||||
num_conv_pos_embedding_groups=16,
|
||||
conv_pos_batch_norm=False,
|
||||
do_stable_layer_norm=False,
|
||||
apply_spec_augment=True,
|
||||
mask_time_prob=0.05,
|
||||
@ -209,6 +212,7 @@ class HubertConfig(PretrainedConfig):
|
||||
self.conv_bias = conv_bias
|
||||
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||
self.conv_pos_batch_norm = conv_pos_batch_norm
|
||||
self.num_feat_extract_layers = len(self.conv_dim)
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
|
@ -38,7 +38,8 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
MAPPING = {
|
||||
"post_extract_proj": "feature_projection.projection",
|
||||
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
||||
"encoder.pos_conv.0": "encoder.pos_conv_embed.batch_norm",
|
||||
"encoder.pos_conv.1": "encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
||||
@ -76,6 +77,12 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
elif weight_type == "running_mean":
|
||||
hf_pointer.running_mean.data = value
|
||||
elif weight_type == "running_var":
|
||||
hf_pointer.running_var.data = value
|
||||
elif weight_type == "num_batches_tracked":
|
||||
hf_pointer.num_batches_tracked.data = value
|
||||
else:
|
||||
hf_pointer.data = value
|
||||
|
||||
@ -116,6 +123,12 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
weight_type = "weight"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
elif "running_mean" in name:
|
||||
weight_type = "running_mean"
|
||||
elif "running_var" in name:
|
||||
weight_type = "running_var"
|
||||
elif "num_batches_tracked" in name:
|
||||
weight_type = "num_batches_tracked"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
|
@ -260,7 +260,6 @@ class HubertGroupNormConvLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
|
||||
class HubertPositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -272,32 +271,37 @@ class HubertPositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
if hasattr(self.conv, "parametrizations"):
|
||||
weight_g = self.conv.parametrizations.weight.original0
|
||||
weight_v = self.conv.parametrizations.weight.original1
|
||||
else:
|
||||
weight_g = self.conv.weight_g
|
||||
weight_v = self.conv.weight_v
|
||||
deepspeed.zero.register_external_parameter(self, weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, weight_g)
|
||||
self.batch_norm = None
|
||||
if config.conv_pos_batch_norm:
|
||||
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
|
||||
else:
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
if hasattr(self.conv, "parametrizations"):
|
||||
weight_g = self.conv.parametrizations.weight.original0
|
||||
weight_v = self.conv.parametrizations.weight.original1
|
||||
else:
|
||||
weight_g = self.conv.weight_g
|
||||
weight_v = self.conv.weight_v
|
||||
deepspeed.zero.register_external_parameter(self, weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, weight_g)
|
||||
else:
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
if self.batch_norm is not None:
|
||||
hidden_states = self.batch_norm(hidden_states)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.padding(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
|
@ -943,3 +943,40 @@ class HubertModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
|
||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)
|
||||
|
||||
def test_inference_hubert_25hz(self):
|
||||
model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device)
|
||||
|
||||
sample = self._load_datasamples(1)
|
||||
input_speech = torch.tensor(sample[0], dtype=torch.float, device=torch_device).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_speech, output_hidden_states=True).hidden_states[11]
|
||||
|
||||
# expected outputs taken from the original textlesslib implementation by:
|
||||
# model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans',
|
||||
# vocab_size=500, deduplicate=False, need_f0=False)
|
||||
# model(wav)['dense']
|
||||
expected_outputs_first = torch.tensor(
|
||||
[
|
||||
[0.0267, 0.1776, -0.1706, -0.4559],
|
||||
[-0.2430, -0.2943, -0.1864, -0.1187],
|
||||
[-0.1812, -0.4239, -0.1916, -0.0858],
|
||||
[-0.1495, -0.4758, -0.4036, 0.0302],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_outputs_last = torch.tensor(
|
||||
[
|
||||
[0.3366, -0.2734, -0.1415, -0.3055],
|
||||
[0.2329, -0.3580, -0.1421, -0.3197],
|
||||
[0.1631, -0.4301, -0.1965, -0.2956],
|
||||
[0.3342, -0.2185, -0.2253, -0.2363],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_output_sum = 1681.7603
|
||||
|
||||
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
|
||||
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
|
||||
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)
|
||||
|
Loading…
Reference in New Issue
Block a user