mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove pos arg from Perceiver's Pre/Postprocessors (#18602)
* Remove pos arg from Perceiver's Pre/Postprocessors * Revert the removed pos args in public methods
This commit is contained in:
parent
71fc331746
commit
408b5e307b
@ -3130,7 +3130,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
||||
|
||||
return inp_dim + pos_dim
|
||||
|
||||
def _build_network_inputs(self, inputs: torch.Tensor, pos: torch.Tensor, network_input_is_1d: bool = True):
|
||||
def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
|
||||
"""
|
||||
Construct the final input, including position encoding.
|
||||
|
||||
@ -3209,7 +3209,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
||||
else:
|
||||
raise ValueError("Unsupported data format for conv1x1.")
|
||||
|
||||
inputs, inputs_without_pos = self._build_network_inputs(inputs, pos, network_input_is_1d)
|
||||
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
|
||||
modality_sizes = None # Size for each modality, only needed for multimodal
|
||||
|
||||
return inputs, modality_sizes, inputs_without_pos
|
||||
@ -3308,7 +3308,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
|
||||
return pos_dim
|
||||
return self.samples_per_patch + pos_dim
|
||||
|
||||
def _build_network_inputs(self, inputs, pos):
|
||||
def _build_network_inputs(self, inputs):
|
||||
"""Construct the final input, including position encoding."""
|
||||
batch_size = inputs.shape[0]
|
||||
index_dims = inputs.shape[1:-1]
|
||||
@ -3332,7 +3332,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
|
||||
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
|
||||
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
|
||||
|
||||
inputs, inputs_without_pos = self._build_network_inputs(inputs, pos)
|
||||
inputs, inputs_without_pos = self._build_network_inputs(inputs)
|
||||
modality_sizes = None # Size for each modality, only needed for multimodal
|
||||
|
||||
return inputs, modality_sizes, inputs_without_pos
|
||||
|
Loading…
Reference in New Issue
Block a user