Fix dt proj bias reassigned (#33314)

* When we set self.dt_proj.bias = None, it removes the bias parameter from the model. When we later tried to assign a tensor to self.dt_proj.bias, it caused a TypeError because PyTorch expects a Parameter object.

* When we set self.dt_proj.bias = None, it removes the bias parameter from the model. When we later tried to assign a tensor to self.dt_proj.bias, it caused a TypeError because PyTorch expects a Parameter object.

* When we set self.dt_proj.bias = None, it removes the bias parameter from the model. When we later tried to assign a tensor to self.dt_proj.bias, it caused a TypeError because PyTorch expects a Parameter object.
This commit is contained in:
HofitBata 2024-10-03 10:51:03 +03:00 committed by GitHub
parent d7950bff82
commit dc8156fdd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -713,11 +713,14 @@ class JambaMambaMixer(nn.Module):
# This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
# linear layers, and requires to call the forward pass directly.
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
time_proj_bias = self.dt_proj.bias
self.dt_proj.bias = None
# Quantized model can't work with the original code:
# ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
time_proj_bias = self.dt_proj.bias.data
with torch.no_grad():
self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data)
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
self.dt_proj.bias = time_proj_bias
with torch.no_grad():
self.dt_proj.bias.data = time_proj_bias
A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)