Skip to content

Commit

Permalink
Fix aliasing issue for ComplexVector
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiangrimberg committed Feb 28, 2024
1 parent 7a5571a commit 2a2f00a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
19 changes: 7 additions & 12 deletions palace/linalg/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
namespace palace
{

ComplexVector::ComplexVector(int size)
: data(2 * size), xr(data, 0, size), xi(data, size, size)
{
}
ComplexVector::ComplexVector(int size) : xr(size), xi(size) {}

ComplexVector::ComplexVector(const ComplexVector &y) : ComplexVector(y.Size())
{
Expand Down Expand Up @@ -44,25 +41,23 @@ ComplexVector::ComplexVector(Vector &y, int offset, int size)

void ComplexVector::UseDevice(bool use_dev)
{
data.UseDevice(use_dev);
xr.UseDevice(use_dev);
xi.UseDevice(use_dev);
}

void ComplexVector::SetSize(int size)
{
data.SetSize(2 * size);
xr.MakeRef(data, 0, size);
xi.MakeRef(data, size, size);
xr.SetSize(size);
xi.SetSize(size);
}

void ComplexVector::MakeRef(Vector &y, int offset, int size)
{
MFEM_ASSERT(y.Size() <= 2 * size,
MFEM_ASSERT(y.Size() >= offset + 2 * size,
"Insufficient storage for ComplexVector alias reference of the given size!");
data.MakeRef(y, offset, 2 * size);
xr.MakeRef(data, offset, size);
xi.MakeRef(data, offset + size, size);
y.ReadWrite(); // Ensure memory is allocated on device before aliasing
xr.MakeRef(y, offset, size);
xi.MakeRef(y, offset + size, size);
}

void ComplexVector::Set(const ComplexVector &y)
Expand Down
6 changes: 3 additions & 3 deletions palace/linalg/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using Vector = mfem::Vector;
class ComplexVector
{
private:
Vector data, xr, xi;
Vector xr, xi;

public:
// Create a vector with the given size.
Expand All @@ -43,10 +43,10 @@ class ComplexVector

// Flag for runtime execution on the mfem::Device. See the documentation for mfem::Vector.
void UseDevice(bool use_dev);
bool UseDevice() const { return data.UseDevice(); }
bool UseDevice() const { return xr.UseDevice(); }

// Return the size of the vector.
int Size() const { return data.Size() / 2; }
int Size() const { return xr.Size(); }

// Set the size of the vector. See the notes for Vector::SetSize for behavior in the cases
// where the new size is less than or greater than Size() or Capacity().
Expand Down

0 comments on commit 2a2f00a

Please sign in to comment.