[Marian Conversion] Fix eos_token_id conversion in conversion script (#14320)

This commit is contained in:
Patrick von Platen 2021-11-08 11:42:34 +01:00 committed by GitHub
parent c016dbdbda
commit b48faae364
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -455,7 +455,7 @@ BART_CONVERTER = { # for each encoder and decoder layer
class OpusState:
def __init__(self, source_dir):
def __init__(self, source_dir, eos_token_id=0):
npz_path = find_model_file(source_dir)
self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict)
@ -492,7 +492,8 @@ class OpusState:
d_model=cfg["dim-emb"],
activation_function=cfg["transformer-aan-activation"],
pad_token_id=self.pad_token_id,
eos_token_id=0,
eos_token_id=eos_token_id,
forced_eos_token_id=eos_token_id,
bos_token_id=0,
max_position_embeddings=cfg["dim-emb"],
scale_embedding=True,
@ -595,7 +596,11 @@ def convert(source_dir: Path, dest_dir):
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
tokenizer.save_pretrained(dest_dir)
opus_state = OpusState(source_dir)
# retrieve EOS token and set correctly
tokenizer_has_eos_token_id = hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None
eos_token_id = tokenizer.eos_token_id if tokenizer_has_eos_token_id else 0
opus_state = OpusState(source_dir, eos_token_id=eos_token_id)
if opus_state.cfg["vocab_size"] != len(tokenizer.encoder):
raise ValueError(
f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"