Skip to content

Commit

Permalink
Fix for AssembledMatrix (#3798)
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward authored Oct 15, 2024
1 parent 2251fae commit 6cc4ff8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
10 changes: 9 additions & 1 deletion firedrake/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def __init__(self, a, bcs, mat_type):
# on different processes)

ufl.Matrix.__init__(self, test.function_space(), trial.function_space())
# Define arguments after `Matrix.__init__` since BaseForm sets `self._arguments` to None

# ufl.Matrix._analyze_form_arguments sets the _arguments attribute to
# non-Firedrake objects, which breaks things. To avoid this we overwrite
# this property after the fact.
self._analyze_form_arguments()
self._arguments = arguments

self.bcs = bcs
self.comm = test.function_space().comm
self._comm = internal_comm(self.comm, self)
Expand All @@ -54,6 +59,9 @@ def arguments(self):
else:
return self._arguments

def ufl_domains(self):
return self._domains

@property
def has_bcs(self):
"""Return True if this :class:`MatrixBase` has any boundary
Expand Down
3 changes: 2 additions & 1 deletion firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ufl

import firedrake.linear_solver as ls
from firedrake.matrix import AssembledMatrix
import firedrake.variational_solver as vs
from firedrake import dmhooks, function, solving_utils, vector
import firedrake
Expand Down Expand Up @@ -162,7 +163,7 @@ def _solve_varproblem(*args, **kwargs):

appctx = kwargs.get("appctx", {})
# Solve linear variational problem
if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.BaseForm):
if isinstance(eq.lhs, (ufl.Form, AssembledMatrix)) and isinstance(eq.rhs, ufl.BaseForm):
# Create problem
problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp,
form_compiler_parameters=form_compiler_parameters,
Expand Down
34 changes: 30 additions & 4 deletions tests/regression/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@ def V():


@pytest.fixture
def a(V):
u = TrialFunction(V)
v = TestFunction(V)
def test(V):
return TestFunction(V)


@pytest.fixture
def trial(V):
return TrialFunction(V)


return inner(u, v)*dx
@pytest.fixture
def a(test, trial):
return inner(trial, test)*dx


@pytest.fixture(params=["nest", "aij", "matfree"])
Expand All @@ -27,3 +34,22 @@ def test_assemble_returns_matrix(a):
A = assemble(a)

assert isinstance(A, matrix.Matrix)


def test_solve_with_assembled_matrix():
mesh = UnitIntervalMesh(3)
V = FunctionSpace(mesh, "CG", 1)

u = TrialFunction(V)
v = TestFunction(V)
x, = SpatialCoordinate(mesh)
f = Function(V).interpolate(x)

a = inner(u, v) * dx
A = AssembledMatrix((v, u), bcs=(), petscmat=assemble(a).M.handle)
L = inner(f, v) * dx

solution = Function(V)
solve(A == L, solution)

assert norm(assemble(f - solution)) < 1e-15

0 comments on commit 6cc4ff8

Please sign in to comment.