mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
parent
04282a9ef5
commit
7c9b0ca08c
@ -43,8 +43,8 @@ import requests
|
|||||||
from transformers import SamHQModel, SamHQProcessor
|
from transformers import SamHQModel, SamHQProcessor
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
@ -69,8 +69,8 @@ import requests
|
|||||||
from transformers import SamHQModel, SamHQProcessor
|
from transformers import SamHQModel, SamHQProcessor
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
|
@ -715,7 +715,7 @@ class SamHQModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "sushmanth/sam_hq_vit_b"
|
model_name = "syscv-community/sam-hq-vit-base"
|
||||||
model = SamHQModel.from_pretrained(model_name)
|
model = SamHQModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@ -801,8 +801,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
cleanup(torch_device, gc_collect=True)
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
def test_inference_mask_generation_no_point(self):
|
def test_inference_mask_generation_no_point(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -821,8 +821,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_mask_generation_one_point_one_bb(self):
|
def test_inference_mask_generation_one_point_one_bb(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -845,8 +845,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -887,8 +887,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=9e-3))
|
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=9e-3))
|
||||||
|
|
||||||
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -913,8 +913,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8680), atol=1e-3))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8680), atol=1e-3))
|
||||||
|
|
||||||
def test_inference_mask_generation_with_labels(self):
|
def test_inference_mask_generation_with_labels(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -933,8 +933,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9137), atol=1e-4))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9137), atol=1e-4))
|
||||||
|
|
||||||
def test_inference_mask_generation_without_labels(self):
|
def test_inference_mask_generation_without_labels(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -950,8 +950,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9137), atol=1e-3))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9137), atol=1e-3))
|
||||||
|
|
||||||
def test_inference_mask_generation_two_points_with_labels(self):
|
def test_inference_mask_generation_two_points_with_labels(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -970,8 +970,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8859), atol=1e-3))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8859), atol=1e-3))
|
||||||
|
|
||||||
def test_inference_mask_generation_two_points_without_labels(self):
|
def test_inference_mask_generation_two_points_without_labels(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -987,8 +987,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8859), atol=1e-3))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8859), atol=1e-3))
|
||||||
|
|
||||||
def test_inference_mask_generation_two_points_batched(self):
|
def test_inference_mask_generation_two_points_batched(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -1013,8 +1013,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.4482), atol=1e-4))
|
self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.4482), atol=1e-4))
|
||||||
|
|
||||||
def test_inference_mask_generation_one_box(self):
|
def test_inference_mask_generation_one_box(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -1031,8 +1031,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.6265), atol=1e-4))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.6265), atol=1e-4))
|
||||||
|
|
||||||
def test_inference_mask_generation_batched_image_one_point(self):
|
def test_inference_mask_generation_batched_image_one_point(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -1060,8 +1060,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
|
self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
|
||||||
|
|
||||||
def test_inference_mask_generation_two_points_point_batch(self):
|
def test_inference_mask_generation_two_points_point_batch(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -1084,8 +1084,8 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_mask_generation_three_boxes_point_batch(self):
|
def test_inference_mask_generation_three_boxes_point_batch(self):
|
||||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b")
|
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -1110,7 +1110,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
def test_dummy_pipeline_generation(self):
|
def test_dummy_pipeline_generation(self):
|
||||||
generator = pipeline("mask-generation", model="sushmanth/sam_hq_vit_b", device=torch_device)
|
generator = pipeline("mask-generation", model="syscv-community/sam-hq-vit-base", device=torch_device)
|
||||||
raw_image = prepare_image()
|
raw_image = prepare_image()
|
||||||
|
|
||||||
_ = generator(raw_image, points_per_batch=64)
|
_ = generator(raw_image, points_per_batch=64)
|
||||||
|
Loading…
Reference in New Issue
Block a user