diff --git a/birdseye/mqtt.py b/birdseye/mqtt.py index 9615b358..d5200d7c 100644 --- a/birdseye/mqtt.py +++ b/birdseye/mqtt.py @@ -5,13 +5,16 @@ class BirdsEyeMQTT: - def __init__(self, mqtt_host, mqtt_port, message_handler): + def __init__(self, mqtt_host, mqtt_port, topics): + self.mqtt_host = mqtt_host + self.mqtt_port = mqtt_port + self.topics = topics try: self.client = paho.mqtt.client.Client() self.client.on_connect = self.on_connect - self.client.on_message = self.on_message + # self.client.on_message = self.on_message self.client.on_publish = self.on_publish - self.message_handler = message_handler + # self.message_handler = message_handler self.client.connect(mqtt_host, mqtt_port, 60) self.client.loop_start() @@ -24,16 +27,20 @@ def __init__(self, mqtt_host, mqtt_port, message_handler): ) sys.exit(1) - def on_message(self, client, userdata, json_message): - json_data = json.loads(json_message.payload) - self.message_handler(json_data) + def on_message_func(self, message_handler): + def on_message(client, userdata, json_message): + json_data = json.loads(json_message.payload) + message_handler(json_data) + + return on_message def on_connect(self, client, userdata, flags, result_code): - sub_channel = "gamutrf/rssi" logging.info( - "Connected to %s with result code %s", sub_channel, str(result_code) + f"Connected to MQTT broker {self.mqtt_host}:{self.mqtt_port} with result code {result_code}" ) - self.client.subscribe(sub_channel) # also tried qos = 1 and 1 + for topic, handler in self.topics: + self.client.subscribe(topic) + self.client.message_callback_add(topic, self.on_message_func(handler)) def on_publish(self, client, userdata, mid): logging.info("Completed transmission to broker.") diff --git a/birdseye/mqtt_fake.py b/birdseye/mqtt_fake.py index d8698306..1c061cd7 100644 --- a/birdseye/mqtt_fake.py +++ b/birdseye/mqtt_fake.py @@ -20,60 +20,130 @@ def message_handler(data): logging.info("Message handler received data: {}".format(data)) - mqtt_client = BirdsEyeMQTT("localhost", 1883, message_handler) + mqtt_client = BirdsEyeMQTT( + "localhost", + 1883, + [("gamutrf/inference", message_handler), ("gamutrf/targets", message_handler)], + ) - publish_data = { + sensor_data = { "metadata": { - "image_path": "/logs/inference/image_1700632876.885_640x640_2408895999Hz.png", - "orig_rows": "28897", - "rssi_max": "-46.058800", - "rssi_mean": "-69.604271", - "rssi_min": "-117.403526", - "rx_freq": "2408895999", - "ts": "1700632876.885" - }, + "image_path": "/logs/inference/image_1700632876.885_640x640_2408895999Hz.png", + "orig_rows": "28897", + "rssi_max": "-46.058800", + "rssi_mean": "-69.604271", + "rssi_min": "-117.403526", + "rx_freq": "2408895999", + "ts": "1700632876.885", + }, "predictions": { "mini2_telem": [ { - "rssi": "-40", - "conf": "0.33034399151802063", - "xywh": [609.8685302734375, 250.76278686523438, 20.482666015625, 7.45684814453125] + "rssi": "-40", + "conf": "0.33034399151802063", + "xywh": [ + 609.8685302734375, + 250.76278686523438, + 20.482666015625, + 7.45684814453125, + ], } ], "mini2_video": [ { - "rssi": "-80", - "conf": "0.33034399151802063", - "xywh": [609.8685302734375, 250.76278686523438, 20.482666015625, 7.45684814453125] + "rssi": "-80", + "conf": "0.33034399151802063", + "xywh": [ + 609.8685302734375, + 250.76278686523438, + 20.482666015625, + 7.45684814453125, + ], } - ] - }, + ], + }, "position": [32.922651, -117.120815], "heading": 0, "rssi": [-40, -60], - "gps": "fix" + "gps": "fix", + } + target_data = { + "altitude": 4700, + "gps_fix_type": 2, + "gps_stale": "false", + "heading": 174.15, + "latitude": 32.922651, + "longitude": -117.120815, + "relative_alt": 4703, + "target_name": "drone1", + "time_boot_ms": 904414, + "time_usec": None, + "vx": 0.0, + "vy": 0.0, + "vz": 0.0, } + control_options = {"sensor": sensor_data, "target": target_data} + control_map = {i: k for (i, k) in enumerate(control_options)} + control_key = None + def on_key_release(key): + global control_key + + if key == Key.esc: + exit() + + try: + if key.char == "0": + print(f"\nSelected sensor. Up/Down/Left/Right will control sensor.") + control_key = "sensor" + elif key.char == "1": + print(f"\nSelected target 1. Up/Down/Left/Right will control target 1.") + control_key = "target" + return + except: + pass + + if control_key is None: + print( + f"\nPlease select the device to control by pressing a number from {list(range(len(control_options)))}." + ) + print(f"{control_map}") + return + if key == Key.right: - # print("Right key clicked") - publish_data["position"][1] += 0.0001 + print("\nRight key pressed\n") + if control_key == "sensor": + sensor_data["position"][1] += 0.0001 + elif control_key == "target": + target_data["longitude"] += 0.0001 elif key == Key.left: - # print("Left key clicked") - publish_data["position"][1] -= 0.0001 + print("\nLeft key pressed\n") + if control_key == "sensor": + sensor_data["position"][1] -= 0.0001 + elif control_key == "target": + target_data["longitude"] -= 0.0001 elif key == Key.up: - # print("Up key clicked") - publish_data["position"][0] += 0.0001 + print("\nUp key pressed\n") + if control_key == "sensor": + sensor_data["position"][0] += 0.0001 + elif control_key == "target": + target_data["latitude"] += 0.0001 elif key == Key.down: - # print("Down key clicked") - publish_data["position"][0] -= 0.0001 - elif key == Key.esc: - exit() + print("\nDown key pressed\n") + if control_key == "sensor": + sensor_data["position"][0] -= 0.0001 + elif control_key == "target": + target_data["latitude"] -= 0.0001 - mqtt_client.client.publish( - "gamutrf/rssi", json.dumps(publish_data) - ) # also tried qos = 1 and 2 - logging.info("Started transmission to broker: {}".format(publish_data)) + if control_key == "sensor": + mqtt_client.client.publish( + "gamutrf/inference", json.dumps(sensor_data) + ) # also tried qos = 1 and 2 + logging.info("Started transmission to broker: {}".format(sensor_data)) + elif control_key == "target": + mqtt_client.client.publish("gamutrf/targets", json.dumps(target_data)) + logging.info("Started transmission to broker: {}".format(target_data)) with keyboard.Listener(on_release=on_key_release) as listener: listener.join() diff --git a/birdseye/utils.py b/birdseye/utils.py index 6a8d0ff2..a32405ee 100644 --- a/birdseye/utils.py +++ b/birdseye/utils.py @@ -676,7 +676,7 @@ def __init__( config={}, enable_heatmap=False, enable_gps_plot=False, - class_map={} + class_map={}, ): self.num_iters = num_iters self.experiment_name = experiment_name @@ -729,12 +729,14 @@ def __init__( self.target_hist = [] self.sensor_hist = [] self.sensor_gps_hist = [] + self.target_gps_hist = {} self.history_length = 10 self.time_step = 0 self.texts = [] self.openstreetmap = None self.transform = None self.expected_target_rssi = None + self.target_only_map = False if config: write_config_log(config, self.logdir) @@ -800,11 +802,46 @@ def live_plot( """ Create a live plot """ + + lines = ( + [] + ) # https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.pyplot.legend.html + legend_elements = [] + separable_color_array = [ + ["deepskyblue", "blue"], + ["pink", "red"], + ["wheat", "orange"], + ["lightgreen", "green"], + ] + color_array = [ + ["salmon", "darkred", "red"], + ["lightskyblue", "darkblue", "blue"], + ] + sensor_color = "green" + + # Target only openstreetmap + if self.openstreetmap is None and data.get("targets", None): + self.target_only_map = True + print(data["targets"]) + self.openstreetmap = GPSVis( + position=data["targets"][list(data["targets"])[0]]["position"], + distance=map_distance, + ) + self.openstreetmap.set_origin( + data["targets"][list(data["targets"])[0]]["position"] + ) + self.transform = np.array( + [self.openstreetmap.origin[0], self.openstreetmap.origin[1]] + ) + + # Openstreetmap if ( - self.openstreetmap is None + (self.openstreetmap is None or self.target_only_map) and data.get("position", None) is not None and data.get("heading", None) is not None ): + self.target_only_map = False + self.target_gps_hist = {} self.openstreetmap = GPSVis( position=data["position"], distance=map_distance ) @@ -814,6 +851,8 @@ def live_plot( ) self.time_step = time_step + + # Particle filter statistics self.pf_stats["mean_hypothesis"].append( env.pf.mean_hypothesis if hasattr(env.pf, "mean_hypothesis") else [None] ) @@ -827,16 +866,58 @@ def live_plot( env.pf.map_state if hasattr(env.pf, "map_state") else [None] ) + # Sensor state history (from internal state) abs_sensor = env.state.sensor_state - abs_particles = env.get_absolute_particles() self.sensor_hist.append(abs_sensor) + # Sensor state history (from gps) + if self.openstreetmap and data.get("position", None) is not None: + self.sensor_gps_hist.append( + self.openstreetmap.scale_to_img( + data["position"], + (self.openstreetmap.width_meters, self.openstreetmap.height_meters), + ) + ) + + # Particle states + abs_particles = env.get_absolute_particles() + + # Target state history (from internal state) if env.simulated: self.target_hist.append(env.get_absolute_target()) - target_heading = None - target_relative_heading = None + # Target state history (from gps) + if self.openstreetmap and data.get("target_gps", None) is not None: + for t in data["targets"]: + if t not in self.target_gps_hist: + self.target_gps_hist[t] = [] + self.target_gps_hist[t].append( + self.openstreetmap.scale_to_img( + data["targets"][t]["position"], + ( + self.openstreetmap.width_meters, + self.openstreetmap.height_meters, + ), + ) + ) + print( + f"\n{t}:{data['targets'][t]['position']}, {self.target_gps_hist[t][-1]}\n" + ) + # Target state history (old version; from file ) + if self.openstreetmap and data.get("drone_position", None) is not None: + # print(f"{data['drone_position']=}") + self.target_hist.append( + self.openstreetmap.scale_to_img( + data["drone_position"], + (self.openstreetmap.width_meters, self.openstreetmap.height_meters), + ) + ) + + # target_heading = None + # target_relative_heading = None + + # Calculate extra data if target is known if ( data.get("position", None) is not None and data.get("drone_position", None) is not None @@ -857,194 +938,133 @@ def live_plot( ax.clear() if self.openstreetmap is not None: self.openstreetmap.plot_map(axis1=ax) - # TODO get variables + if separable: ax.set_title("Time = {}".format(time_step)) else: abs_particles = np.moveaxis(abs_particles, 1, 0) - # ax.set_title( - # "Time = {}, Frequency = {}, Bandwidth = {}, Gain = {}".format( - # time_step, None, None, None - # ) - # ) ax.set_title( f"Time = {str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))}" ) - color_array = [ - ["salmon", "darkred", "red"], - ["lightskyblue", "darkblue", "blue"], - ] - lines = ( - [] - ) # https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.pyplot.legend.html - - separable_color_array = [ - ["deepskyblue", "blue"], - ["pink", "red"], - ["wheat", "orange"], - ["lightgreen", "green"], - ] - sensor_color = "green" - legend_elements = [] - for t in range(env.state.n_targets): - # PLOT PARTICLES - # particles_x, particles_y = pol2cart( - # abs_particles[:, t, 0], np.radians(abs_particles[:, t, 1]) - # ) - particles_x, particles_y = pol2cart( - abs_particles[t, :, 0], np.radians(abs_particles[t, :, 1]) - ) - if self.transform is not None: - particles_x += self.transform[0] - particles_y += self.transform[1] - if separable: - particle_color = separable_color_array[t][0] - else: - particle_color = "salmon" - (line1,) = ax.plot( - particles_x, - particles_y, - "o", - color=particle_color, - markersize=4, - markeredgecolor="black", - label="Particles", - alpha=0.4, - zorder=1, - ) - - # PLOT HEATMAP OVER STREET MAP - if self.enable_heatmap and self.openstreetmap: - heatmap, xedges, yedges = np.histogram2d( + # Plot particles + if env.simulated or (self.openstreetmap and not self.target_only_map): + for t in range(env.state.n_targets): + particles_x, particles_y = pol2cart( + abs_particles[t, :, 0], np.radians(abs_particles[t, :, 1]) + ) + if self.transform is not None: + particles_x += self.transform[0] + particles_y += self.transform[1] + if separable: + particle_color = separable_color_array[t][0] + else: + particle_color = "salmon" + (line1,) = ax.plot( particles_x, particles_y, - bins=(self.openstreetmap.xedges, self.openstreetmap.yedges), - ) - heatmap = gaussian_filter(heatmap, sigma=8) - extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] - im = ax.imshow( - heatmap.T, - extent=extent, - origin="lower", - cmap="jet", - interpolation="nearest", - alpha=0.2, + "o", + color=particle_color, + markersize=4, + markeredgecolor="black", + label="Particles", + alpha=0.4, + zorder=1, ) - # plt.colorbar(im) - # PLOT CENTROIDS - centroid_x = np.mean(particles_x) - centroid_y = np.mean(particles_y) - if separable: - centroid_color = separable_color_array[t][1] - else: - centroid_color = "magenta" + # PLOT HEATMAP OVER STREET MAP + if self.enable_heatmap and self.openstreetmap: + heatmap, xedges, yedges = np.histogram2d( + particles_x, + particles_y, + bins=(self.openstreetmap.xedges, self.openstreetmap.yedges), + ) + heatmap = gaussian_filter(heatmap, sigma=8) + extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] + im = ax.imshow( + heatmap.T, + extent=extent, + origin="lower", + cmap="jet", + interpolation="nearest", + alpha=0.2, + ) + # plt.colorbar(im) - (line2,) = ax.plot( - centroid_x, - centroid_y, - "^", - color=centroid_color, - markeredgecolor="black", - label="Mean Estimate", - markersize=20, - zorder=7, - ) + # PLOT CENTROIDS + centroid_x = np.mean(particles_x) + centroid_y = np.mean(particles_y) + if separable: + centroid_color = separable_color_array[t][1] + else: + centroid_color = "magenta" - if t == 0: - lines.extend([line1, line2]) - else: - lines.extend([]) + (line2,) = ax.plot( + centroid_x, + centroid_y, + "^", + color=centroid_color, + markeredgecolor="black", + label="Mean Estimate", + markersize=20, + zorder=7, + ) - target_class_name = f"Target {t}" - for class_name, class_idx in self.class_map.items(): - if t == class_idx: - target_class_name = class_name + if t == 0: + lines.extend([line1, line2]) + else: + lines.extend([]) + target_class_name = f"Target {t} particles" + for class_name, class_idx in self.class_map.items(): + if t == class_idx: + target_class_name = f"{class_name} particles" + + legend_elements.append( + Line2D( + [0], + [0], + marker="o", + color="white", + markersize=8, + markeredgecolor="black", + markerfacecolor=centroid_color, + label=target_class_name, + ) + ) legend_elements.append( Line2D( [0], [0], - marker="o", + marker="^", color="white", - markersize=8, markeredgecolor="black", - markerfacecolor=centroid_color, - label=target_class_name, - ) - ) - - # PLOT SENSOR - if ( - self.enable_gps_plot and self.openstreetmap and data["position"] is not None - ): # and not data["needs_processing"]: - self.sensor_gps_hist.append( - self.openstreetmap.scale_to_img( - data["position"], - (self.openstreetmap.width_meters, self.openstreetmap.height_meters), + markerfacecolor="white", + label="Avg Estimates", + markersize=12, ) ) - temp_np = np.array(self.sensor_gps_hist) - sensor_x = temp_np[:, 0] - sensor_y = temp_np[:, 1] - - if len(self.sensor_gps_hist) > 1: - # print(f"{data['heading']=}") - # print(f"{data['previous_heading']=}") + # Plot sensor + if env.simulated or self.sensor_gps_hist: + if self.sensor_gps_hist: + temp_np = np.array(self.sensor_gps_hist) + sensor_x = temp_np[:, 0] + sensor_y = temp_np[:, 1] arrow_x, arrow_y = pol2cart( - 4, np.radians(data.get("heading", data["previous_heading"])) - ) - ax.arrow( - sensor_x[-1], - sensor_y[-1], - arrow_x, - arrow_y, - width=1.5, - color="green", - zorder=4, + 6, np.radians(data.get("heading", data["previous_heading"])) ) - ax.plot( - sensor_x[len(sensor_x)-self.history_length:-1], - sensor_y[len(sensor_y)-self.history_length:-1], - linewidth=3.0, - color="green", - markeredgecolor="black", - markersize=4, - zorder=4, + else: + sensor_x, sensor_y = pol2cart( + np.array(self.sensor_hist)[:, 0], + np.radians(np.array(self.sensor_hist)[:, 1]), ) - (line4,) = ax.plot( - sensor_x[-1], - sensor_y[-1], - "p", - color="green", - label="SensorGPS", - markersize=10, - zorder=4, - ) - lines.extend([line4]) - legend_elements.append( - mpatches.Patch(facecolor="green", edgecolor="black", label="SensorGPS") - ) + if self.transform is not None: + sensor_x += self.transform[0] + sensor_y += self.transform[1] - sensor_x, sensor_y = pol2cart( - np.array(self.sensor_hist)[:, 0], - np.radians(np.array(self.sensor_hist)[:, 1]), - ) - if self.transform is not None: - sensor_x += self.transform[0] - sensor_y += self.transform[1] - - if len(self.sensor_hist) > 1: - arrow_x, arrow_y = pol2cart( - 6, np.radians(env.state.sensor_state[2]) - ) + arrow_x, arrow_y = pol2cart(6, np.radians(env.state.sensor_state[2])) line4 = mpatches.FancyArrow( - # sensor_x[-2], - # sensor_y[-2], - # 4 * (sensor_x[-1] - sensor_x[-2]), - # 4 * (sensor_y[-1] - sensor_y[-2]), sensor_x[-1], sensor_y[-1], arrow_x, @@ -1059,54 +1079,39 @@ def live_plot( linewidth=1, ) ax.add_patch(line4) - ax.plot( - sensor_x[len(sensor_x)-self.history_length:], - sensor_y[len(sensor_x)-self.history_length:], - linewidth=5, - color=sensor_color, - # markeredgecolor="black", - # markersize=4, - zorder=6, - # path_effects=[pe.Stroke(linewidth=7, foreground='black')] - ) - # (line4,) = ax.plot( - # sensor_x[-1], - # sensor_y[-1], - # "p", - # color="blue", - # label="sensor", - # markersize=10, - # zorder=4, - # ) lines.extend([line4]) + if len(self.sensor_hist) > 1: + ax.plot( + sensor_x[len(sensor_x) - self.history_length :], + sensor_y[len(sensor_x) - self.history_length :], + linewidth=5, + color=sensor_color, + # markeredgecolor="black", + # markersize=4, + zorder=6, + # path_effects=[pe.Stroke(linewidth=7, foreground='black')] + ) legend_elements.append( mpatches.Patch( facecolor=sensor_color, edgecolor="black", label="Sensor" ) ) - if self.openstreetmap and data.get("drone_position", None) is not None: - # print(f"{data['drone_position']=}") - self.target_hist.append( - self.openstreetmap.scale_to_img( - data["drone_position"], - (self.openstreetmap.width_meters, self.openstreetmap.height_meters), - ) - ) - - # PLOT TARGETS - if self.target_hist: - # target_np = np.array(self.target_hist) - # print(f"{self.target_hist[-1]=}") - # assert len(self.target_hist.shape) == 3 - for t in range(env.state.n_targets): + # Plot targets + if self.target_hist or self.target_gps_hist: + if env.simulated: + n_target_hist = env.state.n_targets + else: + n_target_hist = len(self.target_gps_hist) + for t in range(n_target_hist): if env.simulated: target_x, target_y = pol2cart( np.array(self.target_hist)[:, t, 0], np.radians(np.array(self.target_hist)[:, t, 1]), ) else: - temp_np = np.array(self.target_hist) + temp_np = np.array(list(self.target_gps_hist.values())[t]) + # temp_np = np.array(self.target_hist) target_x = temp_np[:, 0] target_y = temp_np[:, 1] @@ -1114,53 +1119,53 @@ def live_plot( # target_x += self.transform[0] # target_y += self.transform[1] - if len(self.target_hist) > 1: + if len(target_x) > 1: ax.plot( - target_x[:-1], - target_y[:-1], + target_x, + target_y, linewidth=3.0, color="black", zorder=3, markersize=4, ) + (line5,) = ax.plot( target_x[-1], target_y[-1], "X", color="black", markeredgecolor="black", - label="Targets", - markersize=10, + # label="Targets", + markersize=8, zorder=3, ) - lines.extend([line5]) - legend_elements.append( - Line2D( - [0], - [0], - marker="X", - color="white", - markerfacecolor="black", - markeredgecolor="black", - label="Targets", - markersize=10, + + target_class_name = f"Target {t}" + if self.target_gps_hist: + target_class_name = list(self.target_gps_hist)[t] + ax.text( + target_x[-1], + target_y[-1], + f"{target_class_name}", + color="black", + fontsize=16, + fontweight="bold", + ) + # lines.extend([line5]) + legend_elements.append( + Line2D( + [0], + [0], + marker="X", + color="white", + markerfacecolor="black", + markeredgecolor="black", + label=target_class_name, + markersize=10, + ) ) - ) # Legend - legend_elements.append( - Line2D( - [0], - [0], - marker="^", - color="white", - markeredgecolor="black", - markerfacecolor="white", - label="Avg Estimates", - markersize=12, - ) - ) - ax.legend( handles=legend_elements, loc="upper center", @@ -1385,8 +1390,8 @@ def build_multitarget_plots( # these are matplotlib.patch.Patch properties props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) - self.history_length = 50 - + self.history_length = 50 + if len(self.abs_target_hist) < self.history_length: self.abs_target_hist = [abs_target] * self.history_length self.abs_sensor_hist = [abs_sensor] * self.history_length diff --git a/geolocate.py b/geolocate.py index d64fb694..6c7d83ae 100644 --- a/geolocate.py +++ b/geolocate.py @@ -11,7 +11,16 @@ import time from datetime import datetime -from flask import Flask, render_template, send_from_directory, redirect, request, url_for, make_response, jsonify +from flask import ( + Flask, + render_template, + send_from_directory, + redirect, + request, + url_for, + make_response, + jsonify, +) from io import BytesIO from timeit import default_timer as timer @@ -69,24 +78,26 @@ def real_observation(self): if observation is None: return default_observation - - if type(observation) == dict: + + if type(observation) == dict: format_observation = default_observation for class_name in observation: - self.class_map[class_name] = self.class_map.get(class_name, len(self.class_map)) + self.class_map[class_name] = self.class_map.get( + class_name, len(self.class_map) + ) format_observation[self.class_map[class_name]] = observation[class_name] observation = format_observation if type(observation) != list: observation = [observation] - if len(observation) != self.n_targets: + if len(observation) != self.n_targets: raise ValueError("len(observation) != n_targets") - for i in range(len(observation)): + for i in range(len(observation)): if observation[i] and observation[i] < self.threshold: observation[i] = None - + return observation @@ -111,6 +122,8 @@ def __init__(self, config_path="geolocate.ini"): "action_taken": None, "needs_processing": False, "gps": None, + "targets": {}, + "target_gps": None, } config = configparser.ConfigParser() config.read(config_path) @@ -149,6 +162,25 @@ def __init__(self, config_path="geolocate.ini"): default_config.update(self.config) self.config = default_config + def target_handler(self, message_data): + logging.info(f"Received gamutrf/target MQTT message: {message_data}") + + if ( + message_data["gps_stale"].lower() != "false" + or int(message_data["gps_fix_type"]) < 2 + ): + return "No target GPS." + + self.data["target_gps"] = "fix" + target_name = message_data["target_name"] + if target_name not in self.data["targets"]: + self.data["targets"][target_name] = {"idx": len(self.data["targets"])} + + self.data["targets"][target_name]["position"] = ( + message_data["latitude"], + message_data["longitude"], + ) + def data_handler(self, message_data): """ Generic data processor @@ -179,19 +211,21 @@ def data_handler(self, message_data): metadata = message_data.get("metadata", None) if predictions: self.data["rssi"] = {} - for class_name in predictions.keys(): - + for class_name in predictions.keys(): self.data["rssi"][class_name] = np.mean( - [float(prediction.get("rssi_max", metadata["rssi_max"])) for prediction in predictions[class_name]] + [ + float(prediction.get("rssi_max", metadata["rssi_max"])) + for prediction in predictions[class_name] + ] ) - elif metadata: + elif metadata: # TODO: how to handle when tracking multiple targets - #self.data["rssi"] = float(metadata["rssi_mean"]) - #self.data["rssi"] = float(metadata["rssi_min"]) - #self.data["rssi"] = float(metadata.get["rssi_max"]) + # self.data["rssi"] = float(metadata["rssi_mean"]) + # self.data["rssi"] = float(metadata["rssi_min"]) + # self.data["rssi"] = float(metadata.get["rssi_max"]) pass - else: + else: self.data["rssi"] = message_data.get("rssi", None) self.data["position"] = message_data.get("position", self.data["position"]) @@ -237,32 +271,34 @@ def run_flask(self, flask_host, flask_port, fig, results): """ app = Flask(__name__) - @app.route('/gui/') + @app.route("/gui/") def gui_file(filename): - return send_from_directory('gui', filename) + return send_from_directory("gui", filename) @app.route("/gui/data") def gui_data(): data = base64.b64encode(self.image_buf.getvalue()).decode("ascii") - img = "data:image/png;base64,"+data + img = "data:image/png;base64," + data return jsonify(img) - @app.route('/refresh') + @app.route("/refresh") def refresh(): - os.makedirs('gui', exist_ok=True) + os.makedirs("gui", exist_ok=True) with open("gui/map.png", "wb") as img: img.write(self.image_buf.getbuffer()) - # newmapp = np.random.rand(500,500,3) * 255 + # newmapp = np.random.rand(500,500,3) * 255 # data = Image.fromarray(newmapp.astype('uint8')).convert('RGBA') # data.save('gui/map.png') - return "OK" + return "OK" - @app.route('/gui/form', methods = ['POST', 'GET']) + @app.route("/gui/form", methods=["POST", "GET"]) def gui_form(): - if request.method == 'POST': - user = request.form.get('name', None) - self.config['n_targets'] = request.form.get('n_targets', self.config['n_targets']) - reset = request.form.get('reset', None) + if request.method == "POST": + user = request.form.get("name", None) + self.config["n_targets"] = request.form.get( + "n_targets", self.config["n_targets"] + ) + reset = request.form.get("reset", None) if reset == "reset": self.stop() self.image_buf = BytesIO() @@ -281,7 +317,6 @@ def gui_from_buffer(): return render_template("loading.html") return render_template("gui_from_buffer.html", config=self.config) - @app.route("/") def index(): flask_start_time = timer() @@ -313,13 +348,13 @@ def index(): host=host_name, port=port, debug=False, use_reloader=False ) ).start() - + def get_replay_json(self, replay_file): with open(replay_file, "r", encoding="UTF-8") as open_file: replay_data = json.load(open_file) for ts in replay_data: yield replay_data[ts] - + def get_replay_log(self, replay_file): with open(replay_file, "r", encoding="UTF-8") as open_file: for line in open_file: @@ -328,15 +363,17 @@ def get_replay_log(self, replay_file): def start(self): self.stop_threads = False - self.main_thread = threading.Thread(target=self.main, args=[lambda: self.stop_threads]) + self.main_thread = threading.Thread( + target=self.main, args=[lambda: self.stop_threads] + ) self.main_thread.start() logging.info("Main thread started.") - def stop(self): + def stop(self): self.stop_threads = True self.main_thread.join() logging.info("Main thread stopped.") - + def main(self, stopped): """ Main loop @@ -398,9 +435,11 @@ def main(self, stopped): ###### MQTT or replay from file if replay_file is None: - mqtt_client = birdseye.mqtt.BirdsEyeMQTT( - mqtt_host, mqtt_port, self.data_handler - ) + topics = [ + ("gamutrf/inference", self.data_handler), + ("gamutrf/targets", self.target_handler), + ] + mqtt_client = birdseye.mqtt.BirdsEyeMQTT(mqtt_host, mqtt_port, topics) else: if replay_file.endswith(".log"): get_replay_data = self.get_replay_log(replay_file) @@ -520,7 +559,12 @@ def main(self, stopped): control_actions = [] step_time = 0 - while self.data["gps"] != "fix" and not replay_file and not stopped(): + while ( + self.data["gps"] != "fix" + and self.data["target_gps"] != "fix" + and not replay_file + and not stopped() + ): time.sleep(1) logging.info("Waiting for GPS...") @@ -530,11 +574,11 @@ def main(self, stopped): if replay_file: # load data from saved file - try: + try: replay_data = next(get_replay_data) except StopIteration: break - + self.data_handler(replay_data) action_start = timer() @@ -629,7 +673,9 @@ def main(self, stopped): if __name__ == "__main__": # pragma: no cover parser = argparse.ArgumentParser() - parser.add_argument("config_path", help="Path to config file, geolocate.ini provided as example.") + parser.add_argument( + "config_path", help="Path to config file, geolocate.ini provided as example." + ) parser.add_argument("--log", default="INFO", help="Log level") args = parser.parse_args()