mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix typo of Block
. (#28727)
This commit is contained in:
parent
9e8f35fa28
commit
e694e985d7
@ -787,7 +787,7 @@ MIXTRAL_ATTENTION_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MixtralBLockSparseTop2MLP(nn.Module):
|
class MixtralBlockSparseTop2MLP(nn.Module):
|
||||||
def __init__(self, config: MixtralConfig):
|
def __init__(self, config: MixtralConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ffn_dim = config.intermediate_size
|
self.ffn_dim = config.intermediate_size
|
||||||
@ -805,6 +805,14 @@ class MixtralBLockSparseTop2MLP(nn.Module):
|
|||||||
return current_hidden_states
|
return current_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
|
||||||
|
)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class MixtralSparseMoeBlock(nn.Module):
|
class MixtralSparseMoeBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
This implementation is
|
This implementation is
|
||||||
@ -827,7 +835,7 @@ class MixtralSparseMoeBlock(nn.Module):
|
|||||||
# gating
|
# gating
|
||||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||||
|
|
||||||
self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
""" """
|
""" """
|
||||||
|
Loading…
Reference in New Issue
Block a user