Remove pos arg from Perceiver's Pre/Postprocessors (#18602)

* Remove pos arg from Perceiver's Pre/Postprocessors

* Revert the removed pos args in public methods
This commit is contained in:
Ahmad Elawady 2022-09-26 14:50:58 +02:00 committed by GitHub
parent 71fc331746
commit 408b5e307b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3130,7 +3130,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
return inp_dim + pos_dim
def _build_network_inputs(self, inputs: torch.Tensor, pos: torch.Tensor, network_input_is_1d: bool = True):
def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
"""
Construct the final input, including position encoding.
@ -3209,7 +3209,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
else:
raise ValueError("Unsupported data format for conv1x1.")
inputs, inputs_without_pos = self._build_network_inputs(inputs, pos, network_input_is_1d)
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
modality_sizes = None # Size for each modality, only needed for multimodal
return inputs, modality_sizes, inputs_without_pos
@ -3308,7 +3308,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
return pos_dim
return self.samples_per_patch + pos_dim
def _build_network_inputs(self, inputs, pos):
def _build_network_inputs(self, inputs):
"""Construct the final input, including position encoding."""
batch_size = inputs.shape[0]
index_dims = inputs.shape[1:-1]
@ -3332,7 +3332,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
inputs, inputs_without_pos = self._build_network_inputs(inputs, pos)
inputs, inputs_without_pos = self._build_network_inputs(inputs)
modality_sizes = None # Size for each modality, only needed for multimodal
return inputs, modality_sizes, inputs_without_pos