GPTNeo: handle padded wte (#11079)

* GPTNeo: handle padded wte

* Switch to config.vocab_size

* apply review suggestion

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Leo Gao 2021-04-07 06:05:20 -06:00 committed by GitHub
parent 083ad7d46c
commit 247bed3857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -112,6 +112,10 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
array = array.transpose()
if name == ["wte"]:
# if vocab is padded, then trim off the padding embeddings
array = array[: config.vocab_size]
try:
assert (
pointer.shape == array.shape