fix gemma3 grad acc (#37208)

* fix gemma3 grad acc

* fix

* fix

* fix

* fix

* rmv print

* rm

* Update setup.py

* Apply style fixes

* propagate the changes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Arthur <arthur.zucker@gmail.com>
This commit is contained in:
Marc Sun 2025-06-25 16:28:44 +02:00 committed by GitHub
parent 860b898d03
commit 3c322c9cdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 13 additions and 8 deletions

View File

@ -777,6 +777,8 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
)
class Gemma3Model(Gemma3PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: Gemma3Config):
super().__init__(config)

View File

@ -727,6 +727,9 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
class Gemma3Model(PaliGemmaModel):
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.

View File

@ -132,6 +132,8 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
)
class PaliGemmaModel(PaliGemmaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)

View File

@ -629,18 +629,16 @@ class Trainer:
# Just in case the model was wrapped outside of the `Trainer`
unwrapped_model = self.accelerator.unwrap_model(model)
model_forward = (
unwrapped_model.forward
if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward
)
forward_params = inspect.signature(model_forward).parameters
# We also unwrap peft model
if _is_peft_model(unwrapped_model):
unwrapped_model = unwrapped_model.get_base_model()
# Check if the model has explicit setup for loss kwargs,
# if not, check if `**kwargs` are in model.forward
if hasattr(model, "accepts_loss_kwargs"):
self.model_accepts_loss_kwargs = model.accepts_loss_kwargs
if hasattr(unwrapped_model, "accepts_loss_kwargs"):
self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
else:
forward_params = inspect.signature(unwrapped_model.forward).parameters
self.model_accepts_loss_kwargs = any(
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
)