diff --git a/docs/source/en/model_doc/donut.md b/docs/source/en/model_doc/donut.md index 6e3bd3c51ea..fe2d2d4fe00 100644 --- a/docs/source/en/model_doc/donut.md +++ b/docs/source/en/model_doc/donut.md @@ -208,6 +208,11 @@ print(answer) [[autodoc]] DonutImageProcessor - preprocess +## DonutImageProcessorFast + +[[autodoc]] DonutImageProcessorFast + - preprocess + ## DonutFeatureExtractor [[autodoc]] DonutFeatureExtractor diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 6c74035ea22..4439d756382 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -80,7 +80,7 @@ else: ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")), ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")), ("dinov2", ("BitImageProcessor",)), - ("donut-swin", ("DonutImageProcessor",)), + ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")), ("dpt", ("DPTImageProcessor",)), ("efficientformer", ("EfficientFormerImageProcessor",)), ("efficientnet", ("EfficientNetImageProcessor",)), diff --git a/src/transformers/models/donut/__init__.py b/src/transformers/models/donut/__init__.py index 54de054051f..834c451f78f 100644 --- a/src/transformers/models/donut/__init__.py +++ b/src/transformers/models/donut/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_donut_swin import * from .feature_extraction_donut import * from .image_processing_donut import * + from .image_processing_donut_fast import * from .modeling_donut_swin import * from .processing_donut import * else: diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index e4afd4a4f78..72d051859a7 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -20,6 +20,7 @@ import numpy as np from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( + convert_to_rgb, get_resize_output_image_size, pad, resize, @@ -151,10 +152,21 @@ class DonutImageProcessor(BaseImageProcessor): input_height, input_width = get_image_size(image, channel_dim=input_data_format) output_height, output_width = size["height"], size["width"] + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if input_data_format == ChannelDimension.LAST: + rot_axes = (0, 1) + elif input_data_format == ChannelDimension.FIRST: + rot_axes = (1, 2) + else: + raise ValueError(f"Unsupported data format: {input_data_format}") + if (output_width < output_height and input_width > input_height) or ( output_width > output_height and input_width < input_height ): - image = np.rot90(image, 3) + image = np.rot90(image, 3, axes=rot_axes) if data_format is not None: image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) @@ -407,6 +419,8 @@ class DonutImageProcessor(BaseImageProcessor): resample=resample, ) + images = [convert_to_rgb(image) for image in images] + # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/donut/image_processing_donut_fast.py b/src/transformers/models/donut/image_processing_donut_fast.py new file mode 100644 index 00000000000..be83b5cc5c6 --- /dev/null +++ b/src/transformers/models/donut/image_processing_donut_fast.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Donut.""" + +from typing import Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, +) + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class DonutFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + do_thumbnail: Optional[bool] + do_align_long_axis: Optional[bool] + do_pad: Optional[bool] + + +@add_start_docstrings( + "Constructs a fast Donut image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + """, +) +class DonutImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 2560, "width": 1920} + do_resize = True + do_rescale = True + do_normalize = True + do_thumbnail = True + do_align_long_axis = False + do_pad = True + valid_kwargs = DonutFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[DonutFastImageProcessorKwargs]): + size = kwargs.pop("size", None) + if isinstance(size, (tuple, list)): + size = size[::-1] + kwargs["size"] = size + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[DonutFastImageProcessorKwargs]) -> BatchFeature: + if "size" in kwargs: + size = kwargs.pop("size") + if isinstance(size, (tuple, list)): + size = size[::-1] + kwargs["size"] = size + return super().preprocess(images, **kwargs) + + def align_long_axis( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`torch.Tensor`): + The image to be aligned. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + + Returns: + `torch.Tensor`: The aligned image. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + height_dim, width_dim = image.dim() - 2, image.dim() - 1 + image = torch.rot90(image, 3, dims=[height_dim, width_dim]) + + return image + + def pad_image( + self, + image: "torch.Tensor", + size: SizeDict, + random_padding: bool = False, + ) -> "torch.Tensor": + """ + Pad the image to the specified size. + + Args: + image (`torch.Tensor`): + The image to be padded. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + random_padding (`bool`, *optional*, defaults to `False`): + Whether to use random padding or not. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = size.height, size.width + input_height, input_width = image.shape[-2:] + + delta_width = output_width - input_width + delta_height = output_height - input_height + + if random_padding: + pad_top = torch.random.randint(low=0, high=delta_height + 1) + pad_left = torch.random.randint(low=0, high=delta_width + 1) + else: + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = (pad_left, pad_top, pad_right, pad_bottom) + return F.pad(image, padding) + + def pad(self, *args, **kwargs): + logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") + return self.pad_image(*args, **kwargs) + + def thumbnail( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`torch.Tensor`): + The image to be resized. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return self.resize( + image, + size=SizeDict(width=width, height=height), + interpolation=F.InterpolationMode.BICUBIC, + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + do_thumbnail: bool, + do_align_long_axis: bool, + do_pad: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_align_long_axis: + stacked_images = self.align_long_axis(image=stacked_images, size=size) + if do_resize: + shortest_edge = min(size.height, size.width) + stacked_images = self.resize( + image=stacked_images, size=SizeDict(shortest_edge=shortest_edge), interpolation=interpolation + ) + if do_thumbnail: + stacked_images = self.thumbnail(image=stacked_images, size=size) + if do_pad: + stacked_images = self.pad_image(image=stacked_images, size=size, random_padding=False) + + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["DonutImageProcessorFast"] diff --git a/src/transformers/models/nougat/image_processing_nougat.py b/src/transformers/models/nougat/image_processing_nougat.py index d5251d4ff12..25b5c5e7bc8 100644 --- a/src/transformers/models/nougat/image_processing_nougat.py +++ b/src/transformers/models/nougat/image_processing_nougat.py @@ -220,10 +220,21 @@ class NougatImageProcessor(BaseImageProcessor): input_height, input_width = get_image_size(image, channel_dim=input_data_format) output_height, output_width = size["height"], size["width"] + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if input_data_format == ChannelDimension.LAST: + rot_axes = (0, 1) + elif input_data_format == ChannelDimension.FIRST: + rot_axes = (1, 2) + else: + raise ValueError(f"Unsupported data format: {input_data_format}") + if (output_width < output_height and input_width > input_height) or ( output_width > output_height and input_width < input_height ): - image = np.rot90(image, 3) + image = np.rot90(image, 3, axes=rot_axes) if data_format is not None: image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) diff --git a/tests/models/donut/test_image_processing_donut.py b/tests/models/donut/test_image_processing_donut.py index 07da7fe74de..29c3bff2a14 100644 --- a/tests/models/donut/test_image_processing_donut.py +++ b/tests/models/donut/test_image_processing_donut.py @@ -18,7 +18,7 @@ import unittest import numpy as np from transformers.testing_utils import is_flaky, require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -31,6 +31,9 @@ if is_vision_available(): from transformers import DonutImageProcessor + if is_torchvision_available(): + from transformers import DonutImageProcessorFast + class DonutImageProcessingTester: def __init__( @@ -96,6 +99,7 @@ class DonutImageProcessingTester: @require_vision class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = DonutImageProcessor if is_vision_available() else None + fast_image_processing_class = DonutImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -106,122 +110,156 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_thumbnail")) - self.assertTrue(hasattr(image_processing, "do_align_long_axis")) - self.assertTrue(hasattr(image_processing, "do_pad")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_thumbnail")) + self.assertTrue(hasattr(image_processing, "do_align_long_axis")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 20}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 20}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) - # Previous config had dimensions in (width, height) order - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=(42, 84)) - self.assertEqual(image_processor.size, {"height": 84, "width": 42}) + # Previous config had dimensions in (width, height) order + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=(42, 84)) + self.assertEqual(image_processor.size, {"height": 84, "width": 42}) + + def test_image_processor_preprocess_with_kwargs(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + height = 84 + width = 42 + # Previous config had dimensions in (width, height) order + encoded_images = image_processing(image_inputs[0], size=(width, height), return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.image_processor_tester.num_channels, + height, + width, + ), + ) @is_flaky() def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - self.assertEqual( - encoded_images.shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - self.assertEqual( - encoded_images.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) @is_flaky() def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - self.assertEqual( - encoded_images.shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - self.assertEqual( - encoded_images.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) @is_flaky() def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - self.assertEqual( - encoded_images.shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - self.assertEqual( - encoded_images.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + + +@require_torch +@require_vision +class DonutImageProcessingAlignAxisTest(DonutImageProcessingTest): + def setUp(self): + super().setUp() + self.image_processor_tester = DonutImageProcessingTester(self, do_align_axis=True)