From be8ef38bf71f0e60b4c37d145672583a2c8ed37c Mon Sep 17 00:00:00 2001 From: Lucas Tindall Date: Tue, 23 Jan 2024 15:20:47 -0800 Subject: [PATCH] formatting --- birdseye/mqtt.py | 7 +-- birdseye/mqtt_fake.py | 103 ++++++++++++++++++++----------------- birdseye/utils.py | 86 +++++++++++++++---------------- geolocate.py | 117 +++++++++++++++++++++++++----------------- 4 files changed, 170 insertions(+), 143 deletions(-) diff --git a/birdseye/mqtt.py b/birdseye/mqtt.py index c9f9d70e..d5200d7c 100644 --- a/birdseye/mqtt.py +++ b/birdseye/mqtt.py @@ -12,9 +12,9 @@ def __init__(self, mqtt_host, mqtt_port, topics): try: self.client = paho.mqtt.client.Client() self.client.on_connect = self.on_connect - #self.client.on_message = self.on_message + # self.client.on_message = self.on_message self.client.on_publish = self.on_publish - #self.message_handler = message_handler + # self.message_handler = message_handler self.client.connect(mqtt_host, mqtt_port, 60) self.client.loop_start() @@ -31,6 +31,7 @@ def on_message_func(self, message_handler): def on_message(client, userdata, json_message): json_data = json.loads(json_message.payload) message_handler(json_data) + return on_message def on_connect(self, client, userdata, flags, result_code): @@ -38,7 +39,7 @@ def on_connect(self, client, userdata, flags, result_code): f"Connected to MQTT broker {self.mqtt_host}:{self.mqtt_port} with result code {result_code}" ) for topic, handler in self.topics: - self.client.subscribe(topic) + self.client.subscribe(topic) self.client.message_callback_add(topic, self.on_message_func(handler)) def on_publish(self, client, userdata, mid): diff --git a/birdseye/mqtt_fake.py b/birdseye/mqtt_fake.py index ba8997d3..1c061cd7 100644 --- a/birdseye/mqtt_fake.py +++ b/birdseye/mqtt_fake.py @@ -20,76 +20,84 @@ def message_handler(data): logging.info("Message handler received data: {}".format(data)) - mqtt_client = BirdsEyeMQTT("localhost", 1883, [("gamutrf/inference",message_handler),("gamutrf/targets", message_handler)]) + mqtt_client = BirdsEyeMQTT( + "localhost", + 1883, + [("gamutrf/inference", message_handler), ("gamutrf/targets", message_handler)], + ) sensor_data = { "metadata": { - "image_path": "/logs/inference/image_1700632876.885_640x640_2408895999Hz.png", - "orig_rows": "28897", - "rssi_max": "-46.058800", - "rssi_mean": "-69.604271", - "rssi_min": "-117.403526", - "rx_freq": "2408895999", - "ts": "1700632876.885" - }, + "image_path": "/logs/inference/image_1700632876.885_640x640_2408895999Hz.png", + "orig_rows": "28897", + "rssi_max": "-46.058800", + "rssi_mean": "-69.604271", + "rssi_min": "-117.403526", + "rx_freq": "2408895999", + "ts": "1700632876.885", + }, "predictions": { "mini2_telem": [ { - "rssi": "-40", - "conf": "0.33034399151802063", - "xywh": [609.8685302734375, 250.76278686523438, 20.482666015625, 7.45684814453125] + "rssi": "-40", + "conf": "0.33034399151802063", + "xywh": [ + 609.8685302734375, + 250.76278686523438, + 20.482666015625, + 7.45684814453125, + ], } ], "mini2_video": [ { - "rssi": "-80", - "conf": "0.33034399151802063", - "xywh": [609.8685302734375, 250.76278686523438, 20.482666015625, 7.45684814453125] + "rssi": "-80", + "conf": "0.33034399151802063", + "xywh": [ + 609.8685302734375, + 250.76278686523438, + 20.482666015625, + 7.45684814453125, + ], } - ] - }, + ], + }, "position": [32.922651, -117.120815], "heading": 0, "rssi": [-40, -60], - "gps": "fix" + "gps": "fix", } target_data = { - - "altitude":4700, - "gps_fix_type":2, - "gps_stale":"false", - "heading":174.15, - "latitude":32.922651, - "longitude":-117.120815, - "relative_alt":4703, - "target_name":"drone1", - "time_boot_ms":904414, - "time_usec":None, - "vx":0.0, - "vy":0.0, - "vz":0.0 - + "altitude": 4700, + "gps_fix_type": 2, + "gps_stale": "false", + "heading": 174.15, + "latitude": 32.922651, + "longitude": -117.120815, + "relative_alt": 4703, + "target_name": "drone1", + "time_boot_ms": 904414, + "time_usec": None, + "vx": 0.0, + "vy": 0.0, + "vz": 0.0, } - control_options = { - "sensor": sensor_data, - "target": target_data - } - control_map = {i:k for (i,k) in enumerate(control_options)} + control_options = {"sensor": sensor_data, "target": target_data} + control_map = {i: k for (i, k) in enumerate(control_options)} control_key = None def on_key_release(key): - - global control_key + global control_key if key == Key.esc: exit() try: - if key.char == '0': + if key.char == "0": print(f"\nSelected sensor. Up/Down/Left/Right will control sensor.") control_key = "sensor" - elif key.char == '1': + elif key.char == "1": print(f"\nSelected target 1. Up/Down/Left/Right will control target 1.") control_key = "target" return @@ -97,9 +105,11 @@ def on_key_release(key): pass if control_key is None: - print(f"\nPlease select the device to control by pressing a number from {list(range(len(control_options)))}.") + print( + f"\nPlease select the device to control by pressing a number from {list(range(len(control_options)))}." + ) print(f"{control_map}") - return + return if key == Key.right: print("\nRight key pressed\n") @@ -126,16 +136,13 @@ def on_key_release(key): elif control_key == "target": target_data["latitude"] -= 0.0001 - if control_key == "sensor": mqtt_client.client.publish( "gamutrf/inference", json.dumps(sensor_data) ) # also tried qos = 1 and 2 logging.info("Started transmission to broker: {}".format(sensor_data)) elif control_key == "target": - mqtt_client.client.publish( - "gamutrf/targets", json.dumps(target_data) - ) + mqtt_client.client.publish("gamutrf/targets", json.dumps(target_data)) logging.info("Started transmission to broker: {}".format(target_data)) with keyboard.Listener(on_release=on_key_release) as listener: diff --git a/birdseye/utils.py b/birdseye/utils.py index 8383ed46..a32405ee 100644 --- a/birdseye/utils.py +++ b/birdseye/utils.py @@ -676,7 +676,7 @@ def __init__( config={}, enable_heatmap=False, enable_gps_plot=False, - class_map={} + class_map={}, ): self.num_iters = num_iters self.experiment_name = experiment_name @@ -803,7 +803,6 @@ def live_plot( Create a live plot """ - lines = ( [] ) # https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.pyplot.legend.html @@ -819,24 +818,23 @@ def live_plot( ["lightskyblue", "darkblue", "blue"], ] sensor_color = "green" - # Target only openstreetmap - if ( - self.openstreetmap is None - and data.get("targets", None) - ): + if self.openstreetmap is None and data.get("targets", None): self.target_only_map = True print(data["targets"]) self.openstreetmap = GPSVis( - position=data["targets"][list(data["targets"])[0]]["position"], distance=map_distance + position=data["targets"][list(data["targets"])[0]]["position"], + distance=map_distance, + ) + self.openstreetmap.set_origin( + data["targets"][list(data["targets"])[0]]["position"] ) - self.openstreetmap.set_origin(data["targets"][list(data["targets"])[0]]["position"]) self.transform = np.array( [self.openstreetmap.origin[0], self.openstreetmap.origin[1]] ) - # Openstreetmap + # Openstreetmap if ( (self.openstreetmap is None or self.target_only_map) and data.get("position", None) is not None @@ -853,8 +851,8 @@ def live_plot( ) self.time_step = time_step - - # Particle filter statistics + + # Particle filter statistics self.pf_stats["mean_hypothesis"].append( env.pf.mean_hypothesis if hasattr(env.pf, "mean_hypothesis") else [None] ) @@ -873,7 +871,7 @@ def live_plot( self.sensor_hist.append(abs_sensor) # Sensor state history (from gps) - if self.openstreetmap and data.get("position", None) is not None: + if self.openstreetmap and data.get("position", None) is not None: self.sensor_gps_hist.append( self.openstreetmap.scale_to_img( data["position"], @@ -889,20 +887,23 @@ def live_plot( self.target_hist.append(env.get_absolute_target()) # Target state history (from gps) - if self.openstreetmap and data.get("target_gps", None) is not None: - + if self.openstreetmap and data.get("target_gps", None) is not None: for t in data["targets"]: - if t not in self.target_gps_hist: + if t not in self.target_gps_hist: self.target_gps_hist[t] = [] self.target_gps_hist[t].append( self.openstreetmap.scale_to_img( data["targets"][t]["position"], - (self.openstreetmap.width_meters, self.openstreetmap.height_meters), + ( + self.openstreetmap.width_meters, + self.openstreetmap.height_meters, + ), ) - ) - print(f"\n{t}:{data['targets'][t]['position']}, {self.target_gps_hist[t][-1]}\n") - + print( + f"\n{t}:{data['targets'][t]['position']}, {self.target_gps_hist[t][-1]}\n" + ) + # Target state history (old version; from file ) if self.openstreetmap and data.get("drone_position", None) is not None: # print(f"{data['drone_position']=}") @@ -912,10 +913,9 @@ def live_plot( (self.openstreetmap.width_meters, self.openstreetmap.height_meters), ) ) - - #target_heading = None - #target_relative_heading = None + # target_heading = None + # target_relative_heading = None # Calculate extra data if target is known if ( @@ -1016,8 +1016,8 @@ def live_plot( lines.extend([]) target_class_name = f"Target {t} particles" - for class_name, class_idx in self.class_map.items(): - if t == class_idx: + for class_name, class_idx in self.class_map.items(): + if t == class_idx: target_class_name = f"{class_name} particles" legend_elements.append( @@ -1047,7 +1047,6 @@ def live_plot( # Plot sensor if env.simulated or self.sensor_gps_hist: - if self.sensor_gps_hist: temp_np = np.array(self.sensor_gps_hist) sensor_x = temp_np[:, 0] @@ -1064,9 +1063,7 @@ def live_plot( sensor_x += self.transform[0] sensor_y += self.transform[1] - arrow_x, arrow_y = pol2cart( - 6, np.radians(env.state.sensor_state[2]) - ) + arrow_x, arrow_y = pol2cart(6, np.radians(env.state.sensor_state[2])) line4 = mpatches.FancyArrow( sensor_x[-1], sensor_y[-1], @@ -1085,8 +1082,8 @@ def live_plot( lines.extend([line4]) if len(self.sensor_hist) > 1: ax.plot( - sensor_x[len(sensor_x)-self.history_length:], - sensor_y[len(sensor_x)-self.history_length:], + sensor_x[len(sensor_x) - self.history_length :], + sensor_y[len(sensor_x) - self.history_length :], linewidth=5, color=sensor_color, # markeredgecolor="black", @@ -1098,14 +1095,13 @@ def live_plot( mpatches.Patch( facecolor=sensor_color, edgecolor="black", label="Sensor" ) - ) + ) # Plot targets if self.target_hist or self.target_gps_hist: - - if env.simulated: + if env.simulated: n_target_hist = env.state.n_targets - else: + else: n_target_hist = len(self.target_gps_hist) for t in range(n_target_hist): if env.simulated: @@ -1132,30 +1128,30 @@ def live_plot( zorder=3, markersize=4, ) - + (line5,) = ax.plot( target_x[-1], target_y[-1], "X", color="black", markeredgecolor="black", - #label="Targets", + # label="Targets", markersize=8, zorder=3, ) - + target_class_name = f"Target {t}" if self.target_gps_hist: target_class_name = list(self.target_gps_hist)[t] ax.text( - target_x[-1], - target_y[-1], - f"{target_class_name}", - color="black", + target_x[-1], + target_y[-1], + f"{target_class_name}", + color="black", fontsize=16, fontweight="bold", ) - #lines.extend([line5]) + # lines.extend([line5]) legend_elements.append( Line2D( [0], @@ -1394,8 +1390,8 @@ def build_multitarget_plots( # these are matplotlib.patch.Patch properties props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) - self.history_length = 50 - + self.history_length = 50 + if len(self.abs_target_hist) < self.history_length: self.abs_target_hist = [abs_target] * self.history_length self.abs_sensor_hist = [abs_sensor] * self.history_length diff --git a/geolocate.py b/geolocate.py index 3ebaed6c..6c7d83ae 100644 --- a/geolocate.py +++ b/geolocate.py @@ -11,7 +11,16 @@ import time from datetime import datetime -from flask import Flask, render_template, send_from_directory, redirect, request, url_for, make_response, jsonify +from flask import ( + Flask, + render_template, + send_from_directory, + redirect, + request, + url_for, + make_response, + jsonify, +) from io import BytesIO from timeit import default_timer as timer @@ -69,24 +78,26 @@ def real_observation(self): if observation is None: return default_observation - - if type(observation) == dict: + + if type(observation) == dict: format_observation = default_observation for class_name in observation: - self.class_map[class_name] = self.class_map.get(class_name, len(self.class_map)) + self.class_map[class_name] = self.class_map.get( + class_name, len(self.class_map) + ) format_observation[self.class_map[class_name]] = observation[class_name] observation = format_observation if type(observation) != list: observation = [observation] - if len(observation) != self.n_targets: + if len(observation) != self.n_targets: raise ValueError("len(observation) != n_targets") - for i in range(len(observation)): + for i in range(len(observation)): if observation[i] and observation[i] < self.threshold: observation[i] = None - + return observation @@ -152,20 +163,22 @@ def __init__(self, config_path="geolocate.ini"): self.config = default_config def target_handler(self, message_data): - logging.info(f"Received gamutrf/target MQTT message: {message_data}") - if message_data["gps_stale"].lower() != "false" or int(message_data["gps_fix_type"]) < 2: + if ( + message_data["gps_stale"].lower() != "false" + or int(message_data["gps_fix_type"]) < 2 + ): return "No target GPS." self.data["target_gps"] = "fix" target_name = message_data["target_name"] - if target_name not in self.data["targets"]: - self.data["targets"][target_name] = {"idx":len(self.data["targets"])} - + if target_name not in self.data["targets"]: + self.data["targets"][target_name] = {"idx": len(self.data["targets"])} + self.data["targets"][target_name]["position"] = ( - message_data["latitude"], - message_data["longitude"] + message_data["latitude"], + message_data["longitude"], ) def data_handler(self, message_data): @@ -198,19 +211,21 @@ def data_handler(self, message_data): metadata = message_data.get("metadata", None) if predictions: self.data["rssi"] = {} - for class_name in predictions.keys(): - + for class_name in predictions.keys(): self.data["rssi"][class_name] = np.mean( - [float(prediction.get("rssi_max", metadata["rssi_max"])) for prediction in predictions[class_name]] + [ + float(prediction.get("rssi_max", metadata["rssi_max"])) + for prediction in predictions[class_name] + ] ) - elif metadata: + elif metadata: # TODO: how to handle when tracking multiple targets - #self.data["rssi"] = float(metadata["rssi_mean"]) - #self.data["rssi"] = float(metadata["rssi_min"]) - #self.data["rssi"] = float(metadata.get["rssi_max"]) + # self.data["rssi"] = float(metadata["rssi_mean"]) + # self.data["rssi"] = float(metadata["rssi_min"]) + # self.data["rssi"] = float(metadata.get["rssi_max"]) pass - else: + else: self.data["rssi"] = message_data.get("rssi", None) self.data["position"] = message_data.get("position", self.data["position"]) @@ -256,32 +271,34 @@ def run_flask(self, flask_host, flask_port, fig, results): """ app = Flask(__name__) - @app.route('/gui/') + @app.route("/gui/") def gui_file(filename): - return send_from_directory('gui', filename) + return send_from_directory("gui", filename) @app.route("/gui/data") def gui_data(): data = base64.b64encode(self.image_buf.getvalue()).decode("ascii") - img = "data:image/png;base64,"+data + img = "data:image/png;base64," + data return jsonify(img) - @app.route('/refresh') + @app.route("/refresh") def refresh(): - os.makedirs('gui', exist_ok=True) + os.makedirs("gui", exist_ok=True) with open("gui/map.png", "wb") as img: img.write(self.image_buf.getbuffer()) - # newmapp = np.random.rand(500,500,3) * 255 + # newmapp = np.random.rand(500,500,3) * 255 # data = Image.fromarray(newmapp.astype('uint8')).convert('RGBA') # data.save('gui/map.png') - return "OK" + return "OK" - @app.route('/gui/form', methods = ['POST', 'GET']) + @app.route("/gui/form", methods=["POST", "GET"]) def gui_form(): - if request.method == 'POST': - user = request.form.get('name', None) - self.config['n_targets'] = request.form.get('n_targets', self.config['n_targets']) - reset = request.form.get('reset', None) + if request.method == "POST": + user = request.form.get("name", None) + self.config["n_targets"] = request.form.get( + "n_targets", self.config["n_targets"] + ) + reset = request.form.get("reset", None) if reset == "reset": self.stop() self.image_buf = BytesIO() @@ -300,7 +317,6 @@ def gui_from_buffer(): return render_template("loading.html") return render_template("gui_from_buffer.html", config=self.config) - @app.route("/") def index(): flask_start_time = timer() @@ -332,13 +348,13 @@ def index(): host=host_name, port=port, debug=False, use_reloader=False ) ).start() - + def get_replay_json(self, replay_file): with open(replay_file, "r", encoding="UTF-8") as open_file: replay_data = json.load(open_file) for ts in replay_data: yield replay_data[ts] - + def get_replay_log(self, replay_file): with open(replay_file, "r", encoding="UTF-8") as open_file: for line in open_file: @@ -347,15 +363,17 @@ def get_replay_log(self, replay_file): def start(self): self.stop_threads = False - self.main_thread = threading.Thread(target=self.main, args=[lambda: self.stop_threads]) + self.main_thread = threading.Thread( + target=self.main, args=[lambda: self.stop_threads] + ) self.main_thread.start() logging.info("Main thread started.") - def stop(self): + def stop(self): self.stop_threads = True self.main_thread.join() logging.info("Main thread stopped.") - + def main(self, stopped): """ Main loop @@ -421,9 +439,7 @@ def main(self, stopped): ("gamutrf/inference", self.data_handler), ("gamutrf/targets", self.target_handler), ] - mqtt_client = birdseye.mqtt.BirdsEyeMQTT( - mqtt_host, mqtt_port, topics - ) + mqtt_client = birdseye.mqtt.BirdsEyeMQTT(mqtt_host, mqtt_port, topics) else: if replay_file.endswith(".log"): get_replay_data = self.get_replay_log(replay_file) @@ -543,7 +559,12 @@ def main(self, stopped): control_actions = [] step_time = 0 - while self.data["gps"] != "fix" and self.data["target_gps"] != "fix" and not replay_file and not stopped(): + while ( + self.data["gps"] != "fix" + and self.data["target_gps"] != "fix" + and not replay_file + and not stopped() + ): time.sleep(1) logging.info("Waiting for GPS...") @@ -553,11 +574,11 @@ def main(self, stopped): if replay_file: # load data from saved file - try: + try: replay_data = next(get_replay_data) except StopIteration: break - + self.data_handler(replay_data) action_start = timer() @@ -652,7 +673,9 @@ def main(self, stopped): if __name__ == "__main__": # pragma: no cover parser = argparse.ArgumentParser() - parser.add_argument("config_path", help="Path to config file, geolocate.ini provided as example.") + parser.add_argument( + "config_path", help="Path to config file, geolocate.ini provided as example." + ) parser.add_argument("--log", default="INFO", help="Log level") args = parser.parse_args()