[T5] Fix torchquant issue (#21843)

* fix torchquant issue

* add tests
This commit is contained in:
Younes Belkada 2023-02-28 15:09:44 +01:00 committed by GitHub
parent 2d506ea4c4
commit ae9230af40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 43 additions and 6 deletions

View File

@ -275,7 +275,11 @@ class LongT5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states

View File

@ -145,7 +145,11 @@ class MT5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
@ -170,7 +174,11 @@ class MT5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)

View File

@ -272,7 +272,11 @@ class SwitchTransformersDenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states

View File

@ -288,7 +288,11 @@ class T5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
@ -312,7 +316,11 @@ class T5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)

View File

@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_torch_quant(self):
r"""
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
"""
model_name = "google/flan-t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
_ = model.generate(input_ids)
@slow
def test_small_generation(self):
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)