Skip to content

Commit

Permalink
Seeding dynamic and associated initialisation logic (super-droplet in…
Browse files Browse the repository at this point in the history
…jection during simulation) (#1367)

Co-authored-by: claresinger <[email protected]>
Co-authored-by: jtbuch <[email protected]>
  • Loading branch information
3 people authored Sep 14, 2024
1 parent ec1f515 commit f1e1df9
Show file tree
Hide file tree
Showing 26 changed files with 1,449 additions and 12 deletions.
1 change: 1 addition & 0 deletions PySDM/backends/impl_numba/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .pair_methods import PairMethods
from .physics_methods import PhysicsMethods
from .terminal_velocity_methods import TerminalVelocityMethods
from .seeding_methods import SeedingMethods
2 changes: 1 addition & 1 deletion PySDM/backends/impl_numba/methods/condensation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def calculate_ml_new( # pylint: disable=too-many-branches,too-many-arguments,to
v_drop = formulae.particle_shape_and_density__mass_to_volume(
attributes.water_mass[drop]
)
if v_drop < 0:
if v_drop <= 0:
continue
x_old = formulae.condensation_coordinate__x(v_drop)
r_old = formulae.trivia__radius(v_drop)
Expand Down
68 changes: 68 additions & 0 deletions PySDM/backends/impl_numba/methods/seeding_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
""" CPU implementation of backend methods for particle injections """

from functools import cached_property

import numba

from PySDM.backends.impl_common.backend_methods import BackendMethods


class SeedingMethods(BackendMethods): # pylint: disable=too-few-public-methods
@cached_property
def _seeding(self):
@numba.njit(**{**self.default_jit_flags, "parallel": False})
def body( # pylint: disable=too-many-arguments
idx,
multiplicity,
extensive_attributes,
seeded_particle_index,
seeded_particle_multiplicity,
seeded_particle_extensive_attributes,
number_of_super_particles_to_inject: int,
):
number_of_super_particles_already_injected = 0
# TODO #1387 start enumerating from the end of valid particle set
for i, mult in enumerate(multiplicity):
if (
number_of_super_particles_to_inject
== number_of_super_particles_already_injected
):
break
if mult == 0:
idx[i] = -1
s = seeded_particle_index[
number_of_super_particles_already_injected
]
number_of_super_particles_already_injected += 1
multiplicity[i] = seeded_particle_multiplicity[s]
for a in range(len(extensive_attributes)):
extensive_attributes[a, i] = (
seeded_particle_extensive_attributes[a, s]
)
assert (
number_of_super_particles_to_inject
== number_of_super_particles_already_injected
)

return body

def seeding(
self,
*,
idx,
multiplicity,
extensive_attributes,
seeded_particle_index,
seeded_particle_multiplicity,
seeded_particle_extensive_attributes,
number_of_super_particles_to_inject: int,
):
self._seeding(
idx=idx.data,
multiplicity=multiplicity.data,
extensive_attributes=extensive_attributes.data,
seeded_particle_index=seeded_particle_index.data,
seeded_particle_multiplicity=seeded_particle_multiplicity.data,
seeded_particle_extensive_attributes=seeded_particle_extensive_attributes.data,
number_of_super_particles_to_inject=number_of_super_particles_to_inject,
)
2 changes: 2 additions & 0 deletions PySDM/backends/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Numba( # pylint: disable=too-many-ancestors,duplicate-code
methods.DisplacementMethods,
methods.TerminalVelocityMethods,
methods.IsotopeMethods,
methods.SeedingMethods,
):
Storage = ImportedStorage
Random = ImportedRandom
Expand Down Expand Up @@ -75,3 +76,4 @@ def __init__(self, formulae=None, double_precision=True, override_jit_flags=None
methods.DisplacementMethods.__init__(self)
methods.TerminalVelocityMethods.__init__(self)
methods.IsotopeMethods.__init__(self)
methods.SeedingMethods.__init__(self)
4 changes: 4 additions & 0 deletions PySDM/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,8 @@ def build(
for key in self.particulator.dynamics:
self.particulator.timers[key] = WallTimer()

if (attributes["multiplicity"] == 0).any():
self.particulator.attributes.healthy = False
self.particulator.attributes.sanitize()

return self.particulator
1 change: 1 addition & 0 deletions PySDM/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from PySDM.dynamics.eulerian_advection import EulerianAdvection
from PySDM.dynamics.freezing import Freezing
from PySDM.dynamics.relaxed_velocity import RelaxedVelocity
from PySDM.dynamics.seeding import Seeding
94 changes: 94 additions & 0 deletions PySDM/dynamics/seeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
""" particle injection handling, requires initalising a simulation with
enough particles flagged with NaN multiplicity (translated to zeros
at multiplicity discretisation """

from collections.abc import Sized

import numpy as np

from PySDM.dynamics.impl import register_dynamic
from PySDM.initialisation import discretise_multiplicities


@register_dynamic()
class Seeding:
def __init__(
self,
*,
super_droplet_injection_rate: callable,
seeded_particle_extensive_attributes: dict,
seeded_particle_multiplicity: Sized,
):
for attr in seeded_particle_extensive_attributes.values():
assert len(seeded_particle_multiplicity) == len(attr)
self.particulator = None
self.super_droplet_injection_rate = super_droplet_injection_rate
self.seeded_particle_extensive_attributes = seeded_particle_extensive_attributes
self.seeded_particle_multiplicity = seeded_particle_multiplicity
self.rnd = None
self.u01 = None
self.index = None

def register(self, builder):
self.particulator = builder.particulator

def post_register_setup_when_attributes_are_known(self):
if tuple(self.particulator.attributes.get_extensive_attribute_keys()) != tuple(
self.seeded_particle_extensive_attributes.keys()
):
raise ValueError(
f"extensive attributes ({self.seeded_particle_extensive_attributes.keys()})"
" do not match those used in particulator"
f" ({self.particulator.attributes.get_extensive_attribute_keys()})"
)

self.index = self.particulator.Index.identity_index(
len(self.seeded_particle_multiplicity)
)
if len(self.seeded_particle_multiplicity) > 1:
self.rnd = self.particulator.Random(
len(self.seeded_particle_multiplicity), self.particulator.formulae.seed
)
self.u01 = self.particulator.Storage.empty(
len(self.seeded_particle_multiplicity), dtype=float
)
self.seeded_particle_multiplicity = (
self.particulator.IndexedStorage.from_ndarray(
self.index,
discretise_multiplicities(
np.asarray(self.seeded_particle_multiplicity)
),
)
)
self.seeded_particle_extensive_attributes = (
self.particulator.IndexedStorage.from_ndarray(
self.index,
np.asarray(list(self.seeded_particle_extensive_attributes.values())),
)
)

def __call__(self):
if self.particulator.n_steps == 0:
self.post_register_setup_when_attributes_are_known()

time = self.particulator.n_steps * self.particulator.dt
number_of_super_particles_to_inject = self.super_droplet_injection_rate(time)

if number_of_super_particles_to_inject > 0:
assert number_of_super_particles_to_inject <= len(
self.seeded_particle_multiplicity
)

if self.rnd is not None:
self.u01.urand(self.rnd)
# TODO #1387 make shuffle smarter
# e.g. don't need to shuffle if only one type of seed particle
# or if the number of super particles to inject
# is equal to the number of possible seeds
self.index.shuffle(self.u01)
self.particulator.seeding(
seeded_particle_index=self.index,
number_of_super_particles_to_inject=number_of_super_particles_to_inject,
seeded_particle_multiplicity=self.seeded_particle_multiplicity,
seeded_particle_extensive_attributes=self.seeded_particle_extensive_attributes,
)
5 changes: 5 additions & 0 deletions PySDM/impl/particle_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,8 @@ def get_extensive_attribute_keys(self):

def has_attribute(self, attr):
return attr in self.__attributes

def reset_idx(self):
self.__valid_n_sd = self.__idx.shape[0]
self.__idx.reset_index()
self.healthy = False
20 changes: 13 additions & 7 deletions PySDM/initialisation/discretise_multiplicities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,27 @@


def discretise_multiplicities(values_arg):
values_int = values_arg.round().astype(np.int64)
"""any NaN values in the input array are ignored and flagged
with zero multiplicities in the output array"""

values_int = np.where(np.isnan(values_arg), 0, values_arg).round().astype(np.int64)

if np.issubdtype(values_arg.dtype, np.floating):
if np.isnan(values_arg).all():
return values_int

if not np.logical_or(values_int > 0, np.isnan(values_arg)).all():
raise ValueError(
f"int-casting resulted in multiplicity of zero (min(y_float)={min(values_arg)})"
)

percent_diff = 100 * abs(
1 - np.sum(values_arg) / np.sum(values_int.astype(float))
1 - np.nansum(values_arg) / np.sum(values_int.astype(float))
)
if percent_diff > 1:
raise ValueError(
f"{percent_diff}% error in total real-droplet number"
f" due to casting multiplicities to ints"
)

if not (values_int > 0).all():
raise ValueError(
f"int-casting resulted in multiplicity of zero (min(y_float)={min(values_arg)})"
)

return values_int
3 changes: 2 additions & 1 deletion PySDM/initialisation/init_fall_momenta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from PySDM.backends import CPU
from PySDM.dynamics.terminal_velocity import GunnKinzer1949
from PySDM.formulae import Formulae
from PySDM.particulator import Particulator
Expand All @@ -31,6 +30,8 @@ def init_fall_momenta(
if zero:
return np.zeros_like(water_mass)

from PySDM.backends import CPU # pylint: disable=import-outside-toplevel

particulator = Particulator(0, CPU(Formulae())) # TODO #1155

approximation = terminal_velocity_approx(particulator=particulator)
Expand Down
43 changes: 42 additions & 1 deletion PySDM/particulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class Particulator: # pylint: disable=too-many-public-methods,too-many-instance-attributes
def __init__(self, n_sd, backend: BackendMethods):
def __init__(self, n_sd, backend):
assert isinstance(backend, BackendMethods)
self.__n_sd = n_sd

Expand Down Expand Up @@ -438,3 +438,44 @@ def isotopic_fractionation(self, heavy_isotopes: tuple):
self.backend.isotopic_fractionation()
for isotope in heavy_isotopes:
self.attributes.mark_updated(f"moles_{isotope}")

def seeding(
self,
*,
seeded_particle_index,
seeded_particle_multiplicity,
seeded_particle_extensive_attributes,
number_of_super_particles_to_inject,
):
n_null = self.n_sd - self.attributes.super_droplet_count
if n_null == 0:
raise ValueError(
"No available seeds to inject. Please provide particles with nan filled attributes."
)

if number_of_super_particles_to_inject > n_null:
raise ValueError(
"Trying to inject more super particles than space available."
)

if number_of_super_particles_to_inject > len(seeded_particle_multiplicity):
raise ValueError(
"Trying to inject multiple super particles with the same attributes. \
Instead increase multiplicity of injected particles."
)

self.backend.seeding(
idx=self.attributes._ParticleAttributes__idx,
multiplicity=self.attributes["multiplicity"],
extensive_attributes=self.attributes.get_extensive_attribute_storage(),
seeded_particle_index=seeded_particle_index,
seeded_particle_multiplicity=seeded_particle_multiplicity,
seeded_particle_extensive_attributes=seeded_particle_extensive_attributes,
number_of_super_particles_to_inject=number_of_super_particles_to_inject,
)
self.attributes.reset_idx()
self.attributes.sanitize()

self.attributes.mark_updated("multiplicity")
for key in self.attributes.get_extensive_attribute_keys():
self.attributes.mark_updated(key)
2 changes: 2 additions & 0 deletions examples/PySDM_examples/seeding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .settings import Settings
from .simulation import Simulation
Loading

0 comments on commit f1e1df9

Please sign in to comment.