Add type hints for PoolFormer in Pytorch (#16121)

This commit is contained in:
Hyeonsoo Lee 2022-03-15 01:14:04 +09:00 committed by GitHub
parent 6c2f3ed74c
commit 5493c10ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,7 @@
import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -379,7 +379,12 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerModelOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
@ -446,11 +451,11 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
)
def forward(
self,
pixel_values=None,
labels=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,