Fix typo of Block. (#28727)

This commit is contained in:
xkszltl 2024-01-29 07:25:00 -08:00 committed by GitHub
parent 9e8f35fa28
commit e694e985d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -787,7 +787,7 @@ MIXTRAL_ATTENTION_CLASSES = {
}
class MixtralBLockSparseTop2MLP(nn.Module):
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
@ -805,6 +805,14 @@ class MixtralBLockSparseTop2MLP(nn.Module):
return current_hidden_states
class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
def __init__(self, *args, **kwargs):
logger.warning_once(
"MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
)
super().__init__(*args, **kwargs)
class MixtralSparseMoeBlock(nn.Module):
"""
This implementation is
@ -827,7 +835,7 @@ class MixtralSparseMoeBlock(nn.Module):
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """