[bart-tiny-random] Put a 5MB model on S3 to allow faster exampl… (#3488)

This commit is contained in:
Sam Shleifer 2020-03-30 12:28:27 -04:00 committed by GitHub
parent 1f72865726
commit 8deff3acf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 6 deletions

View File

@ -16,15 +16,17 @@ def chunks(lst, n):
yield lst[i : i + n]
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large")
max_length = 140
min_length = 55
for batch in tqdm(list(chunks(lns, batch_size))):
for batch in tqdm(list(chunks(examples, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
@ -51,6 +53,9 @@ def _run_generate():
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
@ -58,8 +63,8 @@ def _run_generate():
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
lns = [" " + x.rstrip() for x in open(args.source_path).readlines()]
generate_summaries(lns, args.output_path, batch_size=args.bs, device=args.device)
examples = [" " + x.rstrip() for x in open(args.source_path).readlines()]
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
if __name__ == "__main__":

View File

@ -25,7 +25,8 @@ class TestBartExamples(unittest.TestCase):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
with tmp.open("w") as f:
f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), output_file_name]
testargs = ["evaluate_cnn.py", str(tmp), output_file_name, "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs):
_run_generate()
self.assertTrue(Path(output_file_name).exists())

View File

@ -27,7 +27,9 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
BartModel,
BartForConditionalGeneration,
BartForSequenceClassification,
@ -183,6 +185,15 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
def test_tiny_model(self):
model_name = "sshleifer/bart-tiny-random"
tiny = AutoModel.from_pretrained(model_name) # same vocab size
tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
with torch.no_grad():
tiny(**inputs_dict)
@require_torch
class BartHeadTests(unittest.TestCase):