mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
use latest __init__ standards and auto-generate modular
This commit is contained in:
parent
c54f8045ec
commit
d8d3c409d8
@ -15,54 +15,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import (
|
from ...utils import _LazyModule
|
||||||
OptionalDependencyNotAvailable,
|
from ...utils.import_utils import define_import_structure
|
||||||
_LazyModule,
|
|
||||||
is_torch_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
|
||||||
"configuration_minimax_text_01": ["MiniMaxText01Config"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not is_torch_available():
|
|
||||||
raise OptionalDependencyNotAvailable()
|
|
||||||
except OptionalDependencyNotAvailable:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
_import_structure["modeling_minimax_text_01"] = [
|
|
||||||
"MiniMaxText01ForCausalLM",
|
|
||||||
"MiniMaxText01ForQuestionAnswering",
|
|
||||||
"MiniMaxText01Model",
|
|
||||||
"MiniMaxText01PreTrainedModel",
|
|
||||||
"MiniMaxText01ForSequenceClassification",
|
|
||||||
"MiniMaxText01ForTokenClassification",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_minimax_text_01 import MiniMaxText01Config
|
from .configuration_minimax_text_01 import *
|
||||||
|
from .modeling_minimax_text_01 import *
|
||||||
try:
|
|
||||||
if not is_torch_available():
|
|
||||||
raise OptionalDependencyNotAvailable()
|
|
||||||
except OptionalDependencyNotAvailable:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
from .modeling_minimax_text_01 import (
|
|
||||||
MiniMaxText01ForCausalLM,
|
|
||||||
MiniMaxText01ForQuestionAnswering,
|
|
||||||
MiniMaxText01ForSequenceClassification,
|
|
||||||
MiniMaxText01ForTokenClassification,
|
|
||||||
MiniMaxText01Model,
|
|
||||||
MiniMaxText01PreTrainedModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
|
@ -259,13 +259,6 @@ def eager_attention_forward(
|
|||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
attn_weights = attn_weights + causal_mask
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
# print()
|
|
||||||
# ic(module.layer_idx)
|
|
||||||
# show_tensor(query, False, True)
|
|
||||||
# show_tensor(key_states, False, True)
|
|
||||||
# show_tensor(value_states, False, True)
|
|
||||||
# show_tensor(attn_weights, False, True)
|
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
@ -310,23 +303,11 @@ class MiniMaxText01Attention(nn.Module):
|
|||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
# print(self.layer_idx)
|
|
||||||
# show_tensor(query_states, end=False, only_shapes=False)
|
|
||||||
# show_tensor(key_states, end=False, only_shapes=True)
|
|
||||||
# show_tensor(value_states, end=True, only_shapes=True)
|
|
||||||
|
|
||||||
# print()
|
|
||||||
# print()
|
|
||||||
# ic(self.layer_idx)
|
|
||||||
# show_tensor(key_states, False, True)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
# show_tensor(key_states, False, True)
|
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
@ -351,10 +332,6 @@ class MiniMaxText01Attention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
# ic(self.layer_idx)
|
|
||||||
# show_tensor(attn_output, False, True)
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
@ -592,7 +569,7 @@ class MiniMaxText01RotaryEmbedding(nn.Module):
|
|||||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
@ -628,7 +605,7 @@ class MiniMaxText01RotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
MINI_MAX_TEXT01_START_DOCSTRING = r"""
|
MINIMAX_TEXT_01_START_DOCSTRING = r"""
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
etc.)
|
etc.)
|
||||||
@ -647,7 +624,7 @@ MINI_MAX_TEXT01_START_DOCSTRING = r"""
|
|||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
|
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
|
||||||
MINI_MAX_TEXT01_START_DOCSTRING,
|
MINIMAX_TEXT_01_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class MiniMaxText01PreTrainedModel(PreTrainedModel):
|
class MiniMaxText01PreTrainedModel(PreTrainedModel):
|
||||||
config_class = MiniMaxText01Config
|
config_class = MiniMaxText01Config
|
||||||
@ -674,7 +651,7 @@ class MiniMaxText01PreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
MINI_MAX_TEXT01_INPUTS_DOCSTRING = r"""
|
MINIMAX_TEXT_01_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
@ -751,7 +728,7 @@ MINI_MAX_TEXT01_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
|
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
|
||||||
MINI_MAX_TEXT01_START_DOCSTRING,
|
MINIMAX_TEXT_01_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
@ -783,7 +760,7 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -820,7 +797,6 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
# TODO: raise exception here?
|
|
||||||
if use_cache and past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
past_key_values = DynamicCache()
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
@ -1173,7 +1149,7 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin):
|
|||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1222,7 +1198,6 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
# ic(input_ids.shape, input_ids)
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_router_logits = (
|
output_router_logits = (
|
||||||
@ -1299,7 +1274,7 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin):
|
|||||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||||
each row of the batch).
|
each row of the batch).
|
||||||
""",
|
""",
|
||||||
MINI_MAX_TEXT01_START_DOCSTRING,
|
MINIMAX_TEXT_01_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
|
class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@ -1317,7 +1292,7 @@ class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.model.embed_tokens = value
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -1395,7 +1370,7 @@ class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
|
|||||||
The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
""",
|
""",
|
||||||
MINI_MAX_TEXT01_START_DOCSTRING,
|
MINIMAX_TEXT_01_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
|
class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@ -1420,7 +1395,7 @@ class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.model.embed_tokens = value
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=TokenClassifierOutput,
|
output_type=TokenClassifierOutput,
|
||||||
@ -1483,7 +1458,7 @@ class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
|
|||||||
The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like
|
The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||||
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
""",
|
""",
|
||||||
MINI_MAX_TEXT01_START_DOCSTRING,
|
MINIMAX_TEXT_01_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel):
|
class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
@ -1502,7 +1477,7 @@ class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.model.embed_tokens = value
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
Loading…
Reference in New Issue
Block a user