mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Skip FP8 linear tests For device capability < 9.0(#37008)
* skip fp8 linear * add capability check * format
This commit is contained in:
parent
279c2e302a
commit
92429057d9
@ -250,6 +250,10 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||
class FP8LinearTest(unittest.TestCase):
|
||||
device = "cuda"
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
||||
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
||||
)
|
||||
def test_linear_preserves_shape(self):
|
||||
"""
|
||||
Test that FP8Linear preserves shape when in_features == out_features.
|
||||
@ -262,6 +266,10 @@ class FP8LinearTest(unittest.TestCase):
|
||||
x_ = linear(x)
|
||||
self.assertEqual(x_.shape, x.shape)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
||||
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
||||
)
|
||||
def test_linear_with_diff_feature_size_preserves_shape(self):
|
||||
"""
|
||||
Test that FP8Linear generates the correct shape when in_features != out_features.
|
||||
|
Loading…
Reference in New Issue
Block a user