mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
fix t5gemma
tests (#39052)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
23b7e73f05
commit
2f50230c59
@ -41,7 +41,7 @@ from ...modeling_outputs import (
|
|||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring, can_return_tuple, logging
|
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||||
from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
||||||
|
|
||||||
|
|
||||||
@ -1112,7 +1112,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
|||||||
self.model = T5GemmaModel(config)
|
self.model = T5GemmaModel(config)
|
||||||
self.vocab_size = config.decoder.vocab_size
|
self.vocab_size = config.decoder.vocab_size
|
||||||
self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
|
self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
|
||||||
self.loss_type = "ForMaskedLMLoss"
|
self.loss_type = "ForMaskedLM"
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@ -1169,10 +1169,14 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
|||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
"""
|
"""
|
||||||
if self.training and self.config._attn_implementation != "eager":
|
if self.training and self.config._attn_implementation != "eager":
|
||||||
logger.warning_once(
|
msg = (
|
||||||
"It is strongly recommended to train T5Gemma models with the `eager` attention implementation "
|
"It is strongly recommended to train T5Gemma models with the `eager` attention implementation "
|
||||||
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
||||||
)
|
)
|
||||||
|
if is_torchdynamo_compiling():
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
logger.warning_once(msg)
|
||||||
|
|
||||||
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
# get decoder inputs from shifting lm labels to the right
|
# get decoder inputs from shifting lm labels to the right
|
||||||
|
@ -37,6 +37,7 @@ from ...utils import (
|
|||||||
auto_docstring,
|
auto_docstring,
|
||||||
can_return_tuple,
|
can_return_tuple,
|
||||||
is_torch_flex_attn_available,
|
is_torch_flex_attn_available,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..gemma2.configuration_gemma2 import Gemma2Config
|
from ..gemma2.configuration_gemma2 import Gemma2Config
|
||||||
@ -1058,7 +1059,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
|||||||
self.model = T5GemmaModel(config)
|
self.model = T5GemmaModel(config)
|
||||||
self.vocab_size = config.decoder.vocab_size
|
self.vocab_size = config.decoder.vocab_size
|
||||||
self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
|
self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
|
||||||
self.loss_type = "ForMaskedLMLoss"
|
self.loss_type = "ForMaskedLM"
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@ -1115,10 +1116,14 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
|||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
"""
|
"""
|
||||||
if self.training and self.config._attn_implementation != "eager":
|
if self.training and self.config._attn_implementation != "eager":
|
||||||
logger.warning_once(
|
msg = (
|
||||||
"It is strongly recommended to train T5Gemma models with the `eager` attention implementation "
|
"It is strongly recommended to train T5Gemma models with the `eager` attention implementation "
|
||||||
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
||||||
)
|
)
|
||||||
|
if is_torchdynamo_compiling():
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
logger.warning_once(msg)
|
||||||
|
|
||||||
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
# get decoder inputs from shifting lm labels to the right
|
# get decoder inputs from shifting lm labels to the right
|
||||||
|
@ -595,6 +595,11 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = T5GemmaForConditionalGeneration if is_torch_available() else None
|
_torch_compile_train_cls = T5GemmaForConditionalGeneration if is_torch_available() else None
|
||||||
|
# `t5gemma` will give warning or raise error if it is not `eager` during training.
|
||||||
|
_torch_compile_train_attn_implementation = "eager"
|
||||||
|
|
||||||
|
# won't fix
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = T5GemmaModelTester(self)
|
self.model_tester = T5GemmaModelTester(self)
|
||||||
@ -1584,6 +1589,9 @@ class T5GemmaEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
model_split_percents = [0.4, 0.5]
|
model_split_percents = [0.4, 0.5]
|
||||||
|
|
||||||
|
# won't fix
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = T5GemmaEncoderOnlyModelTester(self)
|
self.model_tester = T5GemmaEncoderOnlyModelTester(self)
|
||||||
self.config_tester = ConfigTester(
|
self.config_tester = ConfigTester(
|
||||||
|
@ -3748,7 +3748,7 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(
|
self.skipTest(
|
||||||
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
||||||
)
|
)
|
||||||
if config.model_type in ["modernbert", "gemma3"]:
|
if config.model_type in ["modernbert", "gemma3", "t5gemma"]:
|
||||||
self.skipTest(
|
self.skipTest(
|
||||||
reason=f"{config.model_type} currently (transformers==4.52.0) automatically adds an attention_mask input"
|
reason=f"{config.model_type} currently (transformers==4.52.0) automatically adds an attention_mask input"
|
||||||
)
|
)
|
||||||
@ -4414,6 +4414,10 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
cls = self._torch_compile_train_cls
|
cls = self._torch_compile_train_cls
|
||||||
|
attn_implementation = getattr(self, "_torch_compile_train_attn_implementation", None)
|
||||||
|
if attn_implementation is not None:
|
||||||
|
config._attn_implementation = attn_implementation
|
||||||
|
|
||||||
model = cls(config).to(torch_device)
|
model = cls(config).to(torch_device)
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
|
Loading…
Reference in New Issue
Block a user