mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[t5
] Fix T5 inference in float16
+ bnb
error (#21281)
* attempts to fix: - upcast input for `T5DenseActDense` - add the condition `self.wo.weight.dtype != torch.int8` - added tests on `test/mixed_int8` - `make fixup` * fix ci test
This commit is contained in:
parent
f424b09410
commit
e2e393c6f2
@ -276,6 +276,8 @@ class LongT5DenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
@ -146,6 +146,8 @@ class MT5DenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -168,7 +170,8 @@ class MT5DenseGatedActDense(nn.Module):
|
||||
|
||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
||||
# See https://github.com/huggingface/transformers/issues/20287
|
||||
if hidden_states.dtype != self.wo.weight.dtype:
|
||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
|
@ -273,6 +273,8 @@ class SwitchTransformersDenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
@ -289,6 +289,8 @@ class T5DenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -310,7 +312,8 @@ class T5DenseGatedActDense(nn.Module):
|
||||
|
||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
||||
# See https://github.com/huggingface/transformers/issues/20287
|
||||
if hidden_states.dtype != self.wo.weight.dtype:
|
||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
|
@ -163,6 +163,70 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class MixedInt8T5Test(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_name = "t5-small"
|
||||
cls.dense_act_model_name = "google/flan-t5-small" # flan-t5 uses dense-act instead of dense-relu-dense
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.input_text = "Translate in German: Hello, my dog is cute"
|
||||
|
||||
def tearDown(self):
|
||||
r"""
|
||||
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
|
||||
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_inference_without_keep_in_fp32(self):
|
||||
r"""
|
||||
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
T5ForConditionalGeneration._keep_in_fp32_modules = None
|
||||
|
||||
# test with `t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
# test with `flan-t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
|
||||
)
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
def test_inference_with_keep_in_fp32(self):
|
||||
r"""
|
||||
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
# test with `t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
# test with `flan-t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
|
||||
)
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
Loading…
Reference in New Issue
Block a user