mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[superglue] fix wrong concatenation which made batching results wrong (#38850)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
This commit is contained in:
parent
f8b88866f5
commit
1283877571
@ -725,8 +725,8 @@ class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
|
||||
matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
|
||||
matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
|
||||
|
||||
matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1)
|
||||
matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1)
|
||||
matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
|
||||
matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + encoded_keypoints[1]
|
||||
|
Loading…
Reference in New Issue
Block a user