mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Repo consistency fix after #33339 * [run-slow] omdet_turbo
This commit is contained in:
parent
68a2b50069
commit
1baa08897d
@ -418,29 +418,6 @@ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
self.disable_custom_kernels = config.disable_custom_kernels
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
|
||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||
grid_init = (
|
||||
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
||||
.view(self.n_heads, 1, 1, 2)
|
||||
.repeat(1, self.n_levels, self.n_points, 1)
|
||||
)
|
||||
for i in range(self.n_points):
|
||||
grid_init[:, :, i, :] *= i + 1
|
||||
with torch.no_grad():
|
||||
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
||||
nn.init.constant_(self.attention_weights.weight.data, 0.0)
|
||||
nn.init.constant_(self.attention_weights.bias.data, 0.0)
|
||||
nn.init.xavier_uniform_(self.value_proj.weight.data)
|
||||
nn.init.constant_(self.value_proj.bias.data, 0.0)
|
||||
nn.init.xavier_uniform_(self.output_proj.weight.data)
|
||||
nn.init.constant_(self.output_proj.bias.data, 0.0)
|
||||
|
||||
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
||||
return tensor if position_embeddings is None else tensor + position_embeddings
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user