mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 06:20:22 +06:00
Skip FP8 linear tests For device capability < 9.0(#37008)
* skip fp8 linear * add capability check * format
This commit is contained in:
parent
279c2e302a
commit
92429057d9
@ -250,6 +250,10 @@ class FP8QuantizerTest(unittest.TestCase):
|
|||||||
class FP8LinearTest(unittest.TestCase):
|
class FP8LinearTest(unittest.TestCase):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
||||||
|
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
||||||
|
)
|
||||||
def test_linear_preserves_shape(self):
|
def test_linear_preserves_shape(self):
|
||||||
"""
|
"""
|
||||||
Test that FP8Linear preserves shape when in_features == out_features.
|
Test that FP8Linear preserves shape when in_features == out_features.
|
||||||
@ -262,6 +266,10 @@ class FP8LinearTest(unittest.TestCase):
|
|||||||
x_ = linear(x)
|
x_ = linear(x)
|
||||||
self.assertEqual(x_.shape, x.shape)
|
self.assertEqual(x_.shape, x.shape)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
||||||
|
"Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0",
|
||||||
|
)
|
||||||
def test_linear_with_diff_feature_size_preserves_shape(self):
|
def test_linear_with_diff_feature_size_preserves_shape(self):
|
||||||
"""
|
"""
|
||||||
Test that FP8Linear generates the correct shape when in_features != out_features.
|
Test that FP8Linear generates the correct shape when in_features != out_features.
|
||||||
|
Loading…
Reference in New Issue
Block a user