mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixing image segmentation with inference mode. (#14204)
* Fixing image segmentation for inference mode. * Update src/transformers/pipelines/base.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
c28bc80bbb
commit
b338596346
@ -1011,17 +1011,19 @@ class Pipeline(_ScikitCompat):
|
||||
"""
|
||||
raise NotImplementedError("postprocess not implemented")
|
||||
|
||||
def get_inference_context(self):
|
||||
inference_context = (
|
||||
torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad
|
||||
)
|
||||
return inference_context
|
||||
|
||||
def forward(self, model_inputs, **forward_params):
|
||||
with self.device_placement():
|
||||
if self.framework == "tf":
|
||||
model_inputs["training"] = False
|
||||
model_outputs = self._forward(model_inputs, **forward_params)
|
||||
elif self.framework == "pt":
|
||||
inference_context = (
|
||||
torch.inference_mode
|
||||
if version.parse(torch.__version__) >= version.parse("1.9.0")
|
||||
else torch.no_grad
|
||||
)
|
||||
inference_context = self.get_inference_context()
|
||||
with inference_context():
|
||||
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
||||
model_outputs = self._forward(model_inputs, **forward_params)
|
||||
|
@ -114,6 +114,9 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def get_inference_context(self):
|
||||
return torch.no_grad
|
||||
|
||||
def preprocess(self, image):
|
||||
image = self.load_image(image)
|
||||
target_size = torch.IntTensor([[image.height, image.width]])
|
||||
|
Loading…
Reference in New Issue
Block a user