[RWKV] Rwkv fix for 8bit inference (#23468)

* rwkv fix for 8bit inference

* add comment
This commit is contained in:
Younes Belkada 2023-05-19 16:12:25 +02:00 committed by GitHub
parent 1c460a5273
commit 21bd3be172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -709,8 +709,13 @@ class RwkvModel(RwkvPreTrainedModel):
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
# Deal with quantization statistics
if hasattr(block.attention.output.weight, "SCB"):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
self.layers_are_rescaled = not self.training