use latest __init__ standards and auto-generate modular

This commit is contained in:
geetu040 2025-01-27 07:48:21 +05:00
parent c54f8045ec
commit d8d3c409d8
2 changed files with 19 additions and 83 deletions

View File

@ -15,54 +15,15 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import _LazyModule
OptionalDependencyNotAvailable, from ...utils.import_utils import define_import_structure
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_minimax_text_01": ["MiniMaxText01Config"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_minimax_text_01"] = [
"MiniMaxText01ForCausalLM",
"MiniMaxText01ForQuestionAnswering",
"MiniMaxText01Model",
"MiniMaxText01PreTrainedModel",
"MiniMaxText01ForSequenceClassification",
"MiniMaxText01ForTokenClassification",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_minimax_text_01 import MiniMaxText01Config from .configuration_minimax_text_01 import *
from .modeling_minimax_text_01 import *
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_minimax_text_01 import (
MiniMaxText01ForCausalLM,
MiniMaxText01ForQuestionAnswering,
MiniMaxText01ForSequenceClassification,
MiniMaxText01ForTokenClassification,
MiniMaxText01Model,
MiniMaxText01PreTrainedModel,
)
else: else:
import sys import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) _file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -259,13 +259,6 @@ def eager_attention_forward(
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# print()
# ic(module.layer_idx)
# show_tensor(query, False, True)
# show_tensor(key_states, False, True)
# show_tensor(value_states, False, True)
# show_tensor(attn_weights, False, True)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
@ -310,23 +303,11 @@ class MiniMaxText01Attention(nn.Module):
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# print(self.layer_idx)
# show_tensor(query_states, end=False, only_shapes=False)
# show_tensor(key_states, end=False, only_shapes=True)
# show_tensor(value_states, end=True, only_shapes=True)
# print()
# print()
# ic(self.layer_idx)
# show_tensor(key_states, False, True)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# show_tensor(key_states, False, True)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -351,10 +332,6 @@ class MiniMaxText01Attention(nn.Module):
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# ic(self.layer_idx)
# show_tensor(attn_output, False, True)
return attn_output, attn_weights return attn_output, attn_weights
@ -592,7 +569,7 @@ class MiniMaxText01RotaryEmbedding(nn.Module):
2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
@ -628,7 +605,7 @@ class MiniMaxText01RotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
MINI_MAX_TEXT01_START_DOCSTRING = r""" MINIMAX_TEXT_01_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.) etc.)
@ -647,7 +624,7 @@ MINI_MAX_TEXT01_START_DOCSTRING = r"""
@add_start_docstrings( @add_start_docstrings(
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
MINI_MAX_TEXT01_START_DOCSTRING, MINIMAX_TEXT_01_START_DOCSTRING,
) )
class MiniMaxText01PreTrainedModel(PreTrainedModel): class MiniMaxText01PreTrainedModel(PreTrainedModel):
config_class = MiniMaxText01Config config_class = MiniMaxText01Config
@ -674,7 +651,7 @@ class MiniMaxText01PreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
MINI_MAX_TEXT01_INPUTS_DOCSTRING = r""" MINIMAX_TEXT_01_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@ -751,7 +728,7 @@ MINI_MAX_TEXT01_INPUTS_DOCSTRING = r"""
@add_start_docstrings( @add_start_docstrings(
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
MINI_MAX_TEXT01_START_DOCSTRING, MINIMAX_TEXT_01_START_DOCSTRING,
) )
class MiniMaxText01Model(MiniMaxText01PreTrainedModel): class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
""" """
@ -783,7 +760,7 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -820,7 +797,6 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
) )
use_cache = False use_cache = False
# TODO: raise exception here?
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
past_key_values = DynamicCache() past_key_values = DynamicCache()
@ -1173,7 +1149,7 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin):
def get_decoder(self): def get_decoder(self):
return self.model return self.model
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
@ -1222,7 +1198,6 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```""" ```"""
# ic(input_ids.shape, input_ids)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = ( output_router_logits = (
@ -1299,7 +1274,7 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin):
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch). each row of the batch).
""", """,
MINI_MAX_TEXT01_START_DOCSTRING, MINIMAX_TEXT_01_START_DOCSTRING,
) )
class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel): class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
def __init__(self, config): def __init__(self, config):
@ -1317,7 +1292,7 @@ class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.model.embed_tokens = value self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -1395,7 +1370,7 @@ class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks. output) e.g. for Named-Entity-Recognition (NER) tasks.
""", """,
MINI_MAX_TEXT01_START_DOCSTRING, MINIMAX_TEXT_01_START_DOCSTRING,
) )
class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel): class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
def __init__(self, config): def __init__(self, config):
@ -1420,7 +1395,7 @@ class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.model.embed_tokens = value self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
@ -1483,7 +1458,7 @@ class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""", """,
MINI_MAX_TEXT01_START_DOCSTRING, MINIMAX_TEXT_01_START_DOCSTRING,
) )
class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel): class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
@ -1502,7 +1477,7 @@ class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.model.embed_tokens = value self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,