Skip to content

Commit

Permalink
feat(#30): interpolator serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ddemidov committed Oct 5, 2024
1 parent 3b296b2 commit e901883
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
enable_testing()
add_subdirectory(python)
add_subdirectory(tests)
add_executable(example example.cpp)
add_test(NAME go_test COMMAND go test WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mba)
38 changes: 38 additions & 0 deletions example.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include<fstream>
#include <mba/mba.hpp>

int main() {
// Coordinates of data points.
std::vector<mba::point<2>> coo = {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0},
{1.0, 1.0}, {0.4, 0.4}, {0.6, 0.6}};

// Data values.
std::vector<double> val = {0.2, 0.0, 0.0, -0.2, -1.0, 1.0};

// Bounding box containing the data points.
mba::point<2> lo = {-0.1, -0.1};
mba::point<2> hi = {1.1, 1.1};

// Initial grid size.
mba::index<2> grid = {3, 3};

// Algorithm setup.
mba::MBA<2> interp(lo, hi, grid, coo, val);

// write interpolator to a file
std::ofstream fout("test.mba", std::ios::binary);
interp.write(fout);
fout.close();

// read interpolator from a file
std::ifstream fin("test.mba", std::ios::binary);
mba::MBA<2> interp2(fin);

// Get interpolated value at arbitrary location.
double w = interp(mba::point<2>{0.3, 0.7});
double w2 = interp2(mba::point<2>{0.3, 0.7});

std::cout
<< w << std::endl
<< w2 << std::endl;
}
121 changes: 120 additions & 1 deletion mba/mba.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ THE SOFTWARE.

#include <iostream>
#include <iomanip>
#include <fstream>
#include <vector>
#include <array>
#include <map>
Expand Down Expand Up @@ -62,6 +63,10 @@ using index = std::array<size_t, N>;

namespace detail {

const int32_t initial_tag = 42000;
const int32_t dense_tag = 42001;
const int32_t sparse_tag = 42002;

template <class Cond, class Msg>
void precondition(const Cond &cond, const Msg &msg) {
if (!static_cast<bool>(cond)) throw std::runtime_error(msg);
Expand Down Expand Up @@ -117,6 +122,23 @@ class multi_array {
T* data() {
return buf.data();
}

void read(std::ifstream &f) {
precondition(static_cast<bool>(f.read((char*)sizes.data(), sizeof(sizes))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)stride.data(), sizeof(stride))), "file i/o error");

size_t n = 1;
for (size_t i = 0; i < N; ++i) n *= sizes[i];

buf.resize(n);
precondition(static_cast<bool>(f.read((char*)buf.data(), sizeof(T) * n)), "file i/o error");
}

void write(std::ofstream &f) const {
precondition(static_cast<bool>(f.write((char*)sizes.data(), sizeof(sizes))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)stride.data(), sizeof(stride))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)buf.data(), sizeof(T) * buf.size())), "file i/o error");
}
private:
std::array<int, N> sizes;
std::array<int, N> stride;
Expand Down Expand Up @@ -289,6 +311,8 @@ class control_lattice {
virtual double operator()(const point<NDim> &p) const = 0;

virtual void report(std::ostream&) const = 0;
virtual int32_t tag() const = 0;
virtual void write(std::ofstream&) const = 0;

template <class CooIter, class ValIter>
double residual(CooIter coo_begin, CooIter coo_end, ValIter val_begin) const {
Expand Down Expand Up @@ -319,6 +343,12 @@ class initial_approximation : public control_lattice<NDim> {
void report(std::ostream &os) const {
os << "initial approximation";
}

int32_t tag() const {
return detail::initial_tag;
}

void write(std::ofstream&) const {}
private:
std::function<double(const point<NDim>&)> f;
};
Expand Down Expand Up @@ -425,6 +455,29 @@ class control_lattice_dense : public control_lattice<NDim> {
}
}

control_lattice_dense(std::ifstream &f) {
precondition(static_cast<bool>(f.read((char*)cmin.data(), sizeof(cmin))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)cmax.data(), sizeof(cmax))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)hinv.data(), sizeof(hinv))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)grid.data(), sizeof(grid))), "file i/o error");

phi.resize(grid);
phi.read(f);
}

void write(std::ofstream &f) const {
precondition(static_cast<bool>(f.write((char*)cmin.data(), sizeof(cmin))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)cmax.data(), sizeof(cmax))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)hinv.data(), sizeof(hinv))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)grid.data(), sizeof(grid))), "file i/o error");

phi.write(f);
}

int32_t tag() const {
return detail::dense_tag;
}

double operator()(const point<NDim> &p) const {
index<NDim> i;
point<NDim> s;
Expand Down Expand Up @@ -497,7 +550,6 @@ class control_lattice_dense : public control_lattice<NDim> {

return static_cast<double>(nonzeros) / total;
}

private:
point<NDim> cmin, cmax, hinv;
index<NDim> grid;
Expand Down Expand Up @@ -564,6 +616,39 @@ class control_lattice_sparse : public control_lattice<NDim> {
detail::make_transform_iterator(dw.end(), delta_over_omega));
}

control_lattice_sparse(std::ifstream &f) {
precondition(static_cast<bool>(f.read((char*)cmin.data(), sizeof(cmin))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)cmax.data(), sizeof(cmax))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)hinv.data(), sizeof(hinv))), "file i/o error");
precondition(static_cast<bool>(f.read((char*)grid.data(), sizeof(grid))), "file i/o error");

size_t n;
typename sparse_grid::value_type v;

precondition(static_cast<bool>(f.read((char*)&n, sizeof(n))), "file i/o error");
for (size_t i = 0; i < n; ++i) {
precondition(static_cast<bool>(f.read((char*)&v, sizeof(v))), "file i/o error");
phi.insert(phi.end(), v);
}
}

void write(std::ofstream &f) const {
precondition(static_cast<bool>(f.write((char*)cmin.data(), sizeof(cmin))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)cmax.data(), sizeof(cmax))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)hinv.data(), sizeof(hinv))), "file i/o error");
precondition(static_cast<bool>(f.write((char*)grid.data(), sizeof(grid))), "file i/o error");

size_t n = phi.size();
precondition(static_cast<bool>(f.write((char*)&n, sizeof(n))), "file i/o error");
for (auto &p : phi) {
precondition(static_cast<bool>(f.write((char*)&p, sizeof(p))), "file i/o error");
}
}

int32_t tag() const {
return detail::sparse_tag;
}

double operator()(const point<NDim> &p) const {
index<NDim> i;
point<NDim> s;
Expand Down Expand Up @@ -752,6 +837,30 @@ class MBA {
);
}

MBA(std::ifstream &f, std::function<double(point<NDim>)> initial = std::function<double(point<NDim>)>()) {
size_t n;
detail::precondition(static_cast<bool>(f.read((char*)&n, sizeof(n))), "file i/o error");
for (size_t i = 0; i < n; ++i) {
int32_t tag;
detail::precondition(static_cast<bool>(f.read((char*)&tag, sizeof(tag))), "file i/o error");

switch (tag) {
case detail::initial_tag:
detail::precondition(static_cast<bool>(initial), "initial function definition should be provided");
cl.push_back(std::make_shared<initial_approximation>(initial));
break;
case detail::dense_tag:
cl.push_back(std::make_shared<dense_lattice>(f));
break;
case detail::sparse_tag:
cl.push_back(std::make_shared<sparse_lattice>(f));
break;
default:
detail::precondition(false, "unknown lattice tag in input file");
}
}
}

double operator()(const point<NDim> &p) const {
double f = 0.0;

Expand All @@ -772,6 +881,16 @@ class MBA {
return os;
}

void write(std::ofstream &f) const {
size_t n = cl.size();
detail::precondition(static_cast<bool>(f.write((char*)&n, sizeof(n))), "file i/o error");
for (auto &l : cl) {
auto tag = l->tag();
detail::precondition(static_cast<bool>(f.write((char*)&tag, sizeof(tag))), "file i/o error");
l->write(f);
}
}

private:
typedef detail::control_lattice<NDim> lattice;
typedef detail::initial_approximation<NDim> initial_approximation;
Expand Down

0 comments on commit e901883

Please sign in to comment.