forked from minhnh/mas_perception_libs
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathimage_detection_test
executable file
·154 lines (122 loc) · 6.18 KB
/
image_detection_test
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
import os
from importlib import import_module
import numpy as np
import cv2
from enum import Enum
import rospy
from cv_bridge import CvBridge, CvBridgeError
from rospy import Subscriber, Publisher
from sensor_msgs.msg import Image as ImageMsg, PointCloud2
from mas_perception_libs import SingleImageDetectionHandler
from mas_perception_libs.utils import case_insensitive_glob, cloud_msg_to_image_msg,\
crop_organized_cloud_msg, crop_cloud_to_xyz
class ImageSourceType(Enum):
CLOUD_TOPIC = 0
IMAGE_TOPIC = 1
IMAGE_DIR = 2
UNKNOWN = 3
def handle_image_directory(image_dir, detector_handler):
# glob image files
image_files = []
for file_type in ('*.jpg', '*.jpeg', '*.png', '*.bmp'):
image_files.extend(case_insensitive_glob(os.path.join(image_dir, file_type)))
# send requests and publish results
bridge = CvBridge()
for image_path in image_files:
# read and convert image to ROS message
img = cv2.imread(image_path)
try:
img_msg = bridge.cv2_to_imgmsg(img, 'bgr8')
except CvBridgeError as e:
rospy.logerr('failed to convert input from OpenCV image to ROS message: ' + e.message)
continue
# detect objects and publish result image
_, _, _ = detector_handler.process_image_msg(img_msg)
# sleep for 1 second
rospy.sleep(1.0)
class CloudTopicHandler(object):
_cropped_img_pub = None # type: Publisher
_cropped_pub = None # type: Publisher
_cloud_sub = None # type: Subscriber
_detector_handler = None # type: SingleImageDetectionHandler
def __init__(self, cloud_topic_name, detector_handler):
self._detector_handler = detector_handler
self._cropped_pub = rospy.Publisher('~first_cropped_cloud', PointCloud2, queue_size=1)
self._cropped_img_pub = rospy.Publisher('~first_cropped_image', ImageMsg, queue_size=1)
self._cloud_sub = rospy.Subscriber(cloud_topic_name, PointCloud2, self._cloud_callback)
def _cloud_callback(self, cloud_msg):
img_msg = cloud_msg_to_image_msg(cloud_msg)
# generate 2D bounding boxes from detection result
bounding_boxes, classes, confidences = self._detector_handler.process_image_msg(img_msg)
if len(bounding_boxes) == 0:
return
# crop point cloud
cropped_cloud = crop_organized_cloud_msg(cloud_msg, bounding_boxes[0])
self._cropped_pub.publish(cropped_cloud)
# crop point cloud
cropped_coord = crop_cloud_to_xyz(cloud_msg, bounding_boxes[0])
rospy.loginfo('label {0}: coord mean (frame {1}): {2}'
.format(bounding_boxes[0].label, cloud_msg.header.frame_id,
np.nanmean(np.reshape(cropped_coord, (-1, 3)), axis=0)))
class ImageTopicHandler(object):
def __init__(self, image_topic_name, detector_handler):
self._detector_handler = detector_handler
self._image_sub = rospy.Subscriber(image_topic_name, ImageMsg, self._image_callback)
def _image_callback(self, image_msg):
_, _, _ = self._detector_handler.process_image_msg(image_msg)
def main(image_source_type, image_source, detection_class, kwargs_file, class_annotation_file, result_topic):
rospy.loginfo('Image source: {0}: {1}'.format(image_source_type, image_source))
# create detector instance
detector_handler = SingleImageDetectionHandler(detection_class, class_annotation_file, kwargs_file, result_topic)
# handle image directory case
if image_source_type == ImageSourceType.IMAGE_DIR:
handle_image_directory(image_source, detector_handler)
elif image_source_type == ImageSourceType.CLOUD_TOPIC:
CloudTopicHandler(image_source, detector_handler)
elif image_source_type == ImageSourceType.IMAGE_TOPIC:
ImageTopicHandler(image_source, detector_handler)
else:
raise ValueError('unhandled image source type: ' + str(image_source_type))
if __name__ == '__main__':
rospy.init_node('~image_detection_test')
# get parameters
param_class_annotations = rospy.get_param('~class_annotations', None)
if not param_class_annotations:
raise ValueError('no class annotation file specified')
if not os.path.exists(param_class_annotations):
raise ValueError('class annotation file does not exist: ' + param_class_annotations)
param_kwargs_file = rospy.get_param('~kwargs_file', None)
if not param_kwargs_file:
raise ValueError('no kwargs file specified for detection module')
if not os.path.exists(param_kwargs_file):
raise ValueError('kwargs file for detection does not exist: ' + param_class_annotations)
param_detection_module = rospy.get_param('~detection_module', None)
if not param_detection_module:
raise ValueError('no detection module specified')
param_detection_class = rospy.get_param('~detection_class', None)
if not param_detection_class:
raise ValueError('no detection class specified')
imported_class = getattr(import_module(param_detection_module), param_detection_class)
param_result_topic = rospy.get_param('~result_topic', '/mas_perception/detection_result')
# determine image source, whether local directory, sensor_msgs/Image topic, or sensor_msgs/PointCloud2 topic
img_source_type = ImageSourceType.UNKNOWN
img_source = rospy.get_param('~image_directory', None)
if img_source:
if os.path.exists(img_source):
img_source_type = ImageSourceType.IMAGE_DIR
else:
rospy.logwarn('image directory does not exist: ' + img_source)
if img_source_type == ImageSourceType.UNKNOWN:
img_source = rospy.get_param('~cloud_topic', None)
if img_source:
img_source_type = ImageSourceType.CLOUD_TOPIC
if img_source_type == ImageSourceType.UNKNOWN:
img_source = rospy.get_param('~image_topic', None)
if img_source:
img_source_type = ImageSourceType.IMAGE_TOPIC
if img_source_type == ImageSourceType.UNKNOWN:
raise ValueError('no valid image source specified')
# test detection server
main(img_source_type, img_source, imported_class, param_kwargs_file, param_class_annotations, param_result_topic)
rospy.spin()