Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deployable ML4H Dockers #577

Open
wants to merge 117 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
7c903ed
test steps in diffusion plot
lucidtronix Oct 11, 2024
088020e
test steps in diffusion plot
lucidtronix Oct 16, 2024
9382159
test steps in diffusion plot
lucidtronix Oct 16, 2024
c4ca4cf
test steps in diffusion plot
lucidtronix Oct 18, 2024
26ea28c
test steps in diffusion plot
lucidtronix Oct 18, 2024
a07b7e5
test steps in diffusion plot
lucidtronix Oct 18, 2024
a15baa2
test steps in diffusion plot
lucidtronix Oct 18, 2024
a627f1f
test steps in diffusion plot
lucidtronix Oct 18, 2024
01b1756
test steps in diffusion plot
lucidtronix Oct 18, 2024
84386f1
test steps in diffusion plot
lucidtronix Oct 18, 2024
3435107
test steps in diffusion plot
lucidtronix Oct 18, 2024
ae083af
test steps in diffusion plot
lucidtronix Oct 18, 2024
50f8875
test steps in diffusion plot
lucidtronix Oct 18, 2024
a5c2ca3
test steps in diffusion plot
lucidtronix Oct 18, 2024
93978fd
test steps in diffusion plot
lucidtronix Oct 18, 2024
a95f727
test steps in diffusion plot
lucidtronix Oct 18, 2024
7c90a47
test steps in diffusion plot
lucidtronix Oct 29, 2024
53af85f
test steps in diffusion plot
lucidtronix Oct 29, 2024
9bfeb36
test steps in diffusion plot
lucidtronix Oct 29, 2024
c548720
test steps in diffusion plot
lucidtronix Oct 29, 2024
5617694
test steps in diffusion plot
lucidtronix Oct 29, 2024
4cb6b04
test steps in diffusion plot
lucidtronix Oct 29, 2024
16a3ba2
brain mri z index norm
lucidtronix Oct 30, 2024
97af51e
brain mri z index norm
lucidtronix Oct 30, 2024
d476edf
brain mri z index norm
lucidtronix Oct 30, 2024
dde6044
brain mri z index norm
lucidtronix Oct 31, 2024
0825553
brain mri z index norm
lucidtronix Oct 31, 2024
c716324
brain mri z index norm
lucidtronix Nov 1, 2024
1d1a19a
log auc and mse
lucidtronix Nov 8, 2024
3c538b3
log auc and mse
lucidtronix Nov 8, 2024
a543f43
log auc and mse
lucidtronix Nov 8, 2024
dfcdb33
log auc and mse
lucidtronix Nov 8, 2024
5d8132f
log auc and mse
lucidtronix Nov 8, 2024
8f357ab
log auc and mse
lucidtronix Nov 8, 2024
50ae738
log auc and mse
lucidtronix Nov 8, 2024
7b00f90
log auc and mse
lucidtronix Nov 8, 2024
28a7ea3
log auc and mse
lucidtronix Nov 8, 2024
f4734e7
log auc and mse
lucidtronix Nov 8, 2024
35b3a6d
log auc and mse
lucidtronix Nov 8, 2024
e510f82
log auc and mse
lucidtronix Nov 11, 2024
fa239b8
log auc and mse
lucidtronix Nov 11, 2024
c5f718c
log auc and mse
lucidtronix Nov 15, 2024
dca9f3e
log auc and mse
lucidtronix Nov 15, 2024
953aa69
log auc and mse
lucidtronix Nov 15, 2024
4e11ca2
log auc and mse
lucidtronix Nov 26, 2024
1df78f7
log auc and mse
lucidtronix Nov 26, 2024
f892448
cumc docker
lucidtronix Nov 27, 2024
43c857e
cumc docker
lucidtronix Nov 27, 2024
e495b48
cumc docker
lucidtronix Nov 27, 2024
13ee745
cumc docker
lucidtronix Nov 27, 2024
aa5c162
cumc docker
lucidtronix Nov 27, 2024
37db283
cumc docker
lucidtronix Nov 27, 2024
5befd1e
cumc docker
lucidtronix Nov 27, 2024
87b4cbc
cumc docker
lucidtronix Nov 27, 2024
02dc15a
cumc docker
lucidtronix Nov 27, 2024
1a93304
cumc docker
lucidtronix Nov 27, 2024
3393380
cumc docker
lucidtronix Nov 27, 2024
baa1875
cumc docker
lucidtronix Nov 27, 2024
afa912d
cumc docker
lucidtronix Nov 27, 2024
6fddd3d
cumc docker
lucidtronix Nov 27, 2024
3cef024
fix
lucidtronix Dec 3, 2024
82c8c9a
fix
lucidtronix Dec 3, 2024
aa2e4ea
fix
lucidtronix Dec 10, 2024
8c6b203
fix
lucidtronix Dec 10, 2024
ab95a1e
fix
lucidtronix Dec 10, 2024
3beef1b
fix
lucidtronix Dec 11, 2024
2ab67e1
fix
lucidtronix Dec 11, 2024
43381d5
fix
lucidtronix Dec 11, 2024
414aafa
fix
lucidtronix Dec 11, 2024
6921c4a
fix
lucidtronix Dec 12, 2024
1eee793
ecg 512
lucidtronix Dec 13, 2024
5928cd7
condition strategy
lucidtronix Dec 13, 2024
c53e449
condition strategy
lucidtronix Dec 13, 2024
82a9d67
condition strategy
lucidtronix Dec 15, 2024
d85a97d
condition strategy
lucidtronix Dec 16, 2024
0931fd7
condition strategy
lucidtronix Dec 16, 2024
7b49ea0
condition strategy
lucidtronix Dec 16, 2024
ab3c844
condition strategy
lucidtronix Dec 16, 2024
7f082be
condition strategy
lucidtronix Dec 16, 2024
ce66b0e
condition strategy
lucidtronix Dec 16, 2024
3e94259
condition strategy
lucidtronix Dec 16, 2024
4e7cebb
condition strategy
lucidtronix Dec 16, 2024
a3165b6
condition strategy
lucidtronix Dec 16, 2024
306888a
condition strategy
lucidtronix Dec 16, 2024
b146aa6
condition strategy
lucidtronix Dec 16, 2024
b7d5afb
condition strategy
lucidtronix Dec 16, 2024
e8e5228
condition strategy
lucidtronix Dec 16, 2024
46c5298
condition strategy
lucidtronix Dec 16, 2024
7a4f445
condition strategy
lucidtronix Dec 16, 2024
db3530a
condition strategy
lucidtronix Dec 19, 2024
baf60a3
sigmoid loss unconditioned
lucidtronix Dec 19, 2024
2df321c
sigmoid loss unconditioned
lucidtronix Dec 19, 2024
bb5feaa
sigmoid loss unconditioned
lucidtronix Dec 19, 2024
44807ae
sigmoid loss unconditioned
lucidtronix Dec 19, 2024
6b3ddfc
sigmoid loss unconditioned
lucidtronix Dec 20, 2024
9bf044a
sigmoid loss unconditioned
lucidtronix Dec 20, 2024
f95ee86
sigmoid loss unconditioned
lucidtronix Jan 2, 2025
1d40270
sigmoid loss unconditioned
lucidtronix Jan 3, 2025
2f0602d
sigmoid loss unconditioned
lucidtronix Jan 3, 2025
a5d518d
sigmoid loss unconditioned
lucidtronix Jan 3, 2025
ff81dee
kernel inception distance
lucidtronix Jan 3, 2025
7827909
condition and supervise
lucidtronix Jan 3, 2025
3229037
condition and supervise
lucidtronix Jan 8, 2025
c9c16dd
condition and supervise
lucidtronix Jan 8, 2025
1632042
condition and supervise
lucidtronix Jan 9, 2025
1f6bc03
condition and supervise
lucidtronix Jan 10, 2025
d2566d7
condition and supervise
lucidtronix Jan 11, 2025
2617ff7
condition and supervise
lucidtronix Jan 12, 2025
6a90d26
condition and supervise
lucidtronix Jan 13, 2025
44ac183
condition and supervise
lucidtronix Jan 13, 2025
75e86e6
condition and supervise
lucidtronix Jan 13, 2025
d65c3e9
condition and supervise
lucidtronix Jan 13, 2025
0a32e0b
condition and supervise
lucidtronix Jan 13, 2025
4f6f062
condition and supervise
lucidtronix Jan 13, 2025
5cd6e08
condition and supervise
lucidtronix Jan 13, 2025
1ca98be
condition and supervise
lucidtronix Jan 13, 2025
fc30746
condition and supervise
lucidtronix Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docker/ml4h_deploy/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM ghcr.io/broadinstitute/ml4h:tf2.9-latest-cpu

