Skip to content

Commit

Permalink
Refactor ScriptInterface
Browse files Browse the repository at this point in the history
Give each member of ScriptInterface::System a handle to the core
System class. Refactor AutoParameters serialization. The System
class can now handle checkpointing of its ObjectHandle members.
  • Loading branch information
jngrad committed Oct 23, 2023
1 parent 4cc12c1 commit 259b3af
Show file tree
Hide file tree
Showing 36 changed files with 273 additions and 172 deletions.
7 changes: 3 additions & 4 deletions src/core/analysis/statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,12 @@ double mindist(System::System const &system, std::vector<int> const &set1,
return std::sqrt(mindist_sq);
}

Utils::Vector3d calc_linear_momentum(bool include_particles,
Utils::Vector3d calc_linear_momentum(System::System const &system,
bool include_particles,
bool include_lbfluid) {
Utils::Vector3d momentum{};
auto const &system = System::get_system();
auto &cell_structure = *system.cell_structure;
if (include_particles) {
auto const particles = cell_structure.local_particles();
auto const particles = system.cell_structure->local_particles();
momentum =
std::accumulate(particles.begin(), particles.end(), Utils::Vector3d{},
[](Utils::Vector3d const &m, Particle const &p) {
Expand Down
4 changes: 3 additions & 1 deletion src/core/analysis/statistics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ Utils::Vector9d moment_of_inertia_matrix(System::System const &system,
int p_type);

/** Calculate total momentum of the system (particles & LB fluid).
* @param system particle system
* @param include_particles Add particles momentum
* @param include_lbfluid Add LB fluid momentum
*/
Utils::Vector3d calc_linear_momentum(bool include_particles,
Utils::Vector3d calc_linear_momentum(System::System const &system,
bool include_particles,
bool include_lbfluid);
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "nonbonded_interactions/nonbonded_interaction_data.hpp"

#include "electrostatics/coulomb.hpp"
#include "system/System.hpp"

#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -123,7 +122,3 @@ double InteractionsNonBonded::maximal_cutoff() const {
}
return max_cut_nonbonded;
}

void InteractionsNonBonded::on_non_bonded_ia_change() {
System::get_system().on_non_bonded_ia_change();
}
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ class InteractionsNonBonded {

void set_ia_param(int i, int j, std::shared_ptr<IA_parameters> const &ia) {
m_nonbonded_ia_params[get_ia_param_key(i, j)] = ia;
on_non_bonded_ia_change();
}

auto get_max_seen_particle_type() const { return max_seen_particle_type; }
Expand All @@ -437,6 +436,4 @@ class InteractionsNonBonded {

/** @brief Get maximal cutoff. */
double maximal_cutoff() const;

void on_non_bonded_ia_change();
};
24 changes: 13 additions & 11 deletions src/python/espressomd/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def acf_1d(signal, n_with_padding, n):

@script_interface_register
class _ObservableStat(ScriptInterfaceHelper):
_so_name = "ScriptInterface::Analysis::ObservableStat"
_so_name = "Analysis::ObservableStat"
_so_creation_policy = "GLOBAL"

def _generate_summary(self, obj, dim, calc_sp):
Expand Down Expand Up @@ -381,7 +381,7 @@ class Analysis(ScriptInterfaceHelper):
and [1] contains the values of the rdf.
"""
_so_name = "ScriptInterface::Analysis::Analysis"
_so_name = "Analysis::Analysis"
_so_creation_policy = "GLOBAL"
_so_bind_methods = (
"linear_momentum",
Expand All @@ -395,6 +395,11 @@ class Analysis(ScriptInterfaceHelper):
"structure_factor",
"distribution")

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.obs_stat = _ObservableStat()
self.call_method("register_observable_stat", leaf=self.obs_stat)

def min_dist(self, p1='default', p2='default'):
"""
Minimal distance between two sets of particle types.
Expand Down Expand Up @@ -459,10 +464,9 @@ def pressure(self):
* ``"external_fields"``: external fields contribution
"""
obs_stat = _ObservableStat()
observable = obs_stat.call_method("calculate_scalar_pressure")
observable = self.obs_stat.call_method("calculate_scalar_pressure")
utils.handle_errors("calculate_pressure() failed")
return obs_stat._generate_summary(observable, 9, True)
return self.obs_stat._generate_summary(observable, 9, True)

def pressure_tensor(self):
"""
Expand Down Expand Up @@ -502,10 +506,9 @@ def pressure_tensor(self):
* ``"external_fields"``: external fields contribution
"""
obs_stat = _ObservableStat()
observable = obs_stat.call_method("calculate_pressure_tensor")
observable = self.obs_stat.call_method("calculate_pressure_tensor")
utils.handle_errors("calculate_pressure() failed")
return obs_stat._generate_summary(observable, 9, False)
return self.obs_stat._generate_summary(observable, 9, False)

def energy(self):
"""
Expand Down Expand Up @@ -551,10 +554,9 @@ def energy(self):
>>> print(energy["external_fields"])
"""
obs_stat = _ObservableStat()
observable = obs_stat.call_method("calculate_energy")
observable = self.obs_stat.call_method("calculate_energy")
utils.handle_errors("calculate_energy() failed")
return obs_stat._generate_summary(observable, 1, False)
return self.obs_stat._generate_summary(observable, 1, False)

def particle_energy(self, particle):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/python/espressomd/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def __getstate__(self):
checkpointable_properties.append("_active_virtual_sites_handle")
checkpointable_properties += [
"non_bonded_inter", "bonded_inter", "cell_system", "lees_edwards",
"part", "analysis", "auto_update_accumulators",
"comfixed", "constraints", "galilei", "bond_breakage"
"part", "auto_update_accumulators",
"constraints", "bond_breakage"
]
if has_features("COLLISION_DETECTION"):
checkpointable_properties.append("collision_detection")
Expand Down
6 changes: 2 additions & 4 deletions src/script_interface/ObjectContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef SCRIPT_INTERFACE_OBJECT_CONTAINER_HPP
#define SCRIPT_INTERFACE_OBJECT_CONTAINER_HPP

#pragma once

#include "script_interface/auto_parameters/AutoParameters.hpp"

Expand All @@ -38,5 +38,3 @@ using ObjectContainer = std::conditional_t<
AutoParameters<Container<ManagedType, BaseType>, BaseType>, BaseType>;

} // namespace ScriptInterface

#endif
12 changes: 6 additions & 6 deletions src/script_interface/ObjectHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ Variant ObjectHandle::call_method(const std::string &name,
std::string ObjectHandle::serialize() const {
ObjectState state;

auto const params = this->get_parameters();
auto const params = serialize_parameters();
state.params.resize(params.size());

PackVisitor v;
PackVisitor visit;

/* Pack parameters and keep track of ObjectRef parameters */
boost::transform(params, state.params.begin(),
[&v](auto const &kv) -> PackedMap::value_type {
return {kv.first, boost::apply_visitor(v, kv.second)};
[&visit](auto const &kv) -> PackedMap::value_type {
return {kv.first, boost::apply_visitor(visit, kv.second)};
});

/* Packed Object parameters */
state.objects.resize(v.objects().size());
boost::transform(v.objects(), state.objects.begin(), [](auto const &kv) {
state.objects.resize(visit.objects().size());
boost::transform(visit.objects(), state.objects.begin(), [](auto const &kv) {
return std::make_pair(kv.first, kv.second->serialize());
});

Expand Down
16 changes: 13 additions & 3 deletions src/script_interface/ObjectHandle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef SCRIPT_INTERFACE_SCRIPT_INTERFACE_BASE_HPP
#define SCRIPT_INTERFACE_SCRIPT_INTERFACE_BASE_HPP
#pragma once

#include "Variant.hpp"

#include <utils/Span.hpp>
Expand All @@ -27,6 +27,7 @@

#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace ScriptInterface {
Expand Down Expand Up @@ -156,9 +157,18 @@ class ObjectHandle {
*/
static ObjectRef deserialize(const std::string &state, Context &ctx);

/**
* @brief Serialize parameters.
* Can be overriden to e.g. serialize parameters in a specific order.
*/
virtual std::vector<std::pair<std::string, Variant>>
serialize_parameters() const {
auto const params = this->get_parameters();
return {params.begin(), params.end()};
}

private:
virtual std::string get_internal_state() const { return {}; }
virtual void set_internal_state(std::string const &state) {}
};
} /* namespace ScriptInterface */
#endif
4 changes: 1 addition & 3 deletions src/script_interface/ObjectList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef SCRIPT_INTERFACE_OBJECT_LIST_HPP
#define SCRIPT_INTERFACE_OBJECT_LIST_HPP
#pragma once

#include "script_interface/ObjectContainer.hpp"
#include "script_interface/ScriptInterface.hpp"
Expand Down Expand Up @@ -160,4 +159,3 @@ class ObjectList : public ObjectContainer<ObjectList, ManagedType, BaseType> {
}
};
} // Namespace ScriptInterface
#endif
4 changes: 1 addition & 3 deletions src/script_interface/ObjectMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef SCRIPT_INTERFACE_OBJECT_MAP_HPP
#define SCRIPT_INTERFACE_OBJECT_MAP_HPP
#pragma once

#include "script_interface/ObjectContainer.hpp"
#include "script_interface/ScriptInterface.hpp"
Expand Down Expand Up @@ -194,4 +193,3 @@ class ObjectMap : public ObjectContainer<ObjectMap, ManagedType, BaseType> {
}
};
} // Namespace ScriptInterface
#endif
6 changes: 2 additions & 4 deletions src/script_interface/ObjectState.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef ESPRESSO_SCRIPT_INTERFACE_OBJECTSTATE_HPP
#define ESPRESSO_SCRIPT_INTERFACE_OBJECTSTATE_HPP

#pragma once

#include "packed_variant.hpp"

Expand Down Expand Up @@ -48,5 +48,3 @@ struct ObjectState {
}
};
} // namespace ScriptInterface

#endif
6 changes: 2 additions & 4 deletions src/script_interface/Variant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef SCRIPT_INTERFACE_VARIANT_HPP
#define SCRIPT_INTERFACE_VARIANT_HPP

#pragma once

#include "None.hpp"

Expand Down Expand Up @@ -127,5 +127,3 @@ template <class T> bool is_type(Variant const &v) {

inline bool is_none(Variant const &v) { return is_type<None>(v); }
} // namespace ScriptInterface

#endif
Loading

0 comments on commit 259b3af

Please sign in to comment.