Fix ImageGPT doc example (#24317)

* Fix ImageGPT doc example

* Update src/transformers/models/imagegpt/image_processing_imagegpt.py

* Fix types
This commit is contained in:
amyeroberts 2023-06-16 17:01:22 +01:00 committed by GitHub
parent 096f2cf126
commit bdfd57d1d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 10 deletions

View File

@ -60,9 +60,9 @@ class ImageGPTImageProcessor(BaseImageProcessor):
(color clusters).
Args:
clusters (`np.ndarray`, *optional*):
The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)` when color quantizing. Can be
overriden by `clusters` in `preprocess`.
clusters (`np.ndarray` or `List[List[int]]`, *optional*):
The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overriden by `clusters`
in `preprocess`.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
`do_resize` in `preprocess`.
@ -82,7 +82,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
def __init__(
self,
# clusters is a first argument to maintain backwards compatibility with the old ImageGPTFeatureExtractor
clusters: Optional[np.ndarray] = None,
clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
@ -93,7 +93,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
super().__init__(**kwargs)
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
self.clusters = clusters
self.clusters = np.array(clusters) if clusters is not None else None
self.do_resize = do_resize
self.size = size
self.resample = resample
@ -154,7 +154,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
resample: PILImageResampling = None,
do_normalize: bool = None,
do_color_quantize: Optional[bool] = None,
clusters: Optional[Union[int, List[int]]] = None,
clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
**kwargs,
@ -176,7 +176,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
Whether to normalize the image
do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
Whether to color quantize the image.
clusters (`np.ndarray`, *optional*, defaults to `self.clusters`):
clusters (`np.ndarray` or `List[List[int]]`, *optional*, defaults to `self.clusters`):
Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
`do_color_quantize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
@ -199,6 +199,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
clusters = clusters if clusters is not None else self.clusters
clusters = np.array(clusters)
images = make_list_of_images(images)
@ -227,7 +228,6 @@ class ImageGPTImageProcessor(BaseImageProcessor):
images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images]
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
images = np.array(images)
clusters = np.array(clusters)
images = color_quantize(images, clusters).reshape(images.shape[:-1])
# flatten to (batch_size, height*width)

View File

@ -983,9 +983,9 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
>>> model.to(device)
>>> # unconditional generation of 8 images
>>> batch_size = 8
>>> batch_size = 4
>>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token
>>> context = torch.tensor(context).to(device)
>>> context = context.to(device)
>>> output = model.generate(
... input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
... )

View File

@ -102,6 +102,7 @@ src/transformers/models/groupvit/modeling_groupvit.py
src/transformers/models/groupvit/modeling_tf_groupvit.py
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/imagegpt/configuration_imagegpt.py
src/transformers/models/imagegpt/modeling_imagegpt.py
src/transformers/models/layoutlm/configuration_layoutlm.py
src/transformers/models/layoutlm/modeling_layoutlm.py
src/transformers/models/layoutlm/modeling_tf_layoutlm.py