mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
[modular] Do not track imports in functions (#36279)
* Add check * just check for function * Update examples
This commit is contained in:
parent
4b5cf5496d
commit
bc65f3fc1c
@ -140,6 +140,11 @@ class MyNewModelConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -43,6 +43,11 @@ class MyNewModel2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -79,6 +79,20 @@ class NewModelConfig(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "new_model"
|
model_type = "new_model"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -19,7 +19,7 @@ from ...image_utils import (
|
|||||||
PILImageResampling,
|
PILImageResampling,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
is_scaled_image,
|
is_scaled_image,
|
||||||
make_list_of_images,
|
make_flat_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
@ -221,8 +221,7 @@ class ImgprocModelImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
size = size if size is not None else self.size
|
size = size if size is not None else self.size
|
||||||
size = get_size_dict(size, default_to_square=False)
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
images = make_flat_list_of_images(images)
|
||||||
images = make_list_of_images(images)
|
|
||||||
|
|
||||||
if not valid_images(images):
|
if not valid_images(images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -356,6 +356,7 @@ class DummyPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_quantized_cache = True
|
_supports_quantized_cache = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@ -698,7 +699,9 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
|
@ -356,6 +356,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_quantized_cache = True
|
_supports_quantized_cache = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@ -698,7 +699,9 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
|
@ -356,6 +356,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_quantized_cache = True
|
_supports_quantized_cache = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@ -491,6 +492,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs, # NOOP kwarg for now
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -703,7 +705,9 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
@ -787,17 +791,20 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
|
|||||||
if self.config.pad_token_id is None and batch_size != 1:
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
if self.config.pad_token_id is None:
|
if self.config.pad_token_id is None:
|
||||||
sequence_lengths = -1
|
last_non_pad_token = -1
|
||||||
|
elif input_ids is not None:
|
||||||
|
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||||
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||||
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||||
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
last_non_pad_token = -1
|
||||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
logger.warning_once(
|
||||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
sequence_lengths = sequence_lengths.to(logits.device)
|
)
|
||||||
else:
|
|
||||||
sequence_lengths = -1
|
|
||||||
|
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
@ -19,6 +19,7 @@ from ...utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
from ..auto import AutoModel, AutoModelForCausalLM
|
from ..auto import AutoModel, AutoModelForCausalLM
|
||||||
from .configuration_new_task_model import NewTaskModelConfig
|
from .configuration_new_task_model import NewTaskModelConfig
|
||||||
|
|
||||||
@ -254,8 +255,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
token_type_ids,
|
token_type_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
cache_position,
|
cache_position,
|
||||||
input_ids=None,
|
input_tensor,
|
||||||
inputs_embeds=None,
|
|
||||||
is_training: bool = False,
|
is_training: bool = False,
|
||||||
):
|
):
|
||||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||||
@ -265,8 +265,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
min_dtype = torch.finfo(self.dtype).min
|
min_dtype = torch.finfo(self.dtype).min
|
||||||
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||||
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
||||||
if using_static_cache:
|
if using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
elif isinstance(past_key_values, HybridCache):
|
elif isinstance(past_key_values, HybridCache):
|
||||||
@ -297,16 +296,20 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
|
|
||||||
|
# First unmask prefix tokens during training
|
||||||
|
if is_training:
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then apply padding mask (will mask pad tokens)
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
|
||||||
if is_training:
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
|
||||||
)
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||||
@ -325,6 +328,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@ -351,10 +355,12 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||||
|
|
||||||
num_logits_to_keep (`int`, *optional*):
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -418,7 +424,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_logits_to_keep=None,
|
logits_to_keep=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -431,7 +437,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
num_logits_to_keep=num_logits_to_keep,
|
logits_to_keep=logits_to_keep,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -445,10 +451,12 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
|||||||
model_inputs["pixel_values"] = pixel_values
|
model_inputs["pixel_values"] = pixel_values
|
||||||
is_training = token_type_ids is not None and labels is not None
|
is_training = token_type_ids is not None and labels is not None
|
||||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||||
|
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||||
causal_mask = self._update_causal_mask(
|
causal_mask = self._update_causal_mask(
|
||||||
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
|
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||||
)
|
)
|
||||||
model_inputs["attention_mask"] = causal_mask
|
model_inputs["attention_mask"] = causal_mask
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def resize_token_embeddings(
|
def resize_token_embeddings(
|
||||||
|
@ -356,6 +356,7 @@ class SuperPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_quantized_cache = True
|
_supports_quantized_cache = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
@ -620,7 +621,9 @@ class SuperModel(SuperPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
|
@ -649,6 +649,8 @@ class ModuleMapper(CSTVisitor, ABC):
|
|||||||
self.current_function = None
|
self.current_function = None
|
||||||
|
|
||||||
def visit_If(self, node):
|
def visit_If(self, node):
|
||||||
|
# If we are inside a function, do not add the import to the list of imports
|
||||||
|
if self.current_function is None:
|
||||||
for stmt in node.body.body:
|
for stmt in node.body.body:
|
||||||
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
||||||
self.imports.append(node)
|
self.imports.append(node)
|
||||||
|
Loading…
Reference in New Issue
Block a user