Skip to content

Commit

Permalink
saving work in progress on dl
Browse files Browse the repository at this point in the history
  • Loading branch information
JBris committed Sep 20, 2024
1 parent 91694e0 commit c213d39
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 48 deletions.
10 changes: 5 additions & 5 deletions app/conf/calibration_form/common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -386,21 +386,21 @@ components:
collapsible: true
children:
- id: upload-summary-data-file-button
label: Upload
label: Upload statistics
help: Upload summary statistics data from a csv file
class_name: dash.dcc.Upload
handler: file_upload
kwargs:
children: Load statistics data
- id: upload-obs-data-file-button
label: Upload
label: Upload simulation
help: Upload simulated root data from a csv file
class_name: dash.dcc.Upload
handler: file_upload
kwargs:
children: Load simulation data
- id: upload-edge-data-file-button
label: Upload
label: Upload edges
help: Upload edge root data from a csv file
class_name: dash.dcc.Upload
handler: file_upload
Expand Down Expand Up @@ -694,7 +694,7 @@ components:
type: number
min: 1
step: 1
value: 10
value: 5
persistence: true
- id: draws-input
param: pp_samples
Expand All @@ -716,7 +716,7 @@ components:
type: number
min: 1
step: 1
value: 50
value: 5
persistence: true
- id: num-transforms-input
param: nn_num_transforms
Expand Down
209 changes: 199 additions & 10 deletions app/flows/run_snpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Imports
######################################

# mypy: ignore-errors

# isort: off

# This is for compatibility with Prefect.
Expand All @@ -13,15 +15,18 @@
# isort: on

import os.path as osp
from random import choices
from typing import Callable

import mlflow
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from joblib import dump as calibrator_dump
from matplotlib import pyplot as plt
from prefect import flow, task
Expand All @@ -36,6 +41,8 @@
prepare_for_sbi,
simulate_for_sbi,
)
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.pool import global_max_pool

from deeprootgen.calibration import (
SnpeModel,
Expand All @@ -46,6 +53,7 @@
)
from deeprootgen.data_model import RootCalibrationModel, SummaryStatisticsModel
from deeprootgen.io import save_graph_to_db
from deeprootgen.model import RootSystemGraph
from deeprootgen.pipeline import (
begin_experiment,
get_datetime_now,
Expand Down Expand Up @@ -101,6 +109,7 @@ def prepare_task(input_parameters: RootCalibrationModel) -> tuple:

data_type = v["data_type"]
if data_type == "discrete":
lower_bound, upper_bound = int(lower_bound), int(upper_bound)
replicates = np.floor(upper_bound - lower_bound).astype("int")
probabilities = torch.tensor([1 / replicates])
probabilities = probabilities.repeat(replicates)
Expand All @@ -122,6 +131,146 @@ def prepare_task(input_parameters: RootCalibrationModel) -> tuple:
return names, priors, limits, statistics_list


# @TODO this is a hack for compatibility with the sbi API,
# and should be replaced with a GNN feature extractor surrogate.
# i.e. we train a separate feature extractor to provide graph embeddings,
# then use that feature extractor instead of this embedding net.
class GraphFeatureExtractor(torch.nn.Module):
"""A graph feature extractor for density estimation."""

def __init__(self, organ_columns: list[str]) -> None:
"""GraphFeatureExtractor constructor.
Args:
organ_columns (list[str]):
The list of organ columns for grouping organ features.
"""
super().__init__()
self.organ_columns = organ_columns

self.transform = T.Compose([T.NormalizeFeatures(organ_columns)])

G = RootSystemGraph()
organ_features = []
for organ_column in organ_columns:
organ_features.extend(G.organ_columns[organ_column])
self.organ_features = organ_features

num_organ_features = len(organ_features)
self.num_organ_features = num_organ_features
self.conv1 = SAGEConv(
num_organ_features,
num_organ_features * 4,
aggr="mean",
normalize=True,
bias=True,
)
self.conv2 = SAGEConv(
num_organ_features * 4,
num_organ_features * 2,
aggr="mean",
normalize=True,
bias=True,
)

self.fc = torch.nn.Linear(num_organ_features * 2, num_organ_features)
self.pool = global_max_pool
self.activation = F.elu

self.G_list: list = []

def process_graph(self, G: nx.Graph) -> tuple:
"""Process a new NetworkX graph.
Args:
G (nx.Graph):
The NetworkX graph.
Returns:
tuple:
The node and edge features.
"""
for column in self.organ_columns:
G[column] = torch.Tensor(pd.DataFrame(G[column]).values).double()

train_data = self.transform(G)
organ_features = []
for column in self.organ_columns:
organ_features.append(train_data[column])

x = torch.Tensor(np.hstack(organ_features))
edge_index = train_data.edge_index
return x, edge_index

def add_graph(self, G: nx.Graph) -> int:
"""Add a graph to the graph list.
Args:
G (nx.Graph):
The NetworkX graph.
Returns:
int:
The list index.
"""
x, edge_index = self.process_graph(G)

self.G_list.append((x, edge_index))
return len(self.G_list) - 1

def encode(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
"""Construct graph embeddings from node and edges.
Args:
x (torch.Tensor):
The node features.
edge_index (torch.Tensor):
The edge index.
Returns:
torch.Tensor:
The graph embeddings.
"""
batch_index = torch.Tensor(np.repeat(0, x.shape[0])).type(torch.int64)

x = self.conv1(x, edge_index)
x = self.activation(x)
x = self.conv2(x, edge_index)
x = self.activation(x)
x = self.pool(x, batch_index)
x = self.activation(x)
x = self.fc(x)
x = x.view(-1)

return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The forward pass.
Args:
x (torch.Tensor):
The batch tensor.
Returns:
torch.Tensor:
The graph embedding.
"""
if x.shape[1] > 1:
return x

batch_size = x.shape[0]
indices = np.array(range(batch_size))

batches = []
batch = choices(self.G_list, k=batch_size)
for i in indices:
x, edge_index = batch[i]
x = self.encode(x, edge_index)
batches.append(x)
x = torch.stack(batches)
return x


@task
def perform_task(
input_parameters: RootCalibrationModel,
Expand All @@ -145,23 +294,41 @@ def perform_task(
tuple:
The trained model and samples.
"""
use_summary_statistics: bool = (
input_parameters.statistics_comparison.use_summary_statistics
)
if use_summary_statistics:
embedding_net = nn.Identity()
else:
organ_columns = ["organ_coordinates", "organ_hierarchy", "organ_size"]
embedding_net = GraphFeatureExtractor(organ_columns)

def simulator_func(theta: np.ndarray) -> np.ndarray:
theta = theta.detach().cpu().numpy()
parameter_specs = {}
for i, name in enumerate(names):
parameter_specs[name] = theta[i]

simulated, _ = calculate_summary_statistics(
parameter_specs, input_parameters, statistics_list
)
if use_summary_statistics:
simulated, _ = calculate_summary_statistics(
parameter_specs, input_parameters, statistics_list
)
else:
simulation, _ = run_calibration_simulation(
parameter_specs, input_parameters
)

G = simulation.G.as_torch(drop=True)
indx = embedding_net.add_graph(G)
simulated = np.array([indx]).astype("int")

return simulated

calibration_parameters = input_parameters.calibration_parameters
simulator, prior = prepare_for_sbi(simulator_func, priors)
neural_posterior = utils.posterior_nn(
model="nsf",
embedding_net=embedding_net,
hidden_features=calibration_parameters["nn_num_hidden_features"],
num_transforms=calibration_parameters["nn_num_transforms"],
)
Expand All @@ -177,16 +344,34 @@ def simulator_func(theta: np.ndarray) -> np.ndarray:
inference = inference.append_simulations(theta, x, data_device="cpu")
density_estimator = inference.train()
posterior = inference.build_posterior(density_estimator)

calibration_parameters = input_parameters.calibration_parameters
n_draws = calibration_parameters["pp_samples"]
observed_values = []
for statistic in statistics_list:
observed_values.append(statistic.statistic_value)
posterior.set_default_x(observed_values)
posterior_samples = posterior.sample((n_draws,), x=observed_values)

observed_values = [statistic.dict() for statistic in statistics_list]
if use_summary_statistics:
observed_values = []
for statistic in statistics_list:
observed_values.append(statistic.statistic_value)
posterior.set_default_x(observed_values)
posterior_samples = posterior.sample((n_draws,), x=observed_values)
observed_values = [statistic.dict() for statistic in statistics_list]
else:
root_g = RootSystemGraph()
observed_data_content = input_parameters.observed_data_content
raw_edge_content = input_parameters.raw_edge_content
node_df, edge_df = root_g.from_content_string(
observed_data_content, raw_edge_content
)
G = root_g.as_torch(node_df, edge_df, drop=True)
x, edge_index = embedding_net.process_graph(G)

with torch.no_grad():
observed_values = embedding_net.encode(x, edge_index)

posterior.set_default_x(observed_values)
posterior_samples = posterior.sample((n_draws,), x=observed_values)
embedding_net.G_list = []
observed_values = (node_df, edge_df)

return inference, simulator, prior, posterior, posterior_samples, observed_values


Expand Down Expand Up @@ -234,6 +419,9 @@ def log_task(
tuple:
The simulation and its parameters.
"""
# use_summary_statistics: bool = (
# input_parameters.statistics_comparison.use_summary_statistics
# )
time_now = get_datetime_now()
outdir = get_outdir()

Expand Down Expand Up @@ -293,6 +481,7 @@ def log_task(
description="# Simulation-based calibration metrics.",
)

num_bins = None
if sbc_draws <= 20: # type: ignore
num_bins = sbc_draws

Expand Down
Loading

0 comments on commit c213d39

Please sign in to comment.