-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcatFinderRoboflow.py
115 lines (95 loc) · 4.99 KB
/
catFinderRoboflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python
import cv2
import catFinder
from roboflow import Roboflow
from inference_sdk import InferenceHTTPClient
import imgUtils
class CatFinderRoboflow(catFinder.CatFinder):
def __init__(self, settingsObj, debug):
''' Initialise the Roboflow based cat finding object detector.
It expects the following elements in settingsObj:
- apiKey - Roboflow API key
- projectId - Roboflow project ID of model to be used
- versionId - Version ID of Roboflow model to be used
- localServer - if true uses a roboflow inference server on localhost:9001, rather than the one hosted by Roboflow.
- thresholds - the threshold 0 to 1 to be used to determine if one of the model class is detected (e.g. 0.5 = 50% confidence)
'''
super().__init__(settingsObj, debug)
self.apiKey = settingsObj['apiKey']
self.projectId = settingsObj['projectId']
self.versionId = settingsObj['versionId']
self.localServer = settingsObj['localServer']
self.thresholds = settingsObj['thresholds']
if not self.localServer:
try:
rf = Roboflow(api_key=self.apiKey)
project = rf.workspace().project(self.projectId)
self.model = project.version(self.versionId).model
except Exception as e:
print("***************************************************************************************")
print("****** Error connecting to Roboflow Model ******")
print("****** Have you started the local inference server with 'inference server start'? *****")
print("***************************************************************************************")
raise
def getInferenceResults(self, img):
if self.localServer:
try:
CLIENT = InferenceHTTPClient(
api_url="http://localhost:9001",
api_key=self.apiKey
)
retObj = CLIENT.infer(img, model_id="%s/%s" % (self.projectId,self.versionId))
except Exception as e:
print("***************************************************************************************")
print("****** Error connecting to local Roboflow Model ******")
print("****** Have you started the local inference server with 'inference server start'? *****")
print("***************************************************************************************")
raise
else:
# infer using hosted model
retObj = self.model.predict(img, confidence=50, overlap=30).json()
#print("CatFinderRoboflow.findCat() - retObj=", retObj)
return retObj
def findCat(self, img):
retObj = self.getInferenceResults(img)
foundCat = False
for pred in retObj['predictions']:
if pred['class']=="Cat" and pred['confidence']>self.thresholds['Cat']:
foundCat = True
return(foundCat, retObj)
def getAnnotatedImage(self, img):
'''
Returns the image annotated with the detected objects, using the retObj
returned by findCat()
'''
results = results=self.getInferenceResults(img)
#print(results)
if len(results['predictions'])>0:
for pred in results['predictions']:
#print(pred)
if pred['confidence']>self.thresholds[pred['class']]:
print("found %s" % pred['class'])
#bounding_box = results[0][0]
#print(bounding_box)
#x0, y0, x1, y1 = map(int, bounding_box[:4])
x0 = int(pred['x']-0.5*pred['width'])
y0 = int(pred['y']-0.5*pred['height'])
x1 = int(x0 + pred['width'])
y1 = int(y0 + pred['height'])
if x0<0:
x0 = 0
if y0 <0:
y0 = 0
#print("bbox=(%d, %d, %d, %d)" % (x0,y0,x1,y1))
cv2.rectangle(img, (x0, y0), (x1, y1), (255,255,0), 3)
cv2.putText(img, pred['class'], (x0, y0 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
imgScaled = imgUtils.scaleW(img,640)
# Show image
#cv2.imshow('Image Frame', imgScaled)
#cv2.waitKey(1) # waits 1ms
#cv2.destroyAllWindows() # destroys the window showing imag
return(imgScaled)
# infer on an image hosted elsewhere
# print(model.predict("URL_OF_YOUR_IMAGE", hosted=True, confidence=40, overlap=30).json())
if __name__ == "__main__":
print("CatFinderRoboflow.main()");