Skip to content

Commit

Permalink
Add dB/dt to ODE system instead of using trapezoidal integration
Browse files Browse the repository at this point in the history
  • Loading branch information
simlapointe committed Nov 19, 2024
1 parent 9ecd6d1 commit db8c2a5
Showing 1 changed file with 47 additions and 22 deletions.
69 changes: 47 additions & 22 deletions palace/models/timeoperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
std::unique_ptr<KspSolver> kspM, kspA;
std::unique_ptr<Operator> A, B;
mutable Vector RHS;
int size_E;
int size_E, size_B;

const Operator *Curl;

// Bindings to SpaceOperator functions to get the system matrix and preconditioner, and
// construct the linear solver.
Expand All @@ -47,11 +49,12 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
TimeDependentFirstOrderOperator(const IoData &iodata, SpaceOperator &space_op,
std::function<double(double)> &dJ_coef, double t0,
mfem::TimeDependentOperator::Type type)
: mfem::TimeDependentOperator(2 * space_op.GetNDSpace().GetTrueVSize(), t0, type),
: mfem::TimeDependentOperator(2 * space_op.GetNDSpace().GetTrueVSize() + space_op.GetRTSpace().GetTrueVSize(), t0, type),
comm(space_op.GetComm()), dJ_coef(dJ_coef)
{
// Get dimensions of E and Edot vectors.
size_E = space_op.GetNDSpace().GetTrueVSize();
size_B = space_op.GetRTSpace().GetTrueVSize();

// Construct the system matrices defining the linear operator. PEC boundaries are
// handled simply by setting diagonal entries of the mass matrix for the corresponding
Expand All @@ -61,10 +64,12 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
C = space_op.GetDampingMatrix<Operator>(Operator::DIAG_ZERO);
M = space_op.GetMassMatrix<Operator>(Operator::DIAG_ONE);

Curl = &space_op.GetCurlMatrix();

// Set up RHS vector for the current source term: -g'(t) J, where g(t) handles the time
// dependence.
space_op.GetExcitationVector(NegJ);
RHS.SetSize(2 * size_E);
RHS.SetSize(2 * size_E + size_B);
RHS.UseDevice(true);

// Set up linear solvers.
Expand Down Expand Up @@ -103,21 +108,26 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
// Form the RHS for the first-order ODE system
void FormRHS(const Vector &u, Vector &rhs) const
{
Vector u1, u2, rhs1, rhs2;
Vector u1, u2, u3, rhs1, rhs2, rhs3;
u1.UseDevice(true);
u2.UseDevice(true);
u3.UseDevice(true);
rhs1.UseDevice(true);
rhs2.UseDevice(true);
rhs3.UseDevice(true);
u.Read();
u1.MakeRef(const_cast<Vector &>(u), 0, size_E);
u2.MakeRef(const_cast<Vector &>(u), size_E, size_E);
u3.MakeRef(const_cast<Vector &>(u), 2 * size_E, size_B);
rhs.ReadWrite();
rhs1.MakeRef(rhs, 0, size_E);
rhs2.MakeRef(rhs, size_E, size_E);
rhs3.MakeRef(rhs, 2 * size_E, size_B);

// u1 = Edot, u2 = E
// u1 = Edot, u2 = E, u3 = B
// rhs1 = -(K * u2 + C * u1) - J(t)
// rhs2 = u1
// rhs3 = -curl u2
K->Mult(u2, rhs1);
if (C)
{
Expand All @@ -126,11 +136,15 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
linalg::AXPBYPCZ(-1.0, rhs1, dJ_coef(t), NegJ, 0.0, rhs1);

rhs2 = u1;

Curl->Mult(u2, rhs3);
rhs3 *= -1;
}

// Solve M du = rhs
// |M 0| |du1| = |-(K * u2 + C * u1) - J(t) |
// |0 I| |du2| | u1 |
// |M 0 0| |du1| = |-(K * u2 + C * u1) - J(t) |
// |0 I 0| |du2| | u1 |
// |0 0 I| |du3| = |-curl u2 |
void Mult(const Vector &u, Vector &du) const override
{
if (kspM->NumTotalMult() == 0)
Expand All @@ -140,20 +154,25 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
}
FormRHS(u, RHS);

Vector du1, du2, RHS1, RHS2;
Vector du1, du2, du3, RHS1, RHS2, RHS3;
du1.UseDevice(true);
du2.UseDevice(true);
du3.UseDevice(true);
RHS1.UseDevice(true);
RHS2.UseDevice(true);
RHS3.UseDevice(true);
du.ReadWrite();
du1.MakeRef(du, 0, size_E);
du2.MakeRef(du, size_E, size_E);
du3.MakeRef(du, 2 * size_E, size_B);
RHS.ReadWrite();
RHS1.MakeRef(RHS, 0, size_E);
RHS2.MakeRef(RHS, size_E, size_E);
RHS3.MakeRef(RHS, 2 * size_E, size_B);

kspM->Mult(RHS1, du1);
du2 = RHS2;
du3 = RHS3;
}

void ImplicitSolve(double dt, const Vector &u, Vector &k) override
Expand All @@ -171,24 +190,32 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
Mpi::Print("\n");
FormRHS(u, RHS);

Vector k1, k2, RHS1, RHS2;
Vector k1, k2, k3, RHS1, RHS2, RHS3;
k1.UseDevice(true);
k2.UseDevice(true);
k3.UseDevice(true);
RHS1.UseDevice(true);
RHS2.UseDevice(true);
RHS3.UseDevice(true);
k.ReadWrite();
k1.MakeRef(k, 0, size_E);
k2.MakeRef(k, size_E, size_E);
k3.MakeRef(k, 2 * size_E, size_B);
RHS.ReadWrite();
RHS1.MakeRef(RHS, 0, size_E);
RHS2.MakeRef(RHS, size_E, size_E);
RHS3.MakeRef(RHS, 2 * size_E, size_B);

// A k1 = RHS1 - dt K RHS2
K->AddMult(RHS2, RHS1, -dt);
kspA->Mult(RHS1, k1);

// k2 = rhs2 + dt k1
linalg::AXPBYPCZ(1.0, RHS2, dt, k1, 0.0, k2);

// k3 = rhs3 - dt curl k2
k3 = RHS3;
Curl->AddMult(k2, RHS3, -dt);
}

void ExplicitMult(const Vector &u, Vector &v) const override { Mult(u, v); }
Expand All @@ -215,18 +242,22 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
// Solve (Mass - dt Jacobian) x = Mass b
int SUNImplicitSolve(const Vector &b, Vector &x, double tol) override
{
Vector b1, b2, x1, x2, RHS1;
Vector b1, b2, b3, x1, x2, x3, RHS1;
b1.UseDevice(true);
b2.UseDevice(true);
b3.UseDevice(true);
x1.UseDevice(true);
x2.UseDevice(true);
x3.UseDevice(true);
RHS1.UseDevice(true);
b.Read();
b1.MakeRef(const_cast<Vector &>(b), 0, size_E);
b2.MakeRef(const_cast<Vector &>(b), size_E, size_E);
b3.MakeRef(const_cast<Vector &>(b), 2 * size_E, size_B);
x.ReadWrite();
x1.MakeRef(x, 0, size_E);
x2.MakeRef(x, size_E, size_E);
x3.MakeRef(x, 2 * size_E, size_B);
RHS.ReadWrite();
RHS1.MakeRef(RHS, 0, size_E);

Expand All @@ -238,6 +269,10 @@ class TimeDependentFirstOrderOperator : public mfem::TimeDependentOperator
// x2 = b2 + dt x1
linalg::AXPBYPCZ(1.0, b2, saved_gamma, x1, 0.0, x2);

// x3 = b3 - dt curl x2
x3 = b3;
Curl->AddMult(x2, x3, -saved_gamma);

return 0;
}
};
Expand All @@ -249,23 +284,19 @@ TimeOperator::TimeOperator(const IoData &iodata, SpaceOperator &space_op,
: rel_tol(iodata.solver.transient.rel_tol), abs_tol(iodata.solver.transient.abs_tol),
order(iodata.solver.transient.order)
{
// Construct discrete curl matrix for B-field time integration.
Curl = &space_op.GetCurlMatrix();

// Get sizes.
int size_E = space_op.GetNDSpace().GetTrueVSize();
int size_B = space_op.GetRTSpace().GetTrueVSize();

// Allocate space for solution vectors.
sol.SetSize(2 * size_E);
En.SetSize(size_E);
B.SetSize(size_B);
sol.SetSize(2 * size_E + size_B);
sol.UseDevice(true);
E.UseDevice(true);
En.UseDevice(true);
B.UseDevice(true);
sol.ReadWrite();
E.MakeRef(sol, size_E, size_E);
B.MakeRef(sol, 2 * size_E, size_B);

// Create ODE solver for 1st-order IVP.
mfem::TimeDependentOperator::Type type = mfem::TimeDependentOperator::IMPLICIT;
Expand Down Expand Up @@ -379,7 +410,6 @@ void TimeOperator::Init()
{
// Always use zero initial conditions.
sol = 0.0;
B = 0.0;
if (use_mfem_integrator)
{
ode->Init(*op);
Expand All @@ -388,15 +418,10 @@ void TimeOperator::Init()

void TimeOperator::Step(double &t, double &dt)
{
En = E;
double dt_input = dt;
ode->Step(sol, t, dt);
// Ensure user-specified dt does not change.
dt = dt_input;

// Trapezoidal integration for B-field: dB/dt = -∇ x E.
En += E;
Curl->AddMult(En, B, -0.5 * dt);
}

void TimeOperator::PrintStats()
Expand Down

0 comments on commit db8c2a5

Please sign in to comment.