[RWKV] Rwkv fix for 8bit inference (#23468)

* rwkv fix for 8bit inference

* add comment
This commit is contained in:
Younes Belkada 2023-05-19 16:12:25 +02:00 committed by GitHub
parent 1c460a5273
commit 21bd3be172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -708,6 +708,11 @@ class RwkvModel(RwkvPreTrainedModel):
if self.training:
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
else:
# Deal with quantization statistics
if hasattr(block.attention.output.weight, "SCB"):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))