mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[RWKV
] Rwkv fix for 8bit inference (#23468)
* rwkv fix for 8bit inference * add comment
This commit is contained in:
parent
1c460a5273
commit
21bd3be172
@ -709,8 +709,13 @@ class RwkvModel(RwkvPreTrainedModel):
|
||||
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
||||
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
||||
else:
|
||||
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
# Deal with quantization statistics
|
||||
if hasattr(block.attention.output.weight, "SCB"):
|
||||
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
else:
|
||||
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
|
||||
self.layers_are_rescaled = not self.training
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user