diff --git a/turtlebot4_vision_tutorials/turtlebot4_vision_tutorials/pose_display.py b/turtlebot4_vision_tutorials/turtlebot4_vision_tutorials/pose_display.py index eb034e8..49194b0 100644 --- a/turtlebot4_vision_tutorials/turtlebot4_vision_tutorials/pose_display.py +++ b/turtlebot4_vision_tutorials/turtlebot4_vision_tutorials/pose_display.py @@ -17,6 +17,7 @@ # @author Hilary Luo (hluo@clearpathrobotics.com) import cv2 +from functools import partial import numpy as np from cv_bridge import CvBridge @@ -25,7 +26,7 @@ from rclpy.node import Node from rclpy.qos import qos_profile_sensor_data from std_msgs.msg import String as string_msg -from sensor_msgs.msg import Image +from sensor_msgs.msg import BatteryState, Image from turtlebot4_vision_tutorials.MovenetDepthaiEdge import Body @@ -35,37 +36,74 @@ [6, 12], [12, 11], [11, 5], [12, 14], [14, 16], [11, 13], [13, 15]] - -class PoseDetection(Node): +class PoseDisplay(Node): lights_on_ = False - frame = None - body = Body() + frame = [ None, None, None, None, None, None, None] + body = [ None, None, None, None, None, None, None] + percentage = [ None, None, None, None, None, None, None] + def __init__(self): super().__init__('pose_display') + self.declare_parameter('tile_x', 3) + self.declare_parameter('tile_y', 2) + self.declare_parameter('namespaces', ['tb11', 'tb12']) + self.declare_parameter('image_height', 432) + self.declare_parameter('image_width', 768) + + self.tile_x = self.get_parameter('tile_x').get_parameter_value().integer_value + self.tile_y = self.get_parameter('tile_y').get_parameter_value().integer_value + self.namespaces = self.get_parameter('namespaces').get_parameter_value().string_array_value + self.image_height = self.get_parameter('image_height').get_parameter_value().integer_value + self.image_width = self.get_parameter('image_width').get_parameter_value().integer_value + + self.full_frame = np.zeros((self.image_height*self.tile_y, + self.image_width*self.tile_x, 3), np.uint8) self.output = None - # Subscribe to the /semaphore_flags topic - self.semaphore_flag_subscriber = self.create_subscription( - string_msg, - 'semaphore_flag', - self.semaphore_flag_callback, - qos_profile_sensor_data) - - # Subscribe to the ffmpeg_decoded topic - self.body_pose_subscriber = self.create_subscription( - PoseArray, - 'body_pose', - self.body_pose_callback, - qos_profile_sensor_data) - - # Subscribe to the ffmpeg_decoded topic - self.ffmpeg_subscriber = self.create_subscription( - Image, - 'oakd/rgb/preview/ffmpeg_decoded', - self.frame_callback, - qos_profile_sensor_data) + # # Subscribe to the /semaphore_flags topic + # self.semaphore_flag_subscriber = self.create_subscription( + # string_msg, + # 'semaphore_flag', + # self.semaphore_flag_callback, + # qos_profile_sensor_data) + + self.body_pose_subscribers = [] + self.ffmpeg_subscribers = [] + self.battery_subscribers = [] + + # Naming a window + cv2.namedWindow("Movenet", cv2.WINDOW_NORMAL) + + # Subscribe to the pose topics + for i, ns in enumerate(self.namespaces): + subscriber = self.create_subscription( + PoseArray, + f'/{ns}/body_pose', + partial(self.body_pose_callback, num = i), + qos_profile_sensor_data) + self.body_pose_subscribers.append(subscriber) + + # Subscribe to the ffmpeg_decoded topics + subscriber = self.create_subscription( + Image, + f'/{ns}/oakd/rgb/preview/ffmpeg_decoded', + partial(self.frame_callback, num = i), + qos_profile_sensor_data) + self.ffmpeg_subscribers.append(subscriber) + + # Subscribe to the battery topics + subscriber = self.create_subscription( + BatteryState, + f'/{ns}/battery_state', + partial(self.battery_callback, num = i), + qos_profile_sensor_data) + self.battery_subscribers.append(subscriber) + + + timer_period = 0.0833 # seconds + self.timer = self.create_timer(timer_period, self.updateDisplay) self.bridge = CvBridge() @@ -82,48 +120,83 @@ def semaphore_flag_callback(self, letter_msg: string_msg): (0, 190, 255), 3) - def body_pose_callback(self, pose_msg: PoseArray): - # self.get_logger().info('body_pose_callback') + def body_pose_callback(self, pose_msg: PoseArray, num: int): + self.get_logger().info(f'Body_pose_callback {num} - start') temp_keypoints = [] temp_scores = [] for i, point in enumerate(pose_msg.poses): temp_keypoints.append((int(point.position.x), int(point.position.y))) temp_scores.append(point.position.z) - self.body.keypoints = np.array(temp_keypoints) - self.body.scores = np.array(temp_scores) - - def frame_callback(self, image_msg: Image): - # self.get_logger().info('frame_callback') + b = Body() + b.keypoints = np.array(temp_keypoints) + b.scores = np.array(temp_scores) + self.body[num] = b + # self.get_logger().info(f'Body_pose_callback {num} - end') + + def frame_callback(self, image_msg: Image, num: int): + self.get_logger().info(f'Frame_callback {num} - start') if image_msg.data is None: return - self.frame = self.bridge.imgmsg_to_cv2(image_msg, "bgr8") - if self.body.keypoints is not None: - self.draw() - self.waitKey() - - def draw(self): - lines = [np.array([self.body.keypoints[point] for point in line]) - for line in LINES_BODY if self.body.scores[line[0]] > SCORE_THRESH and - self.body.scores[line[1]] > SCORE_THRESH] + self.frame[num] = self.bridge.imgmsg_to_cv2(image_msg, "bgr8") + if self.body[num] is not None and self.body[num].keypoints is not None: + self.draw(num) + self.updateFrame(num) + # self.get_logger().info(f'Frame_callback {num} - end') + + def draw(self, num: int): + lines = [np.array([self.body[num].keypoints[point] for point in line]) + for line in LINES_BODY if self.body[num].scores[line[0]] > SCORE_THRESH and + self.body[num].scores[line[1]] > SCORE_THRESH] if lines is not None: - cv2.polylines(self.frame, lines, False, (255, 180, 90), 2, cv2.LINE_AA) + cv2.polylines(self.frame[num], lines, False, (255, 180, 90), 2, cv2.LINE_AA) - for i, x_y in enumerate(self.body.keypoints): - if self.body.scores[i] > SCORE_THRESH: + for i, x_y in enumerate(self.body[num].keypoints): + if self.body[num].scores[i] > SCORE_THRESH: if i % 2 == 1: color = (0, 255, 0) elif i == 0: color = (0, 255, 255) else: color = (0, 0, 255) - cv2.circle(self.frame, (x_y[0], x_y[1]), 4, color, -11) + cv2.circle(self.frame[num], (x_y[0], x_y[1]), 4, color, -11) + + def updateFrame(self, num: int): + self.get_logger().info(f'Updated frame {num}') + x = num%self.tile_x + y = int((num - x)/self.tile_y) + cv2.putText(self.frame[num], + f'{self.namespaces[num]}', + (50, 50), + # (self.frame[num].shape[1] // 2, 100), + cv2.FONT_HERSHEY_PLAIN, + 2, + (0, 0, 255), + 2) + if self.percentage[num]: + cv2.putText(self.frame[num], + f'{self.percentage[num]:.1f}%', + (self.frame[num].shape[1] - 120, self.frame[num].shape[0] - 20), + cv2.FONT_HERSHEY_PLAIN, + 2, + (0, 0, 255), + 2) + + self.full_frame[ + int(y*self.image_height):int((y+1)*self.image_height), + int(x*self.image_width):int((x+1)*self.image_width), + 0:3] = self.frame[num] + + def updateDisplay(self): + self.get_logger().info(f'Updated display') + cv2.imshow("Movenet", self.full_frame) + cv2.waitKey(1) - def waitKey(self, delay=1): + def waitKey(self, delay=0.1): # if self.show_fps: # self.pose.fps.draw(self.frame, orig=(50,50), size=1, color=(240,180,100)) - cv2.imshow("Movenet", self.frame) + cv2.imshow("Movenet", self.full_frame) if self.output: - self.output.write(self.frame) + self.output.write(self.full_frame) key = cv2.waitKey(delay) if key == 32: # Pause on space bar @@ -134,10 +207,12 @@ def waitKey(self, delay=1): self.show_crop = not self.show_crop return key + def battery_callback(self, batt_msg: BatteryState, num: int): + self.percentage[num] = batt_msg.percentage * 100 def main(args=None): rclpy.init(args=args) - node = PoseDetection() + node = PoseDisplay() rclpy.spin(node) node.destroy_node() rclpy.shutdown()