mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix gemma3 grad acc (#37208)
* fix gemma3 grad acc * fix * fix * fix * fix * rmv print * rm * Update setup.py * Apply style fixes * propagate the changes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Arthur <arthur.zucker@gmail.com>
This commit is contained in:
parent
860b898d03
commit
3c322c9cdf
@ -777,6 +777,8 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
||||
)
|
||||
class Gemma3Model(Gemma3PreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__(config)
|
||||
|
@ -727,6 +727,9 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
||||
|
||||
|
||||
class Gemma3Model(PaliGemmaModel):
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
@ -132,6 +132,8 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
|
@ -629,18 +629,16 @@ class Trainer:
|
||||
|
||||
# Just in case the model was wrapped outside of the `Trainer`
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
model_forward = (
|
||||
unwrapped_model.forward
|
||||
if not _is_peft_model(unwrapped_model)
|
||||
else unwrapped_model.get_base_model().forward
|
||||
)
|
||||
forward_params = inspect.signature(model_forward).parameters
|
||||
# We also unwrap peft model
|
||||
if _is_peft_model(unwrapped_model):
|
||||
unwrapped_model = unwrapped_model.get_base_model()
|
||||
|
||||
# Check if the model has explicit setup for loss kwargs,
|
||||
# if not, check if `**kwargs` are in model.forward
|
||||
if hasattr(model, "accepts_loss_kwargs"):
|
||||
self.model_accepts_loss_kwargs = model.accepts_loss_kwargs
|
||||
if hasattr(unwrapped_model, "accepts_loss_kwargs"):
|
||||
self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
|
||||
else:
|
||||
forward_params = inspect.signature(unwrapped_model.forward).parameters
|
||||
self.model_accepts_loss_kwargs = any(
|
||||
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user