mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pass model_kwargs
when loading a model in pipeline()
(#12449)
* Pass model_kwargs when loading a model in pipeline * Add test for model_kwargs parameter of pipeline() * Rewrite test to not download model * Fix failing style checks
This commit is contained in:
parent
18ca59e1d3
commit
e7f33e8cb3
@ -426,7 +426,13 @@ def pipeline(
|
||||
# Will load the correct model if possible
|
||||
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
||||
framework, model = infer_framework_load_model(
|
||||
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
|
||||
model,
|
||||
model_classes=model_classes,
|
||||
config=config,
|
||||
framework=framework,
|
||||
revision=revision,
|
||||
task=task,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
model_config = model.config
|
||||
|
@ -61,6 +61,13 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
@require_torch
|
||||
def test_model_kwargs_passed_to_model_load(self):
|
||||
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
|
||||
self.assertFalse(ner_pipeline.model.config.output_attentions)
|
||||
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
|
||||
self.assertTrue(ner_pipeline.model.config.output_attentions)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_spanish_bert(self):
|
||||
|
Loading…
Reference in New Issue
Block a user