From bc65f3fc1c1714cccf58ce3d9dcdca8ba9072879 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 25 Feb 2025 10:29:47 +0100 Subject: [PATCH] [modular] Do not track imports in functions (#36279) * Add check * just check for function * Update examples --- .../configuration_my_new_model.py | 5 +++ .../configuration_my_new_model2.py | 5 +++ .../configuration_new_model.py | 14 ++++++++ .../image_processing_new_imgproc_model.py | 5 ++- .../modular-transformers/modeling_dummy.py | 5 ++- .../modeling_multimodal1.py | 5 ++- .../modeling_my_new_model2.py | 27 ++++++++------ .../modeling_new_task_model.py | 36 +++++++++++-------- .../modular-transformers/modeling_super.py | 5 ++- utils/modular_model_converter.py | 8 +++-- 10 files changed, 82 insertions(+), 33 deletions(-) diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index 59637e02d3f..febd1b88675 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -140,6 +140,11 @@ class MyNewModelConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py index eddd7fe4797..a5364a85d53 100644 --- a/examples/modular-transformers/configuration_my_new_model2.py +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -43,6 +43,11 @@ class MyNewModel2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/examples/modular-transformers/configuration_new_model.py b/examples/modular-transformers/configuration_new_model.py index 4d164fe3e75..ba05b4ea51b 100644 --- a/examples/modular-transformers/configuration_new_model.py +++ b/examples/modular-transformers/configuration_new_model.py @@ -79,6 +79,20 @@ class NewModelConfig(PretrainedConfig): model_type = "new_model" keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/examples/modular-transformers/image_processing_new_imgproc_model.py b/examples/modular-transformers/image_processing_new_imgproc_model.py index f3ab1772ec5..bde80dd0964 100644 --- a/examples/modular-transformers/image_processing_new_imgproc_model.py +++ b/examples/modular-transformers/image_processing_new_imgproc_model.py @@ -19,7 +19,7 @@ from ...image_utils import ( PILImageResampling, infer_channel_dimension_format, is_scaled_image, - make_list_of_images, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -221,8 +221,7 @@ class ImgprocModelImageProcessor(BaseImageProcessor): size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) - - images = make_list_of_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 1b0ad5ad92f..7ade13e977f 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -356,6 +356,7 @@ class DummyPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -698,7 +699,9 @@ class DummyModel(DummyPreTrainedModel): if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index ec54af22186..ee649f92864 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -356,6 +356,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -698,7 +699,9 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel): if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 86669310c4f..854d280663f 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -356,6 +356,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -491,6 +492,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, # NOOP kwarg for now ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -703,7 +705,9 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -787,17 +791,20 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel): if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: - sequence_lengths = -1 + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 3cea4ef2c45..a01094737cd 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -19,6 +19,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_new_task_model import NewTaskModelConfig @@ -254,8 +255,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): token_type_ids, past_key_values, cache_position, - input_ids=None, - inputs_embeds=None, + input_tensor, is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": @@ -265,8 +265,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + inputs_lead_dim, sequence_length = input_tensor.shape[:2] if using_static_cache: target_length = past_key_values.get_max_cache_shape() elif isinstance(past_key_values, HybridCache): @@ -297,16 +296,20 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - # we are training thus we need to create a full mask on the image + prefix but causal on suffix - if is_training: - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) + return causal_mask def get_image_features(self, pixel_values: torch.FloatTensor): @@ -325,6 +328,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -351,10 +355,12 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -418,7 +424,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): attention_mask=None, token_type_ids=None, use_cache=True, - num_logits_to_keep=None, + logits_to_keep=None, labels=None, **kwargs, ): @@ -431,7 +437,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, + logits_to_keep=logits_to_keep, token_type_ids=token_type_ids, **kwargs, ) @@ -445,10 +451,12 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): model_inputs["pixel_values"] = pixel_values is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask + return model_inputs def resize_token_embeddings( diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 45486045863..d618cd54e90 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -356,6 +356,7 @@ class SuperPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -620,7 +621,9 @@ class SuperModel(SuperPreTrainedModel): if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 3c5d062fbe3..6c85a1e7a1c 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -649,9 +649,11 @@ class ModuleMapper(CSTVisitor, ABC): self.current_function = None def visit_If(self, node): - for stmt in node.body.body: - if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - self.imports.append(node) + # If we are inside a function, do not add the import to the list of imports + if self.current_function is None: + for stmt in node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + self.imports.append(node) def visit_ClassDef(self, node: ClassDef) -> None: """Record class nodes to create their dependencies at the end."""