Move rescale dtype recasting to match torchvision ToTensor (#25229)

Move dtype recasting to match torchvision ToTensor
This commit is contained in:
amyeroberts 2023-08-01 12:33:12 +01:00 committed by GitHub
parent 3170af71e1
commit d27e4c18fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -110,11 +110,12 @@ def rescale(
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
image = image.astype(dtype)
rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image