Tensor location is already handled (#14224)

in `base.py` not in subclasses.
This commit is contained in:
Nicolas Patry 2021-11-01 13:42:27 +01:00 committed by GitHub
parent 323f28dce2
commit 999540dfe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -145,10 +145,7 @@ class TextClassificationPipeline(Pipeline):
function_to_apply = ClassificationFunction.NONE
outputs = model_outputs["logits"][0]
if self.framework == "pt":
outputs = outputs.cpu().numpy()
else:
outputs = outputs.numpy()
outputs = outputs.numpy()
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)