fix Gemma3 Config (#36893)

* fix Gemma3 Config

* fix config in modular gemm3
This commit is contained in:
AbdelKarim ELJANDOUBI 2025-03-24 10:05:44 +01:00 committed by GitHub
parent c9d1e5238a
commit fe4ca2f4a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 18 deletions

View File

@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Dict, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
@ -292,8 +292,8 @@ class Gemma3Config(PretrainedConfig):
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
@ -303,18 +303,15 @@ class Gemma3Config(PretrainedConfig):
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
logger.info("text_config is None, using default Gemma3TextConfig text config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
elif vision_config is None:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
self.text_config = text_config
self.vision_config = vision_config

View File

@ -16,7 +16,7 @@
import copy
from collections.abc import Callable
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -266,8 +266,8 @@ class Gemma3Config(PretrainedConfig):
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
@ -277,18 +277,15 @@ class Gemma3Config(PretrainedConfig):
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
logger.info("text_config is None, using default Gemma3TextConfig text config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
elif vision_config is None:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
self.text_config = text_config
self.vision_config = vision_config