mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix blenderbot conversion script (#16472)
This commit is contained in:
parent
c85547af2b
commit
85295621f1
@ -18,7 +18,7 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BartConfig, BartForConditionalGeneration
|
||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@ -81,8 +81,8 @@ def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu")
|
||||
sd = model["model"]
|
||||
cfg = BartConfig.from_json_file(config_json_path)
|
||||
m = BartForConditionalGeneration(cfg)
|
||||
cfg = BlenderbotConfig.from_json_file(config_json_path)
|
||||
m = BlenderbotForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
|
Loading…
Reference in New Issue
Block a user