# Set the working directory
WORKDIR /app

# Install TensorFlow (or any other necessary libraries)
RUN pip install tensorflow

# Copy the Keras model file into the Docker image
COPY ecg2af_quintuplet_v2024_01_13.h5 /app/ecg2af_quintuplet_v2024_01_13.h5

# Copy the Python script
COPY process_files.py /app/process_files.py

RUN pip3 install ml4h
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we running pip and pip3 almost consecutively?

I'm assuming it's a typo that would only be a problem in like a weird edge case of versioning, but if it's intentional.... I definitely have follow up questions


# Define the command to run the script
CMD ["python", "process_files.py", "/data"]
250 changes: 250 additions & 0 deletions docker/ml4h_deploy/process_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import os
import sys
import base64
import struct
from collections import defaultdict

import h5py
import xmltodict
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from ml4h.TensorMap import TensorMap, Interpretation
from ml4h.defines import ECG_REST_AMP_LEADS
from ml4h.models.model_factory import get_custom_objects

n_intervals = 25

ecg_tmap = TensorMap(
'ecg_5000_std',
Interpretation.CONTINUOUS,
shape=(5000, 12),
channel_map=ECG_REST_AMP_LEADS
)

af_tmap = TensorMap(
'survival_curve_af',
Interpretation.SURVIVAL_CURVE,
shape=(n_intervals*2,),
)

