-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deployable ML4H Dockers #577
Open
lucidtronix
wants to merge
117
commits into
master
Choose a base branch
from
sf_attn
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,153
−444
Open
Changes from all commits
Commits
Show all changes
117 commits
Select commit
Hold shift + click to select a range
7c903ed
test steps in diffusion plot
lucidtronix 088020e
test steps in diffusion plot
lucidtronix 9382159
test steps in diffusion plot
lucidtronix c4ca4cf
test steps in diffusion plot
lucidtronix 26ea28c
test steps in diffusion plot
lucidtronix a07b7e5
test steps in diffusion plot
lucidtronix a15baa2
test steps in diffusion plot
lucidtronix a627f1f
test steps in diffusion plot
lucidtronix 01b1756
test steps in diffusion plot
lucidtronix 84386f1
test steps in diffusion plot
lucidtronix 3435107
test steps in diffusion plot
lucidtronix ae083af
test steps in diffusion plot
lucidtronix 50f8875
test steps in diffusion plot
lucidtronix a5c2ca3
test steps in diffusion plot
lucidtronix 93978fd
test steps in diffusion plot
lucidtronix a95f727
test steps in diffusion plot
lucidtronix 7c90a47
test steps in diffusion plot
lucidtronix 53af85f
test steps in diffusion plot
lucidtronix 9bfeb36
test steps in diffusion plot
lucidtronix c548720
test steps in diffusion plot
lucidtronix 5617694
test steps in diffusion plot
lucidtronix 4cb6b04
test steps in diffusion plot
lucidtronix 16a3ba2
brain mri z index norm
lucidtronix 97af51e
brain mri z index norm
lucidtronix d476edf
brain mri z index norm
lucidtronix dde6044
brain mri z index norm
lucidtronix 0825553
brain mri z index norm
lucidtronix c716324
brain mri z index norm
lucidtronix 1d1a19a
log auc and mse
lucidtronix 3c538b3
log auc and mse
lucidtronix a543f43
log auc and mse
lucidtronix dfcdb33
log auc and mse
lucidtronix 5d8132f
log auc and mse
lucidtronix 8f357ab
log auc and mse
lucidtronix 50ae738
log auc and mse
lucidtronix 7b00f90
log auc and mse
lucidtronix 28a7ea3
log auc and mse
lucidtronix f4734e7
log auc and mse
lucidtronix 35b3a6d
log auc and mse
lucidtronix e510f82
log auc and mse
lucidtronix fa239b8
log auc and mse
lucidtronix c5f718c
log auc and mse
lucidtronix dca9f3e
log auc and mse
lucidtronix 953aa69
log auc and mse
lucidtronix 4e11ca2
log auc and mse
lucidtronix 1df78f7
log auc and mse
lucidtronix f892448
cumc docker
lucidtronix 43c857e
cumc docker
lucidtronix e495b48
cumc docker
lucidtronix 13ee745
cumc docker
lucidtronix aa5c162
cumc docker
lucidtronix 37db283
cumc docker
lucidtronix 5befd1e
cumc docker
lucidtronix 87b4cbc
cumc docker
lucidtronix 02dc15a
cumc docker
lucidtronix 1a93304
cumc docker
lucidtronix 3393380
cumc docker
lucidtronix baa1875
cumc docker
lucidtronix afa912d
cumc docker
lucidtronix 6fddd3d
cumc docker
lucidtronix 3cef024
fix
lucidtronix 82c8c9a
fix
lucidtronix aa2e4ea
fix
lucidtronix 8c6b203
fix
lucidtronix ab95a1e
fix
lucidtronix 3beef1b
fix
lucidtronix 2ab67e1
fix
lucidtronix 43381d5
fix
lucidtronix 414aafa
fix
lucidtronix 6921c4a
fix
lucidtronix 1eee793
ecg 512
lucidtronix 5928cd7
condition strategy
lucidtronix c53e449
condition strategy
lucidtronix 82a9d67
condition strategy
lucidtronix d85a97d
condition strategy
lucidtronix 0931fd7
condition strategy
lucidtronix 7b49ea0
condition strategy
lucidtronix ab3c844
condition strategy
lucidtronix 7f082be
condition strategy
lucidtronix ce66b0e
condition strategy
lucidtronix 3e94259
condition strategy
lucidtronix 4e7cebb
condition strategy
lucidtronix a3165b6
condition strategy
lucidtronix 306888a
condition strategy
lucidtronix b146aa6
condition strategy
lucidtronix b7d5afb
condition strategy
lucidtronix e8e5228
condition strategy
lucidtronix 46c5298
condition strategy
lucidtronix 7a4f445
condition strategy
lucidtronix db3530a
condition strategy
lucidtronix baf60a3
sigmoid loss unconditioned
lucidtronix 2df321c
sigmoid loss unconditioned
lucidtronix bb5feaa
sigmoid loss unconditioned
lucidtronix 44807ae
sigmoid loss unconditioned
lucidtronix 6b3ddfc
sigmoid loss unconditioned
lucidtronix 9bf044a
sigmoid loss unconditioned
lucidtronix f95ee86
sigmoid loss unconditioned
lucidtronix 1d40270
sigmoid loss unconditioned
lucidtronix 2f0602d
sigmoid loss unconditioned
lucidtronix a5d518d
sigmoid loss unconditioned
lucidtronix ff81dee
kernel inception distance
lucidtronix 7827909
condition and supervise
lucidtronix 3229037
condition and supervise
lucidtronix c9c16dd
condition and supervise
lucidtronix 1632042
condition and supervise
lucidtronix 1f6bc03
condition and supervise
lucidtronix d2566d7
condition and supervise
lucidtronix 2617ff7
condition and supervise
lucidtronix 6a90d26
condition and supervise
lucidtronix 44ac183
condition and supervise
lucidtronix 75e86e6
condition and supervise
lucidtronix d65c3e9
condition and supervise
lucidtronix 0a32e0b
condition and supervise
lucidtronix 4f6f062
condition and supervise
lucidtronix 5cd6e08
condition and supervise
lucidtronix 1ca98be
condition and supervise
lucidtronix fc30746
condition and supervise
lucidtronix File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM ghcr.io/broadinstitute/ml4h:tf2.9-latest-cpu | ||
|
||
# Set the working directory | ||
WORKDIR /app | ||
|
||
# Install TensorFlow (or any other necessary libraries) | ||
RUN pip install tensorflow | ||
|
||
# Copy the Keras model file into the Docker image | ||
COPY ecg2af_quintuplet_v2024_01_13.h5 /app/ecg2af_quintuplet_v2024_01_13.h5 | ||
|
||
# Copy the Python script | ||
COPY process_files.py /app/process_files.py | ||
|
||
RUN pip3 install ml4h | ||
|
||
# Define the command to run the script | ||
CMD ["python", "process_files.py", "/data"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
import os | ||
import sys | ||
import base64 | ||
import struct | ||
from collections import defaultdict | ||
|
||
import h5py | ||
import xmltodict | ||
import numpy as np | ||
import pandas as pd | ||
from tensorflow.keras.models import load_model | ||
from ml4h.TensorMap import TensorMap, Interpretation | ||
from ml4h.defines import ECG_REST_AMP_LEADS | ||
from ml4h.models.model_factory import get_custom_objects | ||
|
||
n_intervals = 25 | ||
|
||
ecg_tmap = TensorMap( | ||
'ecg_5000_std', | ||
Interpretation.CONTINUOUS, | ||
shape=(5000, 12), | ||
channel_map=ECG_REST_AMP_LEADS | ||
) | ||
|
||
af_tmap = TensorMap( | ||
'survival_curve_af', | ||
Interpretation.SURVIVAL_CURVE, | ||
shape=(n_intervals*2,), | ||
) | ||
|
||
death_tmap = TensorMap( | ||
'death_event', | ||
Interpretation.SURVIVAL_CURVE, | ||
shape=(n_intervals*2,), | ||
) | ||
|
||
sex_tmap = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male':1}) | ||
age_tmap = TensorMap(name='age_in_days', interpretation=Interpretation.CONTINUOUS, channel_map={'age_in_days': 0}) | ||
af_in_read_tmap = TensorMap(name='af_in_read', interpretation=Interpretation.CATEGORICAL, channel_map={'no_af_in_read': 0, 'af_in_read':1}) | ||
|
||
output_tensormaps = {tm.output_name(): tm for tm in [af_tmap, death_tmap, sex_tmap, age_tmap, af_in_read_tmap]} | ||
custom_dict = get_custom_objects(list(output_tensormaps.values())) | ||
model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict) | ||
space_dict = defaultdict(list) | ||
|
||
def process_ukb_hd5(filepath, space_dict): | ||
# Placeholder for file processing logic | ||
print(f"Processing file: {filepath}") | ||
with h5py.File(filepath, 'r') as hd5: | ||
ecg_array = np.zeros(ecg_tmap.shape, dtype=np.float32) | ||
for lead in ecg_tmap.channel_map: | ||
ecg_array[:, ecg_tmap.channel_map[lead]] = hd5[f'/ukb_ecg_rest/strip_{lead}/instance_0'] | ||
|
||
ecg_array -= ecg_array.mean() | ||
ecg_array /= (ecg_array.std() + 1e-6) | ||
#print(f"Got tensor: {tensor.mean():0.3f}") | ||
prediction = model.predict(np.expand_dims(ecg_array, axis=0), verbose=0) | ||
if len(model.output_names) == 1: | ||
prediction = [prediction] | ||
predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} | ||
#print(f"Got predictions: {predictions_dict}") | ||
space_dict['sample_id'].append(os.path.basename(filepath).replace('.hd5', '')) | ||
space_dict['ecg_path'].append(filepath) | ||
if '/dates/atrial_fibrillation_or_flutter_date' in hd5: | ||
space_dict['has_af'].append(1) | ||
else: | ||
space_dict['has_af'].append(0) | ||
|
||
for otm in output_tensormaps.values(): | ||
y = predictions_dict[otm.output_name()] | ||
if otm.is_categorical(): | ||
space_dict[f'{otm.name}_prediction'].append(y[0, 1]) | ||
elif otm.is_continuous(): | ||
space_dict[f'{otm.name}_prediction'].append(y[0, 0]) | ||
elif otm.is_survival_curve(): | ||
intervals = otm.shape[-1] // 2 | ||
days_per_bin = 1 + (2 * otm.days_window) // intervals | ||
predicted_survivals = np.cumprod(y[:, :intervals], axis=1) | ||
space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1])) | ||
# print(f' got target: {target[otm.output_name()].numpy().shape}') | ||
# sick = np.sum(target[otm.output_name()].numpy()[:, intervals:], axis=-1) | ||
# follow_up = np.cumsum(target[otm.output_name()].numpy()[:, :intervals], axis=-1)[:, -1] * days_per_bin | ||
# space_dict[f'{otm.name}_event'].append(str(sick[b])) | ||
# space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b])) | ||
# Example: Use the model to make a prediction (add real processing logic here) | ||
|
||
def decode_ekg_muse(raw_wave): | ||
""" | ||
Ingest the base64 encoded waveforms and transform to numeric | ||
""" | ||
# covert the waveform from base64 to byte array | ||
arr = base64.b64decode(bytes(raw_wave, 'utf-8')) | ||
|
||
# unpack every 2 bytes, little endian (16 bit encoding) | ||
unpack_symbols = ''.join([char * int(len(arr) / 2) for char in 'h']) | ||
byte_array = struct.unpack(unpack_symbols, arr) | ||
return byte_array | ||
|
||
def decode_ekg_muse_to_array(raw_wave, downsample=1): | ||
""" | ||
Ingest the base64 encoded waveforms and transform to numeric | ||
|
||
downsample: 0.5 takes every other value in the array. Muse samples at 500/s and the sample model requires 250/s. So take every other. | ||
""" | ||
try: | ||
dwnsmpl = int(1 // downsample) | ||
except ZeroDivisionError: | ||
print("You must downsample by more than 0") | ||
# covert the waveform from base64 to byte array | ||
arr = base64.b64decode(bytes(raw_wave, 'utf-8')) | ||
|
||
# unpack every 2 bytes, little endian (16 bit encoding) | ||
unpack_symbols = ''.join([char * int(len(arr) / 2) for char in 'h']) | ||
byte_array = struct.unpack(unpack_symbols, arr) | ||
return np.array(byte_array)[::dwnsmpl] | ||
|
||
def process_ge_muse_xml(filepath, space_dict): | ||
with open(filepath, 'rb') as fd: | ||
dic = xmltodict.parse(fd.read().decode('utf8')) | ||
|
||
""" | ||
|
||
Upload the ECG as numpy array with shape=[2500,12,1] ([time, leads, 1]). | ||
|
||
The voltage unit should be in 1 mv/unit and the sampling rate should be 250/second (total 10 second). | ||
|
||
The leads should be ordered as follow I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6. | ||
|
||
""" | ||
try: | ||
patient_id = dic['RestingECG']['PatientDemographics']['PatientID'] | ||
except: | ||
print("no PatientID") | ||
patient_id = "none" | ||
try: | ||
pharma_unique_ecg_id = dic['RestingECG']['PharmaData']['PharmaUniqueECGID'] | ||
except: | ||
print("no PharmaUniqueECGID") | ||
pharma_unique_ecg_id = "none" | ||
try: | ||
acquisition_date_time = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + "_" + \ | ||
dic['RestingECG']['TestDemographics']['AcquisitionTime'].replace(":", "-") | ||
except: | ||
print("no AcquisitionDateTime") | ||
acquisition_date_time = "none" | ||
|
||
# try: | ||
# requisition_number = dic['RestingECG']['Order']['RequisitionNumber'] | ||
# except: | ||
# print("no requisition_number") | ||
# requisition_number = "none" | ||
|
||
# need to instantiate leads in the proper order for the model | ||
lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | ||
|
||
""" | ||
Each EKG will have this data structure: | ||
lead_data = { | ||
'I': np.array | ||
} | ||
""" | ||
|
||
lead_data = dict.fromkeys(lead_order) | ||
# lead_data = {leadid: None for k in lead_order} | ||
|
||
# for all_lead_data in dic['RestingECG']['Waveform']: | ||
# for single_lead_data in lead['LeadData']: | ||
# leadname = single_lead_data['LeadID'] | ||
# if leadname in (lead_order): | ||
|
||
for lead in dic['RestingECG']['Waveform']: | ||
for leadid in range(len(lead['LeadData'])): | ||
sample_length = len(decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData'])) | ||
# sample_length is equivalent to dic['RestingECG']['Waveform']['LeadData']['LeadSampleCountTotal'] | ||
if sample_length == 5000: | ||
lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array( | ||
lead['LeadData'][leadid]['WaveFormData'], downsample=1) | ||
elif sample_length == 2500: | ||
lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array( | ||
lead['LeadData'][leadid]['WaveFormData'], downsample=2) | ||
else: | ||
continue | ||
# ensures all leads have 2500 samples and also passes over the 3 second waveform | ||
|
||
lead_data['III'] = (np.array(lead_data["II"]) - np.array(lead_data["I"])) | ||
lead_data['aVR'] = -(np.array(lead_data["I"]) + np.array(lead_data["II"])) / 2 | ||
lead_data['aVF'] = (np.array(lead_data["II"]) + np.array(lead_data["III"])) / 2 | ||
lead_data['aVL'] = (np.array(lead_data["I"]) - np.array(lead_data["III"])) / 2 | ||
|
||
lead_data = {k: lead_data[k] for k in lead_order} | ||
# drops V3R, V4R, and V7 if it was a 15-lead ECG | ||
|
||
# now construct and reshape the array | ||
# converting the dictionary to an np.array | ||
temp = [] | ||
for key, value in lead_data.items(): | ||
temp.append(value) | ||
|
||
# transpose to be [time, leads, ] | ||
ecg_array = np.array(temp).T | ||
|
||
print(f'Writing row of ECG2AF predictions for ECG {patient_id}, at {acquisition_date_time}') | ||
ecg_array -= ecg_array.mean() | ||
ecg_array /= (ecg_array.std() + 1e-6) | ||
#print(f"Got tensor: {tensor.mean():0.3f}") | ||
prediction = model.predict(np.expand_dims(ecg_array, axis=0), verbose=0) | ||
if len(model.output_names) == 1: | ||
prediction = [prediction] | ||
predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} | ||
#print(f"Got predictions: {predictions_dict}") | ||
space_dict['filepath'].append(os.path.basename(filepath)) | ||
space_dict['patient_id'].append(patient_id) | ||
space_dict['acquisition_datetime'].append(acquisition_date_time) | ||
space_dict['pharma_unique_ecg_id'].append(pharma_unique_ecg_id) | ||
|
||
for otm in output_tensormaps.values(): | ||
y = predictions_dict[otm.output_name()] | ||
if otm.is_categorical(): | ||
space_dict[f'{otm.name}_prediction'].append(y[0, 1]) | ||
elif otm.is_continuous(): | ||
space_dict[f'{otm.name}_prediction'].append(y[0, 0]) | ||
elif otm.is_survival_curve(): | ||
intervals = otm.shape[-1] // 2 | ||
days_per_bin = 1 + (2 * otm.days_window) // intervals | ||
predicted_survivals = np.cumprod(y[:, :intervals], axis=1) | ||
space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1])) | ||
# print(f' got target: {target[otm.output_name()].numpy().shape}') | ||
# sick = np.sum(target[otm.output_name()].numpy()[:, intervals:], axis=-1) | ||
# follow_up = np.cumsum(target[otm.output_name()].numpy()[:, :intervals], axis=-1)[:, -1] * days_per_bin | ||
# space_dict[f'{otm.name}_event'].append(str(sick[b])) | ||
# space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b])) | ||
# Example: Use the model to make a prediction (add real processing logic here) | ||
|
||
def main(directory): | ||
# Iterate over all files in the specified directory | ||
space_dict = defaultdict(list) | ||
for i,filename in enumerate(os.listdir(directory)): | ||
filepath = os.path.join(directory, filename) | ||
if os.path.isfile(filepath): | ||
process_ge_muse_xml(filepath, space_dict) | ||
if i > 10000: | ||
break | ||
|
||
df = pd.DataFrame.from_dict(space_dict) | ||
df.to_csv('/output/ecg2af_quintuplet.csv', index=False) | ||
|
||
if __name__ == "__main__": | ||
# Take directory path from command-line arguments | ||
directory = sys.argv[1] if len(sys.argv) > 1 else "/data" | ||
main(directory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we running pip and pip3 almost consecutively?
I'm assuming it's a typo that would only be a problem in like a weird edge case of versioning, but if it's intentional.... I definitely have follow up questions