mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
2d506ea4c4
commit
ae9230af40
@ -275,7 +275,11 @@ class LongT5DenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
if (
|
||||
isinstance(self.wo.weight, torch.Tensor)
|
||||
and hidden_states.dtype != self.wo.weight.dtype
|
||||
and self.wo.weight.dtype != torch.int8
|
||||
):
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
@ -145,7 +145,11 @@ class MT5DenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
if (
|
||||
isinstance(self.wo.weight, torch.Tensor)
|
||||
and hidden_states.dtype != self.wo.weight.dtype
|
||||
and self.wo.weight.dtype != torch.int8
|
||||
):
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
@ -170,7 +174,11 @@ class MT5DenseGatedActDense(nn.Module):
|
||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
||||
# See https://github.com/huggingface/transformers/issues/20287
|
||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
if (
|
||||
isinstance(self.wo.weight, torch.Tensor)
|
||||
and hidden_states.dtype != self.wo.weight.dtype
|
||||
and self.wo.weight.dtype != torch.int8
|
||||
):
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
|
@ -272,7 +272,11 @@ class SwitchTransformersDenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
if (
|
||||
isinstance(self.wo.weight, torch.Tensor)
|
||||
and hidden_states.dtype != self.wo.weight.dtype
|
||||
and self.wo.weight.dtype != torch.int8
|
||||
):
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
@ -288,7 +288,11 @@ class T5DenseActDense(nn.Module):
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
if (
|
||||
isinstance(self.wo.weight, torch.Tensor)
|
||||
and hidden_states.dtype != self.wo.weight.dtype
|
||||
and self.wo.weight.dtype != torch.int8
|
||||
):
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
@ -312,7 +316,11 @@ class T5DenseGatedActDense(nn.Module):
|
||||
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
|
||||
# See https://github.com/huggingface/transformers/issues/20287
|
||||
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
|
||||
if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
|
||||
if (
|
||||
isinstance(self.wo.weight, torch.Tensor)
|
||||
and hidden_states.dtype != self.wo.weight.dtype
|
||||
and self.wo.weight.dtype != torch.int8
|
||||
):
|
||||
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
|
@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
def tokenizer(self):
|
||||
return T5Tokenizer.from_pretrained("t5-base")
|
||||
|
||||
@slow
|
||||
def test_torch_quant(self):
|
||||
r"""
|
||||
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
|
||||
"""
|
||||
model_name = "google/flan-t5-small"
|
||||
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
||||
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||
input_text = "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
||||
_ = model.generate(input_ids)
|
||||
|
||||
@slow
|
||||
def test_small_generation(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user