mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[CI] Fix ci (#21940)
* fix `get_proposal_pos_embed` * fix order * style * zero shot simplify test * add approximate values for zero shot audio classification
This commit is contained in:
parent
fcf813417a
commit
bc33fbf956
@ -497,7 +497,7 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
|
||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2 / self.embedding_dim))
|
||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
@ -1552,7 +1552,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
scale = 2 * math.pi
|
||||
|
||||
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
||||
dim_t = temperature ** (2 * torch.div(dim_t, 2) / num_pos_feats)
|
||||
dim_t = temperature ** (2 * torch_int_div(dim_t, 2) / num_pos_feats)
|
||||
# batch_size, num_queries, 4
|
||||
proposals = proposals.sigmoid() * scale
|
||||
# batch_size, num_queries, 4, 128
|
||||
|
@ -399,7 +399,7 @@ class DetaSinePositionEmbedding(nn.Module):
|
||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2 / self.embedding_dim))
|
||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
@ -1463,7 +1463,7 @@ class DetaModel(DetaPreTrainedModel):
|
||||
scale = 2 * math.pi
|
||||
|
||||
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
||||
dim_t = temperature ** (2 * torch.div(dim_t, 2) / num_pos_feats)
|
||||
dim_t = temperature ** (2 * torch_int_div(dim_t, 2) / num_pos_feats)
|
||||
# batch_size, num_queries, 4
|
||||
proposals = proposals.sigmoid() * scale
|
||||
# batch_size, num_queries, 4, 128
|
||||
|
@ -44,7 +44,7 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
>>> audio = next(iter(dataset["train"]["audio"]))["array"]
|
||||
>>> classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
|
||||
>>> classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
[{'score': 0.9995999932289124, 'label': 'Sound of a dog'}, {'score': 0.00040007088682614267, 'label': 'Sound of vaccum cleaner'}]
|
||||
[{'score': 0.9996, 'label': 'Sound of a dog'}, {'score': 0.0004, 'label': 'Sound of vaccum cleaner'}]
|
||||
```
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user