Skip FP8 linear tests For device capability < 9.0(#37008)

* skip fp8 linear

* add capability check

* format
This commit is contained in:
Mohamed Mekkouri 2025-03-27 12:38:37 +01:00 committed by GitHub
parent 279c2e302a
commit 92429057d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -250,6 +250,10 @@ class FP8QuantizerTest(unittest.TestCase):
class FP8LinearTest(unittest.TestCase):
device = "cuda"
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
)
def test_linear_preserves_shape(self):
"""
Test that FP8Linear preserves shape when in_features == out_features.
@ -262,6 +266,10 @@ class FP8LinearTest(unittest.TestCase):
x_ = linear(x)
self.assertEqual(x_.shape, x.shape)
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
)
def test_linear_with_diff_feature_size_preserves_shape(self):
"""
Test that FP8Linear generates the correct shape when in_features != out_features.