-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathimage_segment_mapper.py
85 lines (67 loc) · 3.04 KB
/
image_segment_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES
OP_NAME = 'image_segment_mapper'
torch = LazyLoader('torch', 'torch')
ultralytics = LazyLoader('ultralytics', 'ultralytics')
@UNFORKABLE.register_module(OP_NAME)
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageSegmentMapper(Mapper):
"""Perform segment-anything on images and return the bounding boxes."""
_accelerator = 'cuda'
def __init__(self,
imgsz=1024,
conf=0.05,
iou=0.5,
model_path='FastSAM-x.pt',
*args,
**kwargs):
"""
Initialization method.
:param imgsz: resolution for image resizing
:param conf: confidence score threshold
:param iou: IoU (Intersection over Union) score threshold
:param model_path: the path to the FastSAM model. Model name should be
one of ['FastSAM-x.pt', 'FastSAM-s.pt'].
"""
kwargs.setdefault('mem_required', '800MB')
super().__init__(*args, **kwargs)
self.imgsz = imgsz
self.conf = conf
self.iou = iou
self.model_key = prepare_model(model_type='fastsam',
model_path=model_path)
def process_single(self, sample, rank=None, context=False):
# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
# N x M x 4 for N images, M boxes, 4 coords
sample[Fields.meta][MetaKeys.bbox_tag] = np.empty((0, 0, 4),
dtype=np.float32)
return sample
if MetaKeys.bbox_tag in sample[Fields.meta]:
return sample
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)
model = get_model(self.model_key, rank=rank, use_cuda=self.use_cuda())
sample[Fields.meta][MetaKeys.bbox_tag] = []
for image in images:
masks = model(image,
retina_masks=True,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False)[0]
sample[Fields.meta][MetaKeys.bbox_tag].append(
masks.boxes.xywh.cpu().numpy())
# match schema
if len(sample[Fields.meta][MetaKeys.bbox_tag]) == 0:
sample[Fields.meta][MetaKeys.bbox_tag] = np.empty((0, 0, 4),
dtype=np.float32)
return sample