Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax config import #199

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alberthli
Copy link

This PR changes the import statement for jax.config, which resolves the error

File "/usr/local/lib/python3.10/dist-packages/torchquad/utils/set_precision.py", line 58, in set_precision
    from jax.config import config
ImportError: cannot import name 'config' from 'jax.config' (/usr/local/lib/python3.10/dist-packages/jax/config.py)

@gomezzz gomezzz changed the base branch from main to develop April 25, 2024 10:17
@gomezzz gomezzz changed the base branch from develop to main April 25, 2024 10:17
@gomezzz
Copy link
Collaborator

gomezzz commented Apr 25, 2024

Hi @alberthli , thanks for the PR!

To understand a bit better, it is my understand jax deprecated this import recently.

Did you check by any chance if this works for older jax version as well? Currently, we require jax>=0.2.22 , so we may need to update the version

(P.S. don't worry update the failing tests, the failures are unrelated to this PR, I think, and a regression in the CI)

@alberthli
Copy link
Author

alberthli commented Apr 25, 2024

Hi @gomezzz, I haven't checked whether it works for earlier versions. This commit from 6 months ago seems relevant if you want to control versioning, though.

@gomezzz gomezzz mentioned this pull request Jun 27, 2024
@gomezzz
Copy link
Collaborator

gomezzz commented Sep 12, 2024

Hi @alberthli ! Sorry for the long delay. Yes, it looks good and works mostly with jax>=0.4.17 (having trouble to test with earlier version due to issues with jax). Maybe we should bump the recommended version of jax though.

Changes for that would be:

Would you mind updating it, @alberthli in this PR? Otherwise I can do it.

I have been trying to run the tests locally on CPU too but I have a problem with jax now.

FAILED integrator_types_test.py::test_integrate_jax - AssertionError: assert 'float64' == 'float32'
FAILED monte_carlo_test.py::test_integrate_jax - assert 0.01089246934838961 < 0.01
FAILED utils_integration_test.py::test_setup_integration_domain - AssertionError: assert 'float64' == 'float32'

I think the middle one just has a too aggressive threshold but the other two seem to be changes in the way jax behaves?

Errors are thrown here

integrator_types_test.py:90: in _run_simple_integrations
    result = integrator.integrate(
../integration/trapezoid.py:25: in integrate
    return super().integrate(fn, dim, N, integration_domain, backend)
../integration/grid_integrator.py:50: in integrate
    function_values, num_points = self.evaluate_integrand(
../integration/base_integrator.py:65: in evaluate_integrand
    result = fn(points, *args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

x = Array([[ 0.        , -2.        ],
       [ 0.08333333, -2.        ],
       [ 0.16666667, -2.        ],
       [ 0.25...      [ 0.83333333,  0.        ],
       [ 0.91666667,  0.        ],
       [ 1.        ,  0.        ]], dtype=float64)

    def fn_const(x):
        assert infer_backend(x) == backend
>       assert get_dtype_name(x) == expected_dtype_name
E       AssertionError: assert 'float64' == 'float32'
E         
E         - float32
E         + float64

integrator_types_test.py:43: AssertionError

and

utils_integration_test.py:48: in _run_tests_with_all_backends
    func(dtype_name=dtype_name, backend=backend, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

dtype_name = 'float32', backend = 'jax'

    def _run_setup_integration_domain_tests(dtype_name, backend):
        """
        Test _setup_integration_domain with the given dtype and numerical backend
        """
        print(
            f"Testing _setup_integration_domain; backend: {backend}, precision: {dtype_name}"
        )
    
        # Domain given as List with Python floats
        domain = _setup_integration_domain(2, [[0.0, 1.0], [1.0, 2.0]], backend)
        assert infer_backend(domain) == backend
>       assert get_dtype_name(domain) == dtype_name
E       AssertionError: assert 'float64' == 'float32'
E         
E         - float32
E         + float64

So either we just change the type in the test or might have to make a change in the set_precision.py. If this is too much, @alberthli , we can also move that to a dedicated issue since it is not directly related to your changes, if you prefer?

Thanks and sorry again for the delays!

@alberthli
Copy link
Author

Hi @gomezzz, my bandwidth is very limited in the next couple of weeks, so I probably won't be able to write this PR. I think separating the testing issues into a separate issue is a good idea, though perhaps that fix should be merged with this one together.

@HGangloff
Copy link

Hi,
I would like to try torchquad with JAX for my project. Shouldn't we ask for a minimal JAX version of 0.4.25 following https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-25-feb-26-2024 ?
I can make the PR if needed

@gomezzz
Copy link
Collaborator

gomezzz commented Sep 13, 2024

Hi @HGangloff ,
ah yes, I was looking for that info unsuccessfully. I don't know jax so well :).

Sounds good, please go ahead, thanks! In case you have an idea why the datatypes in the tests changed (which should have been set via the here modified set_precision function, I think) I would appreciate your insight!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants