Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax leak problems #710

Open
Laohusong opened this issue Dec 30, 2024 · 1 comment
Open

jax leak problems #710

Laohusong opened this issue Dec 30, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@Laohusong
Copy link

I followed the official tutorial to use the bptt trainer and found jax leak problems.

with jax.checking_leaks():
    trainer.fit(train_data, num_epoch=30)

it casued

UnexpectedTracerError                     Traceback (most recent call last)
Cell In[28], [line 4](vscode-notebook-cell:?execution_count=28&line=4)
      [2](vscode-notebook-cell:?execution_count=28&line=2) import jax
      [3](vscode-notebook-cell:?execution_count=28&line=3) with jax.checking_leaks():
----> [4](vscode-notebook-cell:?execution_count=28&line=4)     trainer.fit(train_data, num_epoch=30)

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\train\back_propagation.py:285, in BPTrainer.fit(self, train_data, test_data, num_epoch, num_report, reset_state, shared_args, fun_after_report, batch_size)
    [282](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:282)   self.reset_state()
    [284](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:284) # training
--> [285](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:285) res = self.f_train(shared_args, x, y)
    [287](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:287) # loss
    [288](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:288) fit_epoch_metric['loss'].append(res[0])

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\math\object_transform\jit.py:213, in JITTransform.__call__(self, *args, **kwargs)
    [210](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:210)     return rets
    [212](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:212) # call the transformed function
--> [213](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:213) return _jit_call_take_care_of_rngs(self._transform, self._dyn_vars, *args, **kwargs)

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\math\object_transform\jit.py:94, in _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs)
     [91](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:91) def _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs):
     [92](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:92)   # call the transformed function
     [93](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:93)   rng_keys = stack.call_on_subset(_is_rng, _rng_split_key)
---> [94](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:94)   changes, out = transform(stack.dict_data(), *args, **kwargs)
     [95](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:95)   for key, v in changes.items():
     [96](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:96)     stack[key]._value = v

    [... skipping hidden 3 frame]

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\jax\_src\core.py:924, in check_eval_args(args)
    [922](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:922) for arg in args:
    [923](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:923)   if isinstance(arg, Tracer):
--> [924](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:924)     raise escaped_tracer_error(arg)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[128,100] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

I used python=3.11, brianpy=2.60, gpu version (I also tried the cpu version, the same problem),just installed following the latest tutorial.

I used my both windows and mac and found the same bug.

Will it be ok to downgrade to brainpy=2.4?

@Laohusong Laohusong added the bug Something isn't working label Dec 30, 2024
@Routhleck
Copy link
Member

Thank you for your report! To address the issue, downgrading BrainPy might resolve it. In the meantime, I encourage you to explore our new BrainPy Dynamics Programming Ecosystem. Additionally, you can check out our ANN Training Tutorial for more insights.

Regarding this bug, I may need assistance from @chaoming0625. Rest assured, we will prioritize resolving this in the near future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants