Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX config updates and expose penalty_param #894

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dev_tools/requirements/envs/dev.env.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ iniconfig==2.0.0
# via pytest
isort==5.13.2
# via pylint
jax==0.4.23
jax==0.4.31
# via -r deps/resource_estimates_runtime.txt
jaxlib==0.4.23
jaxlib==0.4.31
# via -r deps/resource_estimates_runtime.txt
jsonschema==4.21.0
# via nbformat
Expand Down
4 changes: 2 additions & 2 deletions dev_tools/requirements/envs/pytest-extra.env.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ iniconfig==2.0.0
# via
# -c envs/dev.env.txt
# pytest
jax==0.4.23
jax==0.4.31
# via
# -c envs/dev.env.txt
# -r deps/resource_estimates_runtime.txt
jaxlib==0.4.23
jaxlib==0.4.31
# via
# -c envs/dev.env.txt
# -r deps/resource_estimates_runtime.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
from pyscf.pbc import scf
from scipy.optimize import minimize

from jax.config import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import jax.typing as jnpt

Expand Down
6 changes: 5 additions & 1 deletion src/openfermion/resource_estimates/thc/factorize_thc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def thc_via_cp3(
bfgs_maxiter=5000,
random_start_thc=True,
verify=False,
penalty_param=None,
):
"""
THC-CP3 performs an SVD decomposition of the eri matrix followed by a CP
Expand All @@ -36,6 +37,7 @@ def thc_via_cp3(
random_start_thc - Perform random start for CP3.
If false perform HOSVD start.
verify - check eri properties. Default is False
penalty_param - penalty parameter for L2 regularization. Default is None.

returns:
eri_thc - (N x N x N x N) reconstructed ERIs from THC factorization
Expand Down Expand Up @@ -115,7 +117,9 @@ def thc_via_cp3(
if perform_bfgs_opt:
x = np.hstack((thc_leaf.ravel(), thc_central.ravel()))
# lbfgs_start_time = time.time()
x = lbfgsb_opt_thc_l2reg(eri_full, nthc, initial_guess=x, maxiter=bfgs_maxiter)
x = lbfgsb_opt_thc_l2reg(
eri_full, nthc, initial_guess=x, maxiter=bfgs_maxiter, penalty_param=penalty_param
)
# lbfgs_calc_time = time.time() - lbfgs_start_time
thc_leaf = x[: norb * nthc].reshape(nthc, norb) # leaf tensor nthc x norb
thc_central = x[norb * nthc : norb * nthc + nthc * nthc].reshape(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# coverage:ignore
# pylint: disable=wrong-import-position
import os
from uuid import uuid4
import h5py
import numpy
import numpy.random
import numpy.linalg
from scipy.optimize import minimize

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax.config import config
from jax import jit, grad
from .adagrad import adagrad
from .thc_objectives import (
Expand All @@ -22,7 +26,6 @@
# set mkl thread count for numpy einsum/tensordot calls
# leave one CPU un used so we can still access this computer
os.environ["MKL_NUM_THREADS"] = "{}".format(os.cpu_count() - 1)
config.update("jax_enable_x64", True)


class CallBackStore:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# coverage:ignore
# pylint: disable=wrong-import-position
import os
from uuid import uuid4
import scipy.optimize

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax.config import config
from jax import jit, grad
import h5py
import numpy
Expand All @@ -15,7 +20,6 @@
# set mkl thread count for numpy einsum/tensordot calls
# leave one CPU un used so we can still access this computer
os.environ["MKL_NUM_THREADS"] = "{}".format(os.cpu_count() - 1)
config.update("jax_enable_x64", True)


def thc_objective_jax(xcur, norb, nthc, eri):
Expand Down
Loading