death_tmap = TensorMap(
'death_event',
Interpretation.SURVIVAL_CURVE,
shape=(n_intervals*2,),
)

sex_tmap = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male':1})
age_tmap = TensorMap(name='age_in_days', interpretation=Interpretation.CONTINUOUS, channel_map={'age_in_days': 0})
af_in_read_tmap = TensorMap(name='af_in_read', interpretation=Interpretation.CATEGORICAL, channel_map={'no_af_in_read': 0, 'af_in_read':1})

output_tensormaps = {tm.output_name(): tm for tm in [af_tmap, death_tmap, sex_tmap, age_tmap, af_in_read_tmap]}
custom_dict = get_custom_objects(list(output_tensormaps.values()))
model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict)
space_dict = defaultdict(list)

def process_ukb_hd5(filepath, space_dict):
# Placeholder for file processing logic
print(f"Processing file: {filepath}")
with h5py.File(filepath, 'r') as hd5:
ecg_array = np.zeros(ecg_tmap.shape, dtype=np.float32)
for lead in ecg_tmap.channel_map:
ecg_array[:, ecg_tmap.channel_map[lead]] = hd5[f'/ukb_ecg_rest/strip_{lead}/instance_0']

ecg_array -= ecg_array.mean()
ecg_array /= (ecg_array.std() + 1e-6)
#print(f"Got tensor: {tensor.mean():0.3f}")
prediction = model.predict(np.expand_dims(ecg_array, axis=0), verbose=0)
if len(model.output_names) == 1:
prediction = [prediction]
predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)}
#print(f"Got predictions: {predictions_dict}")
space_dict['sample_id'].append(os.path.basename(filepath).replace('.hd5', ''))
space_dict['ecg_path'].append(filepath)
if '/dates/atrial_fibrillation_or_flutter_date' in hd5:
space_dict['has_af'].append(1)
else:
space_dict['has_af'].append(0)

for otm in output_tensormaps.values():
y = predictions_dict[otm.output_name()]
if otm.is_categorical():
space_dict[f'{otm.name}_prediction'].append(y[0, 1])
elif otm.is_continuous():
space_dict[f'{otm.name}_prediction'].append(y[0, 0])
elif otm.is_survival_curve():
intervals = otm.shape[-1] // 2
days_per_bin = 1 + (2 * otm.days_window) // intervals
predicted_survivals = np.cumprod(y[:, :intervals], axis=1)
space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1]))
# print(f' got target: {target[otm.output_name()].numpy().shape}')
# sick = np.sum(target[otm.output_name()].numpy()[:, intervals:], axis=-1)
# follow_up = np.cumsum(target[otm.output_name()].numpy()[:, :intervals], axis=-1)[:, -1] * days_per_bin
# space_dict[f'{otm.name}_event'].append(str(sick[b]))
# space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b]))
# Example: Use the model to make a prediction (add real processing logic here)

def decode_ekg_muse(raw_wave):
"""
Ingest the base64 encoded waveforms and transform to numeric
"""
# covert the waveform from base64 to byte array
arr = base64.b64decode(bytes(raw_wave, 'utf-8'))

# unpack every 2 bytes, little endian (16 bit encoding)
unpack_symbols = ''.join([char * int(len(arr) / 2) for char in 'h'])
byte_array = struct.unpack(unpack_symbols, arr)
return byte_array

def decode_ekg_muse_to_array(raw_wave, downsample=1):
"""
Ingest the base64 encoded waveforms and transform to numeric

downsample: 0.5 takes every other value in the array. Muse samples at 500/s and the sample model requires 250/s. So take every other.
"""
try:
dwnsmpl = int(1 // downsample)
except ZeroDivisionError:
print("You must downsample by more than 0")
# covert the waveform from base64 to byte array
arr = base64.b64decode(bytes(raw_wave, 'utf-8'))

# unpack every 2 bytes, little endian (16 bit encoding)
unpack_symbols = ''.join([char * int(len(arr) / 2) for char in 'h'])
byte_array = struct.unpack(unpack_symbols, arr)
return np.array(byte_array)[::dwnsmpl]

def process_ge_muse_xml(filepath, space_dict):
with open(filepath, 'rb') as fd:
dic = xmltodict.parse(fd.read().decode('utf8'))

"""

Upload the ECG as numpy array with shape=[2500,12,1] ([time, leads, 1]).

The voltage unit should be in 1 mv/unit and the sampling rate should be 250/second (total 10 second).

The leads should be ordered as follow I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6.

"""
try:
patient_id = dic['RestingECG']['PatientDemographics']['PatientID']
except:
print("no PatientID")
patient_id = "none"
try:
pharma_unique_ecg_id = dic['RestingECG']['PharmaData']['PharmaUniqueECGID']
except:
print("no PharmaUniqueECGID")
pharma_unique_ecg_id = "none"
try:
acquisition_date_time = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + "_" + \
dic['RestingECG']['TestDemographics']['AcquisitionTime'].replace(":", "-")
except:
print("no AcquisitionDateTime")
acquisition_date_time = "none"

# try:
# requisition_number = dic['RestingECG']['Order']['RequisitionNumber']
# except:
# print("no requisition_number")
# requisition_number = "none"

# need to instantiate leads in the proper order for the model
lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

"""
Each EKG will have this data structure:
lead_data = {
'I': np.array
}
"""

lead_data = dict.fromkeys(lead_order)
# lead_data = {leadid: None for k in lead_order}

# for all_lead_data in dic['RestingECG']['Waveform']:
# for single_lead_data in lead['LeadData']:
# leadname = single_lead_data['LeadID']
# if leadname in (lead_order):

for lead in dic['RestingECG']['Waveform']:
for leadid in range(len(lead['LeadData'])):
sample_length = len(decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData']))
# sample_length is equivalent to dic['RestingECG']['Waveform']['LeadData']['LeadSampleCountTotal']
if sample_length == 5000:
lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array(
lead['LeadData'][leadid]['WaveFormData'], downsample=1)
elif sample_length == 2500:
lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array(
lead['LeadData'][leadid]['WaveFormData'], downsample=2)
else:
continue
# ensures all leads have 2500 samples and also passes over the 3 second waveform

lead_data['III'] = (np.array(lead_data["II"]) - np.array(lead_data["I"]))
lead_data['aVR'] = -(np.array(lead_data["I"]) + np.array(lead_data["II"])) / 2
lead_data['aVF'] = (np.array(lead_data["II"]) + np.array(lead_data["III"])) / 2
lead_data['aVL'] = (np.array(lead_data["I"]) - np.array(lead_data["III"])) / 2

lead_data = {k: lead_data[k] for k in lead_order}
# drops V3R, V4R, and V7 if it was a 15-lead ECG

# now construct and reshape the array
# converting the dictionary to an np.array
temp = []
for key, value in lead_data.items():
temp.append(value)

# transpose to be [time, leads, ]
ecg_array = np.array(temp).T

print(f'Writing row of ECG2AF predictions for ECG {patient_id}, at {acquisition_date_time}')
ecg_array -= ecg_array.mean()
ecg_array /= (ecg_array.std() + 1e-6)
#print(f"Got tensor: {tensor.mean():0.3f}")
prediction = model.predict(np.expand_dims(ecg_array, axis=0), verbose=0)
if len(model.output_names) == 1:
prediction = [prediction]
predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)}
#print(f"Got predictions: {predictions_dict}")
space_dict['filepath'].append(os.path.basename(filepath))
space_dict['patient_id'].append(patient_id)
space_dict['acquisition_datetime'].append(acquisition_date_time)
space_dict['pharma_unique_ecg_id'].append(pharma_unique_ecg_id)

for otm in output_tensormaps.values():
y = predictions_dict[otm.output_name()]
if otm.is_categorical():
space_dict[f'{otm.name}_prediction'].append(y[0, 1])
elif otm.is_continuous():
space_dict[f'{otm.name}_prediction'].append(y[0, 0])
elif otm.is_survival_curve():
intervals = otm.shape[-1] // 2
days_per_bin = 1 + (2 * otm.days_window) // intervals
predicted_survivals = np.cumprod(y[:, :intervals], axis=1)
space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1]))
# print(f' got target: {target[otm.output_name()].numpy().shape}')
# sick = np.sum(target[otm.output_name()].numpy()[:, intervals:], axis=-1)
# follow_up = np.cumsum(target[otm.output_name()].numpy()[:, :intervals], axis=-1)[:, -1] * days_per_bin
# space_dict[f'{otm.name}_event'].append(str(sick[b]))
# space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b]))
# Example: Use the model to make a prediction (add real processing logic here)

def main(directory):
# Iterate over all files in the specified directory
space_dict = defaultdict(list)
for i,filename in enumerate(os.listdir(directory)):
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath):
process_ge_muse_xml(filepath, space_dict)
if i > 10000:
break

df = pd.DataFrame.from_dict(space_dict)
df.to_csv('/output/ecg2af_quintuplet.csv', index=False)

if __name__ == "__main__":
# Take directory path from command-line arguments
directory = sys.argv[1] if len(sys.argv) > 1 else "/data"
main(directory)
21 changes: 21 additions & 0 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,27 @@ def parse_args():
help='For diffusion models, when U-Net representation size is smaller than attention_window '
'Cross-Attention is applied',
)
parser.add_argument(
'--attention_modulo', default=3, type=int,
help='For diffusion models, this controls how frequently Cross-Attention is applied. '
'2 means every other residual block, 3 would mean every third.',
)
parser.add_argument(
'--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'],
help='For diffusion models, this controls conditional embeddings are integrated into the U-NET',
)
parser.add_argument(
'--diffusion_loss', default='sigmoid',
help='Loss function to use for diffusion models. Can be sigmoid, mean_absolute_error, or mean_squared_error',
)
parser.add_argument(
'--sigmoid_beta', default=-3, type=float,
help='Beta to use with sigmoid loss for diffusion models.',
)
parser.add_argument(
'--supervision_scalar', default=0.01, type=float,
help='For `train_diffusion_supervise` mode, this weights the supervision loss from phenotype prediction on denoised data.',
)
parser.add_argument(
'--transformer_size', default=32, type=int,
help='Number of output neurons in Transformer encoders and decoders, '
Expand Down
66 changes: 66 additions & 0 deletions ml4h/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# metrics.py
import logging

import keras
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
Expand Down Expand Up @@ -714,3 +716,67 @@ def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1
"""
w = np.ones_like(estimate)
return _estimate_concordance_index(event_indicator, event_time, estimate, w, tied_tol)


class KernelInceptionDistance(keras.metrics.Metric):
def __init__(self, name, input_shape, kernel_image_size, **kwargs):
super().__init__(name=name, **kwargs)

# KID is estimated per batch and is averaged across batches
self.kid_tracker = keras.metrics.Mean(name="kid_tracker")

# a pretrained InceptionV3 is used without its classification layer
# transform the pixel values to the 0-255 range, then use the same
# preprocessing as during pretraining
self.encoder = keras.Sequential(
[
keras.Input(shape=input_shape), # TODO: handle multi-channel
keras.layers.Lambda(lambda x: tf.tile(x, [1, 1, 1, 3])),
keras.layers.Rescaling(255.0),
keras.layers.Resizing(height=kernel_image_size, width=kernel_image_size),
keras.layers.Lambda(keras.applications.inception_v3.preprocess_input),
keras.applications.InceptionV3(
include_top=False,
input_shape=(kernel_image_size, kernel_image_size, 3),
weights="imagenet",
),
keras.layers.GlobalAveragePooling2D(),
],
name="inception_encoder",
)

def polynomial_kernel(self, features_1, features_2):
feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0

def update_state(self, real_images, generated_images, sample_weight=None):
real_features = self.encoder(real_images, training=False)
generated_features = self.encoder(generated_images, training=False)

# compute polynomial kernels using the two sets of features
kernel_real = self.polynomial_kernel(real_features, real_features)
kernel_generated = self.polynomial_kernel(
generated_features, generated_features
)
kernel_cross = self.polynomial_kernel(real_features, generated_features)

# estimate the squared maximum mean discrepancy using the average kernel values
batch_size = tf.shape(real_features)[0]
batch_size_f = tf.cast(batch_size, dtype=tf.float32)
mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
batch_size_f * (batch_size_f - 1.0)
)
mean_kernel_generated = tf.reduce_sum(
kernel_generated * (1.0 - tf.eye(batch_size))
) / (batch_size_f * (batch_size_f - 1.0))
mean_kernel_cross = tf.reduce_mean(kernel_cross)
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

# update the average KID estimate
self.kid_tracker.update_state(kid)

def result(self):
return self.kid_tracker.result()

def reset_state(self):
self.kid_tracker.reset_state()
Loading
Loading