diff --git a/kowalski/alert_brokers/alert_broker.py b/kowalski/alert_brokers/alert_broker.py index 013ade2b..5eb7bb90 100644 --- a/kowalski/alert_brokers/alert_broker.py +++ b/kowalski/alert_brokers/alert_broker.py @@ -1364,6 +1364,15 @@ def alert_filter__user_defined( }, } + if not isinstance(_filter.get("autosave", False), bool): + passed_filter["auto_followup"]["data"][ + "ignore_source_group_ids" + ] = [ + _filter.get("autosave", {}).get( + "ignore_group_ids", [] + ) + ] + passed_filters.append(passed_filter) except Exception as e: @@ -1927,9 +1936,7 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): response = self.api_skyportal( "POST", "/api/followup_request", - passed_filter["auto_followup"][ - "data" - ], # already contains the optional ignore_group_ids + passed_filter["auto_followup"]["data"], ) if ( response.json()["status"] == "success" @@ -1939,7 +1946,7 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): is False ): log( - f"Posted followup request for {alert['objectId']} to SkyPortal" + f"Posted followup request successfully for {alert['objectId']} to SkyPortal" ) # add it to the existing requests existing_requests.append( @@ -1971,7 +1978,14 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): "text": passed_filter["auto_followup"][ "comment" ], - "group_ids": [passed_filter["group_id"]], + "group_ids": list( + set( + [passed_filter["group_id"]] + + passed_filter.get("auto_followup", {}) + .get("data", {}) + .get("target_group_ids", []) + ) + ), } with timer( f"Posting followup comment {comment['text']} for {alert['objectId']} to SkyPortal", @@ -1989,7 +2003,7 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): .get("data", {}) .get( "message", - "unknow error posting comment", + "unknown error posting comment", ) ) except Exception as e: @@ -1997,15 +2011,20 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): f"Failed to post followup comment {comment['text']} for {alert['objectId']} to SkyPortal: {e}" ) else: - error_message = response.json().get( - "message", - response.json() - .get("data", {}) - .get( + try: + error_message = response.json().get( "message", - "unknow error posting followup request", - ), - ) + response.json() + .get("data", {}) + .get( + "message", + "unknown error posting followup request", + ), + ) + except Exception: + error_message = ( + "unknown error posting followup request" + ) raise ValueError(error_message) except Exception as e: log( @@ -2079,7 +2098,7 @@ def alert_sentinel_skyportal(self, alert, prv_candidates, passed_filters): raise ValueError( response.json().get( "message", - "unknow error updating followup request", + "unknown error updating followup request", ) ) except Exception as e: diff --git a/kowalski/alert_brokers/alert_broker_ztf.py b/kowalski/alert_brokers/alert_broker_ztf.py index d723f4f4..87863271 100644 --- a/kowalski/alert_brokers/alert_broker_ztf.py +++ b/kowalski/alert_brokers/alert_broker_ztf.py @@ -7,6 +7,7 @@ import threading import time import traceback +import numpy as np from abc import ABC from copy import deepcopy from typing import Mapping, Sequence @@ -43,7 +44,7 @@ def process_alert(alert: Mapping, topic: str): # get worker running current task worker = dask.distributed.get_worker() - alert_worker = worker.plugins["worker-init"].alert_worker + alert_worker: ZTFAlertWorker = worker.plugins["worker-init"].alert_worker log(f"{topic} {object_id} {candid} {worker.address}") @@ -129,9 +130,12 @@ def process_alert(alert: Mapping, topic: str): "_id": object_id, "cross_matches": xmatches, "prv_candidates": prv_candidates, - "fp_hists": fp_hists, } + # only add the fp_hists if its a brand new object, not just if there is no entry there + if alert["candidate"]["ndethist"] <= 1: + alert_aux["fp_hists"] = alert_worker.format_fp_hists(alert, fp_hists) + with timer(f"Aux ingesting {object_id} {candid}", alert_worker.verbose > 1): retry(alert_worker.mongo.insert_one)( collection=alert_worker.collection_alerts_aux, document=alert_aux @@ -150,34 +154,24 @@ def process_alert(alert: Mapping, topic: str): upsert=True, ) - # FOR NOW: we decided to only store the forced photometry for the very first alert we get for an object - # so, no need to update anything here - - # update fp_hists - # existing_fp_hists = retry( - # alert_worker.mongo.db[alert_worker.collection_alerts_aux].find_one - # )({"_id": object_id}, {"fp_hists": 1}) - # if existing_fp_hists is not None: - # existing_fp_hists = existing_fp_hists.get("fp_hists", []) - # if len(existing_fp_hists) > 0: - # new_fp_hists = alert_worker.deduplicate_fp_hists( - # existing_fp_hists, fp_hists - # ) - # else: - # new_fp_hists = fp_hists - # else: - # new_fp_hists = fp_hists - # retry( - # alert_worker.mongo.db[alert_worker.collection_alerts_aux].update_one - # )( - # {"_id": object_id}, - # { - # "$set": { - # "fp_hists": new_fp_hists, - # } - # }, - # upsert=True, - # ) + # if there is no fp_hists for this object, we don't update anything + # the idea is that we start accumulating FP only for new objects, to avoid + # having some objects with incomplete FP history, which would be confusing for the filters + # either there is full FP, or there isn't any + if ( + retry( + alert_worker.mongo.db[ + alert_worker.collection_alerts_aux + ].count_documents + )( + {"_id": alert["objectId"], "fp_hists": {"$exists": True}}, + limit=1, + ) + == 1 + ): + alert_worker.update_fp_hists( + alert, alert_worker.format_fp_hists(alert, fp_hists) + ) if config["misc"]["broker"]: # execute user-defined alert filters @@ -495,25 +489,162 @@ def alert_put_photometry(self, alert): ) continue - def deduplicate_fp_hists(self, existing_fp=[], latest_fp=[]): - # for the forced photometry (fp_hists) unfortunately it's not as simple as deduplicating with a set - # the fp_hists of each candidate of an object is recomputed everytime, so datapoints - # at the same jd can be different, so we grab the existing fp_hists aggregate, and build a new one. + def flux_to_mag(self, flux, fluxerr, zp): + """Convert flux to magnitude and calculate SNR + + :param flux: + :param fluxerr: + :param zp: + :param snr_threshold: + :return: + """ + # make sure its all numpy floats or nans + values = np.array([flux, fluxerr, zp], dtype=np.float64) + snr = values[0] / values[1] + mag = -2.5 * np.log10(values[0]) + values[2] + magerr = 1.0857 * (values[1] / values[0]) + limmag3sig = -2.5 * np.log10(3 * values[1]) + values[2] + limmag5sig = -2.5 * np.log10(5 * values[1]) + values[2] + if np.isnan(snr): + return {} + if snr < 0: + return { + "snr": snr, + } + mag_data = { + "mag": mag, + "magerr": magerr, + "snr": snr, + "limmag3sig": limmag3sig, + "limmag5sig": limmag5sig, + } + # remove all NaNs fields + mag_data = {k: v for k, v in mag_data.items() if not np.isnan(v)} + return mag_data + + def format_fp_hists(self, alert, fp_hists): + if len(fp_hists) == 0: + return [] + # sort by jd + fp_hists = sorted(fp_hists, key=lambda x: x["jd"]) + + # add the "alert_mag" field to the new fp_hist + # as well as alert_ra, alert_dec + for i, fp in enumerate(fp_hists): + fp_hists[i] = { + **fp, + **self.flux_to_mag( + flux=fp.get("forcediffimflux", np.nan), + fluxerr=fp.get("forcediffimfluxunc", np.nan), + zp=fp["magzpsci"], + ), + "alert_mag": alert["candidate"]["magpsf"], + "alert_ra": alert["candidate"]["ra"], + "alert_dec": alert["candidate"]["dec"], + } + + return fp_hists - # first find the oldest jd in the latest fp_hists - oldest_jd_in_latest = min([fp["jd"] for fp in latest_fp]) - # get all the datapoints in the existing fp_hists that are older than the oldest jd in the latest fp_hists - older_datapoints = [fp for fp in existing_fp if fp["jd"] < oldest_jd_in_latest] + def update_fp_hists(self, alert, formatted_fp_hists): + # update the existing fp_hist with the new one + # instead of treating it as a set, + # if some entries have the same jd, keep the one with the highest alert_mag - # TODO: implement a better logic here. Could be based on: - # - SNR (better SNR datapoints might be better) - # - position (centroid, closer to the avg position might be better) - # - mag (if 1 sigma brighter or dimmer than the current datapoints, use the newer ones) + # make sure this is an aggregate pipeline in mongodb + if len(formatted_fp_hists) == 0: + return - # for now, just append the latest fp_hists to the older ones, - # to prioritize newer datapoints which might come from an updated pipeline + with timer( + f"Updating fp_hists of {alert['objectId']} {alert['candid']}", + self.verbose > 1, + ): + update_pipeline = [ + # 1. concat the new fp_hists with the existing ones + { + "$project": { + "all_fp_hists": { + "$concatArrays": [ + {"$ifNull": ["$fp_hists", []]}, + formatted_fp_hists, + ] + } + } + }, + # 2. unwind the resulting array to get one document per fp_hist + {"$unwind": "$all_fp_hists"}, + # 3. group by jd and keep the one with the highest alert_mag for each jd + { + "$set": { + "all_fp_hists.alert_mag": { + "$cond": { + "if": { + "$eq": [ + {"$type": "$all_fp_hists.alert_mag"}, + "missing", + ] + }, + "then": -99999.0, + "else": "$all_fp_hists.alert_mag", + } + } + } + }, + # 4. sort by jd and alert_mag + { + "$sort": { + "all_fp_hists.jd": 1, + "all_fp_hists.alert_mag": 1, + } + }, + # 5. group all the deduplicated fp_hists back into an array, keeping the first one (the brightest at each jd) + { + "$group": { + "_id": "$all_fp_hists.jd", + "fp_hist": {"$first": "$$ROOT.all_fp_hists"}, + } + }, + # 6. sort by jd again + {"$sort": {"fp_hist.jd": 1}}, + # 7. group all the fp_hists documents back into a single array + {"$group": {"_id": None, "fp_hists": {"$push": "$fp_hist"}}}, + # 8. project only the new fp_hists array + {"$project": {"fp_hists": 1, "_id": 0}}, + ] + n_retries = 0 + while True: + # run the pipeline and then update the document + new_fp_hists = ( + self.mongo.db[self.collection_alerts_aux] + .aggregate( + update_pipeline, + ) + .next() + .get("fp_hists", []) + ) - return older_datapoints + latest_fp + # update the document, only if there is still less points in the DB than in the new fp_hists. + # Otherwise, rerun the pipeline. This is to help a little bit with concurrency issues + result = self.mongo.db[self.collection_alerts_aux].find_one_and_update( + { + "_id": alert["objectId"], + f"fp_hists.{len(new_fp_hists)}": {"$exists": False}, + }, + {"$set": {"fp_hists": new_fp_hists}}, + ) + if result is None: + n_retries += 1 + if n_retries > 10: + log( + f"Failed to update fp_hists of {alert['objectId']} {alert['candid']}" + ) + break + else: + log( + f"Retrying to update fp_hists of {alert['objectId']} {alert['candid']}" + ) + time.sleep(1) + else: + break class WorkerInitializer(dask.distributed.WorkerPlugin): diff --git a/kowalski/ingesters/ingester.py b/kowalski/ingesters/ingester.py index ff169699..62d151f2 100644 --- a/kowalski/ingesters/ingester.py +++ b/kowalski/ingesters/ingester.py @@ -4,6 +4,7 @@ import time from confluent_kafka import Producer from kowalski.log import log +import threading def delivery_report(err, msg): @@ -18,12 +19,21 @@ def delivery_report(err, msg): class KafkaStream: as_context_manager = False - def __init__(self, topic, path_alerts, config, test=False): + def __init__(self, topic, path_alerts, config, test=False, **kwargs): self.config = config self.topic = topic self.path_alerts = path_alerts self.test = test + if kwargs.get("max_alerts") is not None: + try: + int(kwargs.get("max_alerts")) + except ValueError: + raise ValueError("max_alerts must be an integer") + self.max_alerts = int(kwargs.get("max_alerts")) + else: + self.max_alerts = None + def start(self): # create a kafka topic and start a producer to stream the alerts path_logs = pathlib.Path("logs/") @@ -148,7 +158,11 @@ def start(self): } ) - for p in self.path_alerts.glob("*.avro"): + alerts = list(self.path_alerts.glob("*.avro")) + if self.max_alerts is not None: + alerts = alerts[: self.max_alerts] + log(f"Streaming {len(alerts)} alerts") + for p in alerts: with open(str(p), "rb") as data: # Trigger any available delivery report callbacks from previous produce() calls producer.poll(0) @@ -158,7 +172,18 @@ def start(self): # Asynchronously produce a message, the delivery report callback # will be triggered from poll() above, or flush() below, when the message has # been successfully delivered or failed permanently. - producer.produce(self.topic, data.read(), callback=delivery_report) + while True: + try: + producer.produce( + self.topic, data.read(), callback=delivery_report + ) + break + except BufferError: + print( + "Local producer queue is full (%d messages awaiting delivery): try again\n" + % len(producer) + ) + time.sleep(1) # Wait for any outstanding messages to be delivered and delivery report # callbacks to be triggered. @@ -208,7 +233,14 @@ def stop(self): os.remove(meta_properties) def __enter__(self): - self.start() + # when not in test mode, call the start method in a separate thread + # this helps so that you can start polling the topic right away + # otherwise, the start method will block until all alerts are ingested + # which is not possible if you are ingesting more alerts than the queue size + if self.test: + self.start() + else: + threading.Thread(target=self.start).start() self.as_context_manager = True time.sleep(15) # give it a chance to finish ingesting properly return self diff --git a/kowalski/tests/test_alert_broker_ztf.py b/kowalski/tests/test_alert_broker_ztf.py index 8741f05c..fef39ae7 100644 --- a/kowalski/tests/test_alert_broker_ztf.py +++ b/kowalski/tests/test_alert_broker_ztf.py @@ -67,15 +67,48 @@ def filter_template(upstream): return template -def post_alert(worker, alert): - alert, _, _ = worker.alert_mongify(alert) +def post_alert(worker: ZTFAlertWorker, alert, fp_cutoff=1): + delete_alert(worker, alert) + alert, prv_candidates, fp_hists = worker.alert_mongify(alert) # check if it already exists if worker.mongo.db[worker.collection_alerts].count_documents( {"candid": alert["candid"]} ): log(f"Alert {alert['candid']} already exists, skipping") - return - worker.mongo.insert_one(collection=worker.collection_alerts, document=alert) + else: + worker.mongo.insert_one(collection=worker.collection_alerts, document=alert) + + if worker.mongo.db[worker.collection_alerts_aux].count_documents( + {"_id": alert["objectId"]} + ): + # delete if it exists + worker.mongo.delete_one( + collection=worker.collection_alerts_aux, + document={"_id": alert["objectId"]}, + ) + + # fp_hists: pop nulls - save space + fp_hists = [ + {kk: vv for kk, vv in fp_hist.items() if vv not in [None, -99999, -99999.0]} + for fp_hist in fp_hists + ] + + fp_hists = worker.format_fp_hists(alert, fp_hists) + + # sort fp_hists by jd + fp_hists = sorted(fp_hists, key=lambda x: x["jd"]) + + if fp_cutoff < 1: + index = int(fp_cutoff * len(fp_hists)) + fp_hists = fp_hists[:index] + + aux = { + "_id": alert["objectId"], + "prv_candidates": prv_candidates, + "cross_matches": {}, + "fp_hists": fp_hists, + } + worker.mongo.insert_one(collection=worker.collection_alerts_aux, document=aux) def delete_alert(worker, alert): @@ -83,6 +116,10 @@ def delete_alert(worker, alert): collection=worker.collection_alerts, document={"candidate.candid": alert["candid"]}, ) + worker.mongo.delete_one( + collection=worker.collection_alerts_aux, + document={"_id": alert["objectId"]}, + ) class TestAlertBrokerZTF: @@ -156,6 +193,121 @@ def test_alert_mongofication_with_fphists(self): # for k in new_fp_hists[i].keys(): # assert new_fp_hists[i][k] == fp_hists[i][k] + def test_ingest_alert_with_fp_hists(self): + candid = 2475433850015010009 + sample_avro = f"data/ztf_alerts/20231012/{candid}.avro" + alert_mag = 21.0 + fp_hists = [] + with open(sample_avro, "rb") as f: + records = [record for record in fastavro.reader(f)] + for record in records: + # delete_alert(self.worker, record) + alert, prv_candidates, fp_hists = self.worker.alert_mongify(record) + alert_mag = alert["candidate"]["magpsf"] + post_alert(self.worker, record, fp_cutoff=0.7) + + # verify that the alert was ingested + assert ( + self.worker.mongo.db[self.worker.collection_alerts].count_documents( + {"candid": candid} + ) + == 1 + ) + assert ( + self.worker.mongo.db[self.worker.collection_alerts_aux].count_documents( + {"_id": record["objectId"]} + ) + == 1 + ) + + # verify that fp_hists was ingested correctly + aux = self.worker.mongo.db[self.worker.collection_alerts_aux].find_one( + {"_id": record["objectId"]} + ) + assert "fp_hists" in aux + assert len(aux["fp_hists"]) == 14 # we had a cutoff at 21 * 0.7 = 14.7, so 14 + + # print("---------- Original ----------") + # for fp in aux["fp_hists"]: + # print(f"{fp['jd']}: {fp['alert_mag']}") + + # fp_hists: pop nulls - save space and make sure its the same as what is in the DB + fp_hists = [ + {kk: vv for kk, vv in fp_hist.items() if vv not in [None, -99999, -99999.0]} + for fp_hist in fp_hists + ] + + # sort fp_hists by jd + fp_hists = sorted(fp_hists, key=lambda x: x["jd"]) + + # now, let's try the alert worker's update_fp_hists method, passing it the full fp_hists + # and verify that it we have 21 exactly + + # first we add some forced photometry, where we have new rows, but overlapping rows from a fainter alert + record["candidate"]["magpsf"] = 30.0 + # keep the last 10 fp_hists + fp_hists_copy = fp_hists[-10:] + # remove the last 2 + fp_hists_copy = fp_hists_copy[:-2] + + fp_hists_formatted = self.worker.format_fp_hists(alert, fp_hists_copy) + self.worker.update_fp_hists(record, fp_hists_formatted) + + aux = self.worker.mongo.db[self.worker.collection_alerts_aux].find_one( + {"_id": record["objectId"]} + ) + # print(aux) + assert "fp_hists" in aux + assert len(aux["fp_hists"]) == 19 + + # print("---------- First update ----------") + # for fp in aux["fp_hists"]: + # print(f"{fp['jd']}: {fp['alert_mag']}") + + # we should have the same first 14 fp_hists as before, then the new ones + assert all([aux["fp_hists"][i]["alert_mag"] == alert_mag for i in range(14)]) + assert all([aux["fp_hists"][i]["alert_mag"] == 30.0 for i in range(14, 19)]) + + # verify they are still in order by jd (oldest to newest) + assert all( + [ + aux["fp_hists"][i]["jd"] < aux["fp_hists"][i + 1]["jd"] + for i in range(len(aux["fp_hists"]) - 1) + ] + ) + + # now, the last 10 datapoints, but as if they were from a brighter alert + # we should have an overlap with both the original FP, and the datapoints (from faint alert) we just added + record["candidate"]["magpsf"] = 15.0 + # keep the last 10 fp_hists + fp_hists_copy = fp_hists[-10:] + + fp_hists_formatted = self.worker.format_fp_hists(alert, fp_hists_copy) + self.worker.update_fp_hists(record, fp_hists_formatted) + + aux = self.worker.mongo.db[self.worker.collection_alerts_aux].find_one( + {"_id": record["objectId"]} + ) + + assert "fp_hists" in aux + assert len(aux["fp_hists"]) == 21 + + # print("---------- Last update ----------") + # for fp in aux["fp_hists"]: + # print(f"{fp['jd']}: {fp['alert_mag']}") + + # the result should be 21 fp_hists, with the first 11 being the same as before, and the last 10 being the same as the new fp_hists + assert all([aux["fp_hists"][i]["alert_mag"] == alert_mag for i in range(11)]) + assert all([aux["fp_hists"][i]["alert_mag"] == 15.0 for i in range(11, 21)]) + + # verify they are still in order by jd (oldest to newest) + assert all( + [ + aux["fp_hists"][i]["jd"] < aux["fp_hists"][i + 1]["jd"] + for i in range(len(aux["fp_hists"]) - 1) + ] + ) + def test_make_photometry(self): df_photometry = self.worker.make_photometry(self.alert) assert len(df_photometry) == 32 diff --git a/kowalski/tools/kafka_stream.py b/kowalski/tools/kafka_stream.py index 3be6cbbf..9eb8af24 100644 --- a/kowalski/tools/kafka_stream.py +++ b/kowalski/tools/kafka_stream.py @@ -21,12 +21,18 @@ type=bool, help="test mode. if in test mode, alerts will be pushed to bootstarp.test.server", ) - +parser.add_argument( + "--max_alerts", + type=int, + default=None, + help="maximum number of alerts to stream (optional)", +) args = parser.parse_args() topic = args.topic path_alerts = args.path_alerts test = args.test +max_alerts = args.max_alerts if not isinstance(topic, str) or topic == "": raise ValueError("topic must be a non-empty string") @@ -48,6 +54,7 @@ path_alerts=pathlib.Path(f"data/{path_alerts}"), test=test, config=config, + max_alerts=max_alerts, ) running = True @@ -56,7 +63,9 @@ # if the user hits Ctrl+C, stop the stream try: stream.start() - time.sleep(1000000000) + while True: + time.sleep(60) + print("heartbeat") except KeyboardInterrupt: print("\nStopping Kafka stream...") stream.stop()