mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add type hints for PoolFormer in Pytorch (#16121)
This commit is contained in:
parent
6c2f3ed74c
commit
5493c10ecb
@ -17,7 +17,7 @@
|
||||
|
||||
import collections.abc
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -379,7 +379,12 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, PoolFormerModelOutput]:
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
@ -446,11 +451,11 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
labels=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, PoolFormerClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
|
Loading…
Reference in New Issue
Block a user