You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I followed the official tutorial to use the bptt trainer and found jax leak problems.
with jax.checking_leaks():
trainer.fit(train_data, num_epoch=30)
it casued
UnexpectedTracerError Traceback (most recent call last)
Cell In[28], [line 4](vscode-notebook-cell:?execution_count=28&line=4)
[2](vscode-notebook-cell:?execution_count=28&line=2) import jax
[3](vscode-notebook-cell:?execution_count=28&line=3) with jax.checking_leaks():
----> [4](vscode-notebook-cell:?execution_count=28&line=4) trainer.fit(train_data, num_epoch=30)
File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\train\back_propagation.py:285, in BPTrainer.fit(self, train_data, test_data, num_epoch, num_report, reset_state, shared_args, fun_after_report, batch_size)
[282](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:282) self.reset_state()
[284](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:284) # training
--> [285](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:285) res = self.f_train(shared_args, x, y)
[287](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:287) # loss
[288](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:288) fit_epoch_metric['loss'].append(res[0])
File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\math\object_transform\jit.py:213, in JITTransform.__call__(self, *args, **kwargs)
[210](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:210) return rets
[212](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:212) # call the transformed function
--> [213](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:213) return _jit_call_take_care_of_rngs(self._transform, self._dyn_vars, *args, **kwargs)
File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\math\object_transform\jit.py:94, in _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs)
[91](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:91) def _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs):
[92](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:92) # call the transformed function
[93](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:93) rng_keys = stack.call_on_subset(_is_rng, _rng_split_key)
---> [94](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:94) changes, out = transform(stack.dict_data(), *args, **kwargs)
[95](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:95) for key, v in changes.items():
[96](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:96) stack[key]._value = v
[... skipping hidden 3 frame]
File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\jax\_src\core.py:924, in check_eval_args(args)
[922](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:922) for arg in args:
[923](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:923) if isinstance(arg, Tracer):
--> [924](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:924) raise escaped_tracer_error(arg)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[128,100] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
I used python=3.11, brianpy=2.60, gpu version (I also tried the cpu version, the same problem),just installed following the latest tutorial.
I used my both windows and mac and found the same bug.
Will it be ok to downgrade to brainpy=2.4?
The text was updated successfully, but these errors were encountered:
Thank you for your report! To address the issue, downgrading BrainPy might resolve it. In the meantime, I encourage you to explore our new BrainPy Dynamics Programming Ecosystem. Additionally, you can check out our ANN Training Tutorial for more insights.
Regarding this bug, I may need assistance from @chaoming0625. Rest assured, we will prioritize resolving this in the near future.
I followed the official tutorial to use the bptt trainer and found jax leak problems.
it casued
I used python=3.11, brianpy=2.60, gpu version (I also tried the cpu version, the same problem),just installed following the latest tutorial.
I used my both windows and mac and found the same bug.
Will it be ok to downgrade to brainpy=2.4?
The text was updated successfully, but these errors were encountered: