Pass model_kwargs when loading a model in pipeline() (#12449)

* Pass model_kwargs when loading a model in pipeline

* Add test for model_kwargs parameter of pipeline()

* Rewrite test to not download model

* Fix failing style checks
This commit is contained in:
Alex Hedges 2021-07-09 09:24:55 -04:00 committed by GitHub
parent 18ca59e1d3
commit e7f33e8cb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View File

@ -426,7 +426,13 @@ def pipeline(
# Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
model,
model_classes=model_classes,
config=config,
framework=framework,
revision=revision,
task=task,
**model_kwargs,
)
model_config = model.config

View File

@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
for key in output_keys:
self.assertIn(key, result)
@require_torch
def test_model_kwargs_passed_to_model_load(self):
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
self.assertFalse(ner_pipeline.model.config.output_attentions)
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
self.assertTrue(ner_pipeline.model.config.output_attentions)
@require_torch
@slow
def test_spanish_bert(self):