Fix blenderbot conversion script (#16472)

This commit is contained in:
Suraj Patil 2022-03-29 11:32:13 +02:00 committed by GitHub
parent c85547af2b
commit 85295621f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ import argparse
import torch
from transformers import BartConfig, BartForConditionalGeneration
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
from transformers.utils import logging
@ -81,8 +81,8 @@ def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_
"""
model = torch.load(checkpoint_path, map_location="cpu")
sd = model["model"]
cfg = BartConfig.from_json_file(config_json_path)
m = BartForConditionalGeneration(cfg)
cfg = BlenderbotConfig.from_json_file(config_json_path)
m = BlenderbotForConditionalGeneration(cfg)
valid_keys = m.model.state_dict().keys()
failures = []
mapping = {}