Fixing image segmentation with inference mode. (#14204)

* Fixing image segmentation for inference mode.

* Update src/transformers/pipelines/base.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nicolas Patry 2021-10-29 17:24:09 +02:00 committed by GitHub
parent c28bc80bbb
commit b338596346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 5 deletions

View File

@ -1011,17 +1011,19 @@ class Pipeline(_ScikitCompat):
"""
raise NotImplementedError("postprocess not implemented")
def get_inference_context(self):
inference_context = (
torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad
)
return inference_context
def forward(self, model_inputs, **forward_params):
with self.device_placement():
if self.framework == "tf":
model_inputs["training"] = False
model_outputs = self._forward(model_inputs, **forward_params)
elif self.framework == "pt":
inference_context = (
torch.inference_mode
if version.parse(torch.__version__) >= version.parse("1.9.0")
else torch.no_grad
)
inference_context = self.get_inference_context()
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
model_outputs = self._forward(model_inputs, **forward_params)

View File

@ -114,6 +114,9 @@ class ImageSegmentationPipeline(Pipeline):
return super().__call__(*args, **kwargs)
def get_inference_context(self):
return torch.no_grad
def preprocess(self, image):
image = self.load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])