Skip to content

Commit

Permalink
Added out argument to modify grid inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Jun 20, 2023
1 parent 2a57da8 commit 530324c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/sisl/physics/densitymatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def density(self, grid, spinor=None, psi_values=None, use_sparse=False, **kwargs
# Compute orbital values on the grid
psi_values = uc_dm.geometry.orbital_values(grid.shape, pbc=True)

return psi_values.reduce_orbital_products(csrDM, uc_dm.lattice, use_sparse=use_sparse, **kwargs)
psi_values.reduce_orbital_products(csrDM, uc_dm.lattice, use_sparse=use_sparse, out=grid.grid, **kwargs)

def old_density(self, grid, spinor=None, tol=1e-7, eta=None):
r""" Expand the density matrix to the charge density on a grid
Expand Down
22 changes: 17 additions & 5 deletions src/sisl/sparse_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def reduce_orbitals(self, weights, k=(0,0,0), **kwargs):

return self.reduce_dimension(weights, **kwargs)

def reduce_orbital_products(self, weights, weights_sc, grid_axes=(), use_sparse=False):
def reduce_orbital_products(self, weights, weights_sc, grid_axes=(), use_sparse=False, out=None):
csr = self._csr

# Find out the reduced shape, and the reduce factor. The reduced factor is the number
Expand Down Expand Up @@ -243,7 +243,10 @@ def reduce_orbital_products(self, weights, weights_sc, grid_axes=(), use_sparse=
sparse_coeffs = isinstance(weights, SparseCSR)
multi_coeffs = len(weights.shape) == 3 and weights.shape[2] > 1

dtype = np.result_type(weights, csr)
if out is None:
dtype = np.result_type(weights, csr)
else:
dtype = out.dtype

# Decide which function to use.
if sparse_coeffs:
Expand All @@ -263,12 +266,21 @@ def reduce_orbital_products(self, weights, weights_sc, grid_axes=(), use_sparse=
else:
reduce_func = reduce_sc_products

if out is not None:
grid = out

if multi_coeffs:
grid = np.zeros((*reduced_shape, weights.shape[2]), dtype=dtype)
if out is None:
grid = np.zeros((*reduced_shape, weights.shape[2]), dtype=dtype)
out = grid

out = grid.reshape(-1, weights.shape[2])
else:
grid = Grid(reduced_shape, geometry=self.geometry, dtype=dtype)
out = grid.grid.ravel()
if out is None:
grid = Grid(reduced_shape, geometry=self.geometry, dtype=dtype)
out = grid.grid

out = out.ravel()

reduce_func(
csr.data[: , 0].astype(dtype), csr.ptr, csr.col, weights.astype(dtype=dtype),
Expand Down

0 comments on commit 530324c

Please sign in to comment.