Skip to content

Commit

Permalink
Refactor python api (#39)
Browse files Browse the repository at this point in the history
Refactor python api
  • Loading branch information
banasraf authored Sep 24, 2020
1 parent afd9534 commit 6d057bf
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 71 deletions.
5 changes: 4 additions & 1 deletion python/emt6ro/ga_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ def predict(self, protocols):
"""
assert len(protocols) == self.num_protocols
for experiment in self.experiments:
experiment.run(protocols)
experiment.add_irradiations(protocols)
experiment.run(self.num_steps)
results = [experiment.get_results().reshape((1, self.num_protocols, -1)) for experiment in self.experiments]
results = np.concatenate(results)
results = results.mean(0).mean(1)
fitness = 1500 - results
for experiment in self.experiments:
experiment.reset()
return fitness
23 changes: 19 additions & 4 deletions python/emt6ro/simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def __init__(self, params, tumors, runs, protocols_num, gpu_id=0, simulation_ste
self.result_shape = (protocols_num, len(tumors), runs)
self._experiment = _Experiment(params, tumors, runs, protocols_num, simulation_steps, protocol_resolution, gpu_id)

def run(self, protocols):
def run(self, nsteps):
"""
Start the simulation for given protocols.
Run nsteps steps of the simulation
Parameters
----------
protocols - list of lists of pairs (irradiation time in steps, dose)
nsteps - number of simulation steps
"""
self._experiment.run(protocols)
self._experiment.run(nsteps)

def get_results(self):
"""
Expand All @@ -48,3 +48,18 @@ def get_results(self):
res = np.array(self._experiment.results())
return res.reshape(self.result_shape)

def add_irradiations(self, protocols):
"""
Add irradiations.
Parameters
----------
protocols - list of pairs (time, dose)
"""
self._experiment.add_irradiations(protocols)

def reset(self):
"""
Restore the initial states of the simulations
"""
self._experiment.reset()
99 changes: 59 additions & 40 deletions python/emt6ro/simulation/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ class Experiment {
int protocols_num; // number of protocols
uint32_t simulation_steps;
uint32_t protocol_resolution;
std::vector<HostGrid<Site>> tumors_data;
device::buffer<Site> tumors_data;
int64_t protocol_data_size;
device::Guard device_guard;
device::buffer<float> protocols_data;
device::unique_ptr<float[]> protocols_data;
std::vector<Protocol> protocols;
std::random_device rd{};
Parameters params;
Simulation simulation;
Expand All @@ -42,37 +43,69 @@ class Experiment {
, protocols_num{protocols_num}
, simulation_steps(sim_steps)
, protocol_resolution(prot_resolution)
, tumors_data{}
, tumors_data(tests_num * protocols_num * acc_volume(tumors))
, protocol_data_size{(sim_steps + prot_resolution - 1) / prot_resolution}
, device_guard{device_id}
, protocols_data(protocol_data_size * protocols_num)
, protocols_data()
, protocols(protocols_num * tests_num * tumors_num)
, params(params)
, simulation(tumors_num * tests_num * protocols_num, params, rd()) {
for (auto t: tumors) {
tumors_data.push_back(*t);
device::Guard d_g{device_id};
float *p_data;
cudaMallocManaged(&p_data, protocol_data_size * protocols_num * sizeof(float));
protocols_data = device::unique_ptr<float[]>(p_data, device::Deleter{device_id});
uint32_t index = 0;
for (int p = 0; p < protocols_num; ++p) {
for (auto tumor : tumors) {
for (int t = 0; t < tests_num; ++t) {
tumors_data.copyHost(tumor->view().data, tumor->view().dims.vol(), index, simulation.stream());
index += tumor->view().dims.vol();
}
}
}
for (int p = 0; p < protocols_num; ++p) {
Protocol prot{protocol_resolution, simulation_steps,
protocols_data.get() + p * protocol_data_size};
for (int t = 0; t < tumors_num * tests_num; ++t) {
protocols[p * tumors_num * tests_num + t] = prot;
}
}
simulation.setProtocols(protocols.data());
reset();
}

static int32_t acc_volume(const std::vector<HostGrid<Site>*> &tumors) {
int size = 0;
for (auto &t : tumors) size += t->view().dims.vol();
return size;
}

void reset() {
device::Guard d_g{device_id};
simulation.reset();
simulation.setState(tumors_data.data());
memset(protocols_data.get(), 0 , protocol_data_size * protocols_num * sizeof(float));
}

void run(const std::vector<std::vector<std::pair<int, float>>> &protocols) {
if (running)
throw std::runtime_error("Experiment already running.");
if (protocols.size() != protocols_num)
throw std::runtime_error("Wrong number of protocols.");
void addIrradiations(const std::vector<std::vector<std::pair<int, float>>> &ps) {
ENFORCE(protocols.size() == protocols_num, "Wrong number of protocols.");
device::Guard d_g{device_id};
for (int i = 0; i < protocols_num; ++i) {
auto p = protocols[i * tumors_num * tests_num];
for (auto time_dose : ps[i]) {
p.closestDose(time_dose.first) += time_dose.second;
}
}
}

void run(int nsteps) {
ENFORCE(!running, "Experiment already running.");
running = true;
results_ = std::async(std::launch::async, [&, protocols]() {
results_ = std::async(std::launch::async, [&, nsteps]() {
device::Guard d_g{device_id};
prepareProtocolsData(protocols);
for (int p = 0; p < protocols_num; ++p) {
Protocol prot{protocol_resolution, simulation_steps,
protocols_data.data() + p * protocol_data_size};
for (auto &tumor : tumors_data) {
simulation.sendData(tumor, prot, tests_num);
}
}
simulation.run(simulation_steps);
simulation.run(nsteps);
std::vector<uint32_t> res(tumors_num * tests_num * protocols_num);
simulation.getResults(res.data());
simulation = Simulation(tumors_num * tests_num * protocols_num, params, rd());
return res;
});
}
Expand All @@ -85,22 +118,6 @@ class Experiment {
return results_.get();
}

private:
void prepareProtocolsData(const std::vector<std::vector<std::pair<int, float>>> &protocols) {
std::vector<float> host_protocol(protocol_data_size);
size_t p_i = 0;
for (const auto &protocol: protocols) {
std::fill(host_protocol.begin(), host_protocol.end(), 0);
for (const auto &irradiation: protocol) {
auto i = irradiation.first / protocol_resolution;
host_protocol[i] += irradiation.second;
}
protocols_data.copyHost(host_protocol.data(), protocol_data_size, p_i * protocol_data_size,
simulation.stream());
++p_i;
}
}

};

PYBIND11_MODULE(backend, m) {
Expand All @@ -121,9 +138,11 @@ PYBIND11_MODULE(backend, m) {
py::class_<Experiment>(m, "_Experiment")
.def(py::init<const Parameters&, std::vector<HostGrid<Site>*>, int, int, int, int, int>())
.def("run", &Experiment::run)
.def("results", &Experiment::results);
.def("results", &Experiment::results)
.def("add_irradiations", &Experiment::addIrradiations)
.def("reset", &Experiment::reset);

m.def("load_parameters", &Parameters::loadFromJSONFile);
m.def("load_parameters", &Parameters::loadFromJSONFile);

m.def("load_state", &loadFromFile);

Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name='emt6ro',
version='0.1.3.1',
version='0.1.4.0',
author="Rafal Banas",
author_email="[email protected]",
description="",
Expand Down
6 changes: 6 additions & 0 deletions src/emt6ro/common/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ struct Protocol {
return 0;
}

__host__ __device__ inline float &closestDose(uint32_t step) {
return data_[step / step_resolution_];
}

__device__ void reset();

uint32_t step_resolution_;
uint32_t length_;
float *data_;
Expand Down
5 changes: 1 addition & 4 deletions src/emt6ro/common/random-engine-test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ __global__ void fillUniform(float *data, curandState_t *states) {

TEST(RandomGenerator, FillUniform) {
device::buffer<float> data(2 * 1024);
std::vector<uint32_t> h_seeds(2*1024);
for (size_t i = 0; i < 2 * 1024; ++i) h_seeds[i] = i;
auto d_seeds = device::buffer<uint32_t>::fromHost(h_seeds.data(), 2*1024);
CuRandEngineState state(2 * 1024, d_seeds.data());
CuRandEngineState state(2 * 1024, 1997);
fillUniform<<<2, 1024>>>(data.data(), state.states());
auto h_data = data.toHost();
for (size_t i = 0; i < 2 * 1024; ++i) {
Expand Down
16 changes: 8 additions & 8 deletions src/emt6ro/common/random-engine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@ namespace emt6ro {

namespace detail {

__global__ void initializeState(curandState_t *state, const uint32_t *seeds, size_t size) {
__global__ void initializeState(curandState_t *state, uint32_t seed, size_t size) {
auto i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= size) return;
curand_init(seeds[i], 0, 0, &state[i]);
curand_init(seed, i, 0, &state[i]);
}

void init(curandState_t *state_data, const uint32_t *seeds, size_t size, cudaStream_t stream) {
void init(curandState_t *state_data, uint32_t seed, size_t size, cudaStream_t stream) {
auto mbs = CuBlockDimX * CuBlockDimY;
auto blocks = div_ceil(size, mbs);
auto block_size = (size > mbs) ? mbs : size;
detail::initializeState<<<blocks, block_size, 0, stream>>>(state_data, seeds, size);
detail::initializeState<<<blocks, block_size, 0, stream>>>(state_data, seed, size);
}

} // namespace detail

CuRandEngineState::CuRandEngineState(size_t size, const uint32_t* seeds) : state_(size) {
init(seeds);
CuRandEngineState::CuRandEngineState(size_t size, uint32_t seed, cudaStream_t stream) : state_(size) {
init(seed, stream);
}

CuRandEngineState::CuRandEngineState(size_t size): state_(size) {}

void CuRandEngineState::init(const uint32_t *seeds, cudaStream_t stream) {
detail::init(state_.data(), seeds, state_.size(), stream);
void CuRandEngineState::init(uint32_t seed, cudaStream_t stream) {
detail::init(state_.data(), seed, state_.size(), stream);
}

__device__ float CuRandEngine::uniform() {
Expand Down
4 changes: 2 additions & 2 deletions src/emt6ro/common/random-engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace emt6ro {

class CuRandEngineState {
public:
CuRandEngineState(size_t size, const uint32_t *seeds);
CuRandEngineState(size_t size, uint32_t seed, cudaStream_t stream = nullptr);

explicit CuRandEngineState(size_t size);

Expand All @@ -21,7 +21,7 @@ class CuRandEngineState {
* @param seeds - pointer to device data with random seeds
* @param stream - cuda stream
*/
void init(const uint32_t *seeds, cudaStream_t stream = nullptr);
void init(const uint32_t seed, cudaStream_t stream = nullptr);

inline const curandState_t *states() const {
return state_.data();
Expand Down
12 changes: 9 additions & 3 deletions src/emt6ro/diffusion/diffusion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

namespace emt6ro {

constexpr int kBlockDimX = 32;
constexpr int kBlockDimY = 32;
constexpr int kBlockDiv = (55 + kBlockDimX - 1) / kBlockDimX;
constexpr int kSitesPerThread = kBlockDiv * kBlockDiv;

__device__ void fillBorderMask(GridView<uint8_t> mask, int32_t max_dist) {
float midr = mask.dims.height / 2.f;
float midc = mask.dims.width / 2.f;
Expand Down Expand Up @@ -134,10 +139,11 @@ __global__ void diffusionKernel(GridView<Site> *lattices, const ROI *rois,
extern __shared__ Substrates tmp_mem[];
auto lattice = lattices[blockIdx.x];
auto roi = rois[blockIdx.x];
if (roi.dims.height == 0 || roi.dims.width == 0) return;
Dims bordered_dims(roi.dims.height + 4, roi.dims.width + 4);
GridView<Substrates> tmp_grid{tmp_mem, bordered_dims};
Substrates diff[SitesPerThread/4];
Coords sites[SitesPerThread/4];
Substrates diff[kSitesPerThread];
Coords sites[kSitesPerThread];
int16_t nsites = 0;
GridView<const uint8_t> b_mask{border_masks + lattice.dims.vol() * blockIdx.x,
Dims(roi.dims.height + 2, roi.dims.width + 2)};
Expand Down Expand Up @@ -188,7 +194,7 @@ void batchDiffusion(GridView<Site> *lattices, const ROI *rois, const uint8_t *bo
const Parameters::Diffusion &params, Substrates external_levels, int16_t steps,
Dims dims, int32_t batch_size, cudaStream_t stream) {
auto shared_mem_size = sizeof(Substrates) * Dims(dims.height+4, dims.width+4).vol();
diffusionKernel<<<batch_size, dim3(CuBlockDimX*2, CuBlockDimY*2), shared_mem_size, stream>>>
diffusionKernel<<<batch_size, dim3(kBlockDimX, kBlockDimY), shared_mem_size, stream>>>
(lattices, rois, border_masks, params, external_levels, steps);
KERNEL_DEBUG("diffusion kernel")
}
Expand Down
34 changes: 26 additions & 8 deletions src/emt6ro/simulation/simulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "emt6ro/statistics/statistics.h"
#include "emt6ro/common/cuda-utils.h"
#include "emt6ro/common/stack.cuh"
#include "emt6ro/common/error.h"

namespace emt6ro {

Expand Down Expand Up @@ -128,11 +129,7 @@ Simulation::Simulation(uint32_t batch_size, const Parameters &parameters, uint32
, occupied(batch_size * 1024)
, rand_state(batch_size * simulate_num_threads)
, results(batch_size) {
std::vector<uint32_t> h_seeds(batch_size * simulate_num_threads);
std::mt19937 rand{seed};
std::generate(h_seeds.begin(), h_seeds.end(), rand);
auto seeds = device::buffer<uint32_t>::fromHost(h_seeds.data(), h_seeds.size(), str.stream_);
rand_state.init(seeds.data(), str.stream_);
rand_state.init(seed, str.stream_);
populateLattices();
}

Expand All @@ -147,6 +144,7 @@ void Simulation::sendData(const HostGrid<Site> &grid, const Protocol &protocol,
KERNEL_DEBUG("protocol")
}
filled_samples += multi;
filled_protocols += multi;
}

void Simulation::step() {
Expand All @@ -168,7 +166,6 @@ void Simulation::diffuse() {
batchDiffusion(lattices.data(), rois.data(), border_masks.data(), params.diffusion_params,
params.external_levels, params.time_step/params.diffusion_params.time_step,
dims, batch_size, str.stream_);
// oldBatchDiffusion(data.data(), dims, params, batch_size);
}

void Simulation::simulateCells() {
Expand All @@ -189,15 +186,16 @@ void Simulation::getResults(uint32_t *h_results) {
cudaMemcpyDeviceToHost, str.stream_);
sync();
}

void Simulation::run(uint32_t nsteps) {
assert(filled_samples == batch_size);
ENFORCE(filled_samples == batch_size, "");
for (uint32_t s = 0; s < nsteps; ++s) {
step();
}
}

void Simulation::getData(Site *h_data, uint32_t sample) {
assert(sample < batch_size);
ENFORCE(sample < batch_size, "");
cudaMemcpyAsync(h_data, data.data() + sample * dims.vol(),
dims.vol() * sizeof(Site), cudaMemcpyDeviceToHost, str.stream_);
sync();
Expand All @@ -207,4 +205,24 @@ void Simulation::sync() {
cudaStreamSynchronize(str.stream_);
}

void Simulation::reset() {
sync();
step_ = 0;
filled_samples = 0;
}

void Simulation::setState(const Site *state) {
cudaMemcpyAsync(data.data(), state, batch_size * dims.vol() * sizeof(Site),
cudaMemcpyDeviceToDevice, str.stream_);
KERNEL_DEBUG("copy data");
filled_samples = batch_size;
}

void Simulation::setProtocols(const Protocol *ps) {
cudaMemcpyAsync(protocols.data(), ps, batch_size * sizeof(Protocol),
cudaMemcpyHostToDevice, str.stream_);
KERNEL_DEBUG("copy protocols");
filled_protocols = batch_size;
}

} // namespace emt6ro
Loading

0 comments on commit 6d057bf

Please sign in to comment.