[t5] Fix T5 inference in float16 + bnb error (#21281)

* attempts to fix:

- upcast input for `T5DenseActDense`
- add the condition `self.wo.weight.dtype != torch.int8`
- added tests on `test/mixed_int8`
- `make fixup`

* fix ci test
This commit is contained in:
Younes Belkada 2023-01-24 18:14:38 +01:00 committed by GitHub
parent f424b09410
commit e2e393c6f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 2 deletions

View File

@ -276,6 +276,8 @@ class LongT5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states

View File

@ -146,6 +146,8 @@ class MT5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
@ -168,7 +170,8 @@ class MT5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
if hidden_states.dtype != self.wo.weight.dtype:
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)

View File

@ -273,6 +273,8 @@ class SwitchTransformersDenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states

View File

@ -289,6 +289,8 @@ class T5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
@ -310,7 +312,8 @@ class T5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
if hidden_states.dtype != self.wo.weight.dtype:
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)

View File

@ -163,6 +163,70 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
class MixedInt8T5Test(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_name = "t5-small"
cls.dense_act_model_name = "google/flan-t5-small" # flan-t5 uses dense-act instead of dense-relu-dense
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.input_text = "Translate in German: Hello, my dog is cute"
def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
gc.collect()
torch.cuda.empty_cache()
def test_inference_without_keep_in_fp32(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases.
"""
from transformers import T5ForConditionalGeneration
T5ForConditionalGeneration._keep_in_fp32_modules = None
# test with `t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
def test_inference_with_keep_in_fp32(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases.
"""
from transformers import T5ForConditionalGeneration
# test with `t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):
super().setUp()