transformers/tests/models/glm4/test_modeling_glm4.py
Matt 53fb245eb6
🚨 🚨 Inherited CausalLM Tests (#37590)
* stash commit

* Experiment 1: Try just Gemma

* Experiment 1: Just try Gemma

* make fixup

* Trigger tests

* stash commit

* Try adding Gemma3 as well

* make fixup

* Correct attrib names

* Correct pipeline model mapping

* Add in all_model_classes for Gemma1 again

* Move the pipeline model mapping around again

* make fixup

* Revert Gemma3 changes since it's a VLM

* Let's try Falcon

* Correct attributes

* Correct attributes

* Let's try just overriding get_config() for now

* Do Nemotron too

* And Llama!

* Do llama/persimmon

* Correctly skip tests

* Fix Persimmon

* Include Phimoe

* Fix Gemma2

* Set model_tester_class correctly

* Add GLM

* More models!

* models models models

* make fixup

* Add Qwen3 + Qwen3MoE

* Correct import

* make fixup

* Add the QuestionAnswering classes

* Add the QuestionAnswering classes

* Move pipeline mapping to the right place

* Jetmoe too

* Stop RoPE testing models with no RoPE

* Fix up JetMOE a bit

* Fix up JetMOE a bit

* Can we just force pad_token_id all the time?

* make fixup

* fix starcoder2

* Move pipeline mapping

* Fix RoPE skipping

* Fix RecurrentGemma tests

* Fix Falcon tests

* Add MoE attributes

* Fix values for RoPE testing

* Make sure we set bos_token_id and eos_token_id in an appropriate range

* make fixup

* Fix GLM4

* Add mamba attributes

* Revert bits of JetMOE

* Re-add the JetMOE skips

* Update tests/causal_lm_tester.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add licence

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-05-23 18:29:31 +01:00

202 lines
7.6 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Glm4 model."""
import unittest
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer, Glm4Config, is_torch_available
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_large_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
if is_torch_available():
import torch
from transformers import (
Glm4ForCausalLM,
Glm4ForSequenceClassification,
Glm4ForTokenClassification,
Glm4Model,
)
class Glm4ModelTester(CausalLMModelTester):
if is_torch_available():
config_class = Glm4Config
base_model_class = Glm4Model
causal_lm_class = Glm4ForCausalLM
sequence_classification_class = Glm4ForSequenceClassification
token_classification_class = Glm4ForTokenClassification
@require_torch
class Glm4ModelTest(CausalLMModelTest, unittest.TestCase):
model_tester_class = Glm4ModelTester
all_model_classes = (
(Glm4Model, Glm4ForCausalLM, Glm4ForSequenceClassification, Glm4ForTokenClassification)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": Glm4Model,
"text-classification": Glm4ForSequenceClassification,
"token-classification": Glm4ForTokenClassification,
"text-generation": Glm4ForCausalLM,
"zero-shot": Glm4ForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
@slow
@require_torch_large_gpu
class Glm4IntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
model_id = "THUDM/glm-4-0414-9b-chat"
revision = "refs/pr/15"
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def test_model_9b_fp16(self):
EXPECTED_TEXTS = [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
]
model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, revision=self.revision
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, revision=self.revision)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_9b_bf16(self):
EXPECTED_TEXTS = [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
]
model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, revision=self.revision
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, revision=self.revision)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_9b_eager(self):
EXPECTED_TEXTS = [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
]
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
revision=self.revision,
)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, revision=self.revision)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_torch_sdpa
def test_model_9b_sdpa(self):
EXPECTED_TEXTS = [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
]
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
revision=self.revision,
)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, revision=self.revision)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_flash_attn
@pytest.mark.flash_attn_test
def test_model_9b_flash_attn(self):
EXPECTED_TEXTS = [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
]
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
revision=self.revision,
)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, revision=self.revision)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)