Skip to content

Commit

Permalink
Implement snapshotting for the acoustic wave equation
Browse files Browse the repository at this point in the history
  • Loading branch information
malfarhan7 committed Oct 26, 2024
1 parent f95007d commit 44694f5
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 35 deletions.
89 changes: 69 additions & 20 deletions examples/seismic/acoustic/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign, ConditionalDimension
from devito.symbolics import retrieve_functions, INT, retrieve_derivatives


Expand Down Expand Up @@ -108,7 +108,7 @@ def iso_stencil(field, model, kernel, **kwargs):


def ForwardOperator(model, geometry, space_order=4,
save=False, kernel='OT2', **kwargs):
save=False, kernel='OT2', factor=None, **kwargs):
"""
Construct a forward modelling operator in an acoustic medium.
Expand All @@ -126,6 +126,8 @@ def ForwardOperator(model, geometry, space_order=4,
Defaults to False.
kernel : str, optional
Type of discretization, 'OT2' or 'OT4'.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
"""
m = model.m

Expand All @@ -144,10 +146,28 @@ def ForwardOperator(model, geometry, space_order=4,

# Create interpolation expression for receivers
rec_term = rec.interpolate(expr=u)

# Build operator equations
equations = eqn + src_term + rec_term

if factor:
# Implement snapshotting
nsnaps = (geometry.nt + factor - 1) // factor
time_subsampled = ConditionalDimension(
't_sub', parent=model.grid.time_dim, factor=factor)
usnaps = TimeFunction(name='usnaps', grid=model.grid,
time_order=2, space_order=space_order,
save=nsnaps, time_dim=time_subsampled)
# Add equation to save snapshots
snapshot_eq = Eq(usnaps, u)
equations += [snapshot_eq]
else:
usnaps = None
# Substitute spacing terms to reduce flops
return Operator(eqn + src_term + rec_term, subs=model.spacing_map,
name='Forward', **kwargs)
op = Operator(equations, subs=model.spacing_map, name='Forward', **kwargs)
if usnaps is not None:
return op, usnaps
else:
return op


def AdjointOperator(model, geometry, space_order=4,
Expand Down Expand Up @@ -189,8 +209,8 @@ def AdjointOperator(model, geometry, space_order=4,


def GradientOperator(model, geometry, space_order=4, save=True,
kernel='OT2', **kwargs):
"""
kernel='OT2', factor=None, **kwargs):
"""
Construct a gradient operator in an acoustic media.
Parameters
Expand All @@ -206,30 +226,59 @@ def GradientOperator(model, geometry, space_order=4, save=True,
Option to store the entire (unrolled) wavefield.
kernel : str, optional
Type of discretization, centered or shifted.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
"""
m = model.m

# Gradient symbol and wavefield symbols
# Gradient symbol
grad = Function(name='grad', grid=model.grid)
u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save
else None, time_order=2, space_order=space_order)
v = TimeFunction(name='v', grid=model.grid, save=None,
time_order=2, space_order=space_order)
rec = geometry.rec

# Create the adjoint wavefield
v = TimeFunction(name='v', grid=model.grid, time_order=2, space_order=space_order)

s = model.grid.stepping_dim.spacing
eqn = iso_stencil(v, model, kernel, forward=False)

if kernel == 'OT2':
gradient_update = Inc(grad, - u * v.dt2)
elif kernel == 'OT4':
gradient_update = Inc(grad, - u * v.dt2 - s**2 / 12.0 * u.biharmonic(m**(-2)) * v)
# Add expression for receiver injection
rec = geometry.rec
receivers = rec.inject(field=v.backward, expr=rec * s**2 / m)

time = model.grid.time_dim

if factor is not None:
# Condition to apply gradient update only at snapshot times
condition = Eq(time % factor, 0)
# Create the ConditionalDimension for subsampling
time_subsampled = ConditionalDimension('t_sub', parent=time, factor=factor)
# Define usnaps with time_subsampled as its time dimension
nsnaps = (geometry.nt + factor - 1) // factor
usnaps = TimeFunction(name='usnaps', grid=model.grid,
time_order=2, space_order=space_order,
save=nsnaps, time_dim=time_subsampled)
# Gradient update without indexing usnaps
if kernel == 'OT2':
gradient_update = Inc(grad, - usnaps * v.dt2, implicit_dims=[time_subsampled],
condition=condition)
elif kernel == 'OT4':
gradient_update = Inc(grad, - usnaps * v.dt2
- s**2 / 12.0 * usnaps.biharmonic(m**(-2)) * v,
implicit_dims=[time_subsampled],
condition=condition)
else:
u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
time_order=2, space_order=space_order)
if kernel == 'OT2':
gradient_update = Inc(grad, - u * v.dt2)
elif kernel == 'OT4':
gradient_update = Inc(grad, - u * v.dt2
- s**2 / 12.0 * u.biharmonic(m**(-2)) * v)

# Substitute spacing terms to reduce flops
return Operator(eqn + receivers + [gradient_update], subs=model.spacing_map,
name='Gradient', **kwargs)
op = Operator(eqn + receivers + [gradient_update], subs=model.spacing_map,
name='Gradient', **kwargs)
return op


def BornOperator(model, geometry, space_order=4,
Expand Down Expand Up @@ -274,4 +323,4 @@ def BornOperator(model, geometry, space_order=4,

# Substitute spacing terms to reduce flops
return Operator(eqn1 + source + eqn2 + receivers, subs=model.spacing_map,
name='Born', **kwargs)
name='Born', **kwargs)
66 changes: 51 additions & 15 deletions examples/seismic/acoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
from devitofwi.devito.acoustic.operators import (
ForwardOperator, AdjointOperator, GradientOperator, BornOperator
)

Expand All @@ -23,6 +23,7 @@ class AcousticWaveSolver:
space_order: int, optional
Order of the spatial stencil discretisation. Defaults to 4.
"""

def __init__(self, model, geometry, kernel='OT2', space_order=4, **kwargs):
self.model = model
self.model._initialize_bcs(bcs="damp")
Expand All @@ -44,11 +45,11 @@ def dt(self):
return self.model.critical_dt

@memoized_meth
def op_fwd(self, save=None):
def op_fwd(self, save=None, factor=None):
"""Cached operator for forward runs with buffered wavefield"""
return ForwardOperator(self.model, save=save, geometry=self.geometry,
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)
factor=factor, **self._kwargs)

@memoized_meth
def op_adj(self):
Expand All @@ -58,11 +59,11 @@ def op_adj(self):
**self._kwargs)

@memoized_meth
def op_grad(self, save=True):
def op_grad(self, save=True, factor=None):
"""Cached operator for gradient runs"""
return GradientOperator(self.model, save=save, geometry=self.geometry,
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)
factor=factor, **self._kwargs)

@memoized_meth
def op_born(self):
Expand All @@ -71,7 +72,7 @@ def op_born(self):
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)

def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
def forward(self, src=None, rec=None, u=None, model=None, save=None, factor=None, **kwargs):
"""
Forward modelling function that creates the necessary
data objects for running a forward modelling operator.
Expand All @@ -90,6 +91,8 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
The time-constant velocity.
save : bool, optional
Whether or not to save the entire (unrolled) wavefield.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
Returns
-------
Expand All @@ -108,12 +111,24 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
model = model or self.model
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))
# Get the operator
op_fwd = self.op_fwd(save=save, factor=factor)
# Prepare parameters for operator apply
op_args = {'src': src, 'rec': rec, 'u': u, 'dt': kwargs.pop('dt', self.dt)}
op_args.update(kwargs)

# Execute operator and return wavefield and receiver data
summary = self.op_fwd(save).apply(src=src, rec=rec, u=u,
dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor:
# Operator returned is op, usnaps
op, usnaps = op_fwd
op_args['usnaps'] = usnaps
summary = op.apply(**op_args)

else:
op = op_fwd
usnaps = None
summary = op.apply(**op_args)
return rec, u, usnaps, summary

def adjoint(self, rec, srca=None, v=None, model=None, **kwargs):
"""
Expand Down Expand Up @@ -155,8 +170,8 @@ def adjoint(self, rec, srca=None, v=None, model=None, **kwargs):
dt=kwargs.pop('dt', self.dt), **kwargs)
return srca, v, summary

def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
checkpointing=False, **kwargs):
def jacobian_adjoint(self, rec, u=None, usnaps=None, src=None, v=None, grad=None, model=None,
factor=None, checkpointing=False, **kwargs):
"""
Gradient modelling function for computing the adjoint of the
Linearized Born modelling function, ie. the action of the
Expand All @@ -168,6 +183,8 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
Receiver data.
u : TimeFunction
Full wavefield `u` (created with save=True).
usnaps : TimeFunction
Snapshots of the wavefield `u`.
v : TimeFunction, optional
Stores the computed wavefield.
grad : Function, optional
Expand All @@ -176,12 +193,22 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
Object containing the physical parameters.
vp : Function or float, optional
The time-constant velocity.
checkpointing : boolean, optional
Flag to enable checkpointing (default False).
Cannot be used with snapshotting.
factor : int, optional
Downsampling factor for the saved snapshots of the wavefield `u`.
Cannot be used with checkpointing.
Returns
-------
Gradient field and performance summary.
"""
dt = kwargs.pop('dt', self.dt)
# Check that snapshotting and checkpointing are not used together
if factor is not None and checkpointing:
raise ValueError("Cannot use snapshotting (factor) and checkpointing simultaneously.")

# Gradient symbol
grad = grad or Function(name='grad', grid=self.model.grid)

Expand Down Expand Up @@ -209,8 +236,17 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt,
**kwargs)
if factor is not None:
# Get the gradient operator
op = self.op_grad(save=False, factor=factor)
op_args = {'rec': rec, 'grad': grad, 'v': v, 'dt': dt, 'usnaps': usnaps}
else:
op = self.op_grad(save=True, factor=None)
op_args = {'rec': rec, 'grad': grad, 'v': v, 'dt': dt, 'u': u}

op_args.update(kwargs)
summary = op.apply(**op_args)

return grad, summary

def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwargs):
Expand Down Expand Up @@ -255,4 +291,4 @@ def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwarg

# Backward compatibility
born = jacobian
gradient = jacobian_adjoint
gradient = jacobian_adjoint
2 changes: 2 additions & 0 deletions examples/seismic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,6 @@ def __call__(self, parser, args, values, option_string=None):
choices=['float32', 'float64'])
parser.add_argument("-interp", dest="interp", default="linear",
choices=['linear', 'sinc'])
parser.add_argument("--factor", type=int, default=None,
help="Downsampling factor to use snapshotting, default is None")
return parser

0 comments on commit 44694f5

Please sign in to comment.