mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[RWKV
] Rwkv fix for 8bit inference (#23468)
* rwkv fix for 8bit inference * add comment
This commit is contained in:
parent
1c460a5273
commit
21bd3be172
@ -709,8 +709,13 @@ class RwkvModel(RwkvPreTrainedModel):
|
|||||||
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
||||||
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
||||||
else:
|
else:
|
||||||
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
# Deal with quantization statistics
|
||||||
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
if hasattr(block.attention.output.weight, "SCB"):
|
||||||
|
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
|
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
|
else:
|
||||||
|
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
|
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
|
|
||||||
self.layers_are_rescaled = not self.training
|
self.layers_are_rescaled = not self.training
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user