Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train script doesnt run #17

Open
grgkopanas opened this issue Jan 18, 2022 · 4 comments
Open

Train script doesnt run #17

grgkopanas opened this issue Jan 18, 2022 · 4 comments

Comments

@grgkopanas
Copy link

Hi,

When I follow the instructions for building the environment:

conda env create -f environment.yml
conda activate plenoctree
pip install --upgrade pip
pip install --upgrade jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

My setup is: cuda/11.0 cudnn/8.0-cuda-11.0

I get

(plenoctree) [gkopanas@nefgpu37 plenoctree]$ python -m nerf_sh.train --train_dir ckpts/chair/ --config nerf_sh/config/blender --data_dir ../scenes/nerf/chair/
Traceback (most recent call last):
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/graphdeco/user/gkopanas/plenoctree/nerf_sh/train.py", line 29, in <module>
    import flax
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/flax/__init__.py", line 36, in <module>
    from . import core
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/flax/core/axes_scan.py", line 17, in <module>
    import jax
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/__init__.py", line 93, in <module>
    from . import image
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/image/__init__.py", line 18, in <module>
    from jax._src.image.scale import (
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/_src/image/scale.py", line 20, in <module>
    from jax import lax
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/lax/__init__.py", line 324, in <module>
    from jax._src.lax.fft import (
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/_src/lax/fft.py", line 87, in <module>
    def _rfft_transpose(t, fft_lengths):
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/api.py", line 184, in jit
    return _cpp_jit(fun, static_argnums, device, backend, donate_argnums)
  File "/home/gkopanas/.conda/envs/plenoctree/lib/python3.8/site-packages/jax/api.py", line 370, in _cpp_jit
    cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
TypeError: jit(): incompatible function arguments. The following argument types are supported:
    1. (fun: function, cache_miss: function, get_device: function, static_argnums: List[int], cache_size: int = 4096) -> jaxlib.xla_extension.jax_jit.CompiledFunction

Invoked with: <function _rfft_transpose at 0x7f71ac928a60>, <function _cpp_jit.<locals>.cache_miss at 0x7f71ac928af0>, <function _cpp_jit.<locals>.get_device_info at 0x7f71ac928b80>, <function _cpp_jit.<locals>.get_jax_enable_x64 at 0x7f71ac928c10>, <function _cpp_jit.<locals>.get_jax_disable_jit_flag at 0x7f71ac928ca0>, (0, 2)

Does anyone have a clue what went wrong?

@grgkopanas
Copy link
Author

I also tried to manully install all dependencies by following the guidelines of each respective pacakage (jax, flax, tensorflow etc)
and ended up in another error which is deeper in the code:

I0119 11:16:00.161019 140446576289600 checkpoints.py:249] Found no checkpoint files in ckpts/chair
/home/gkopanas/.conda/envs/plenoctrees4/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:413: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
* Prefetch
/home/gkopanas/.conda/envs/plenoctrees4/lib/python3.8/site-packages/jax/interpreters/xla.py:799: UserWarning: Some donated buffers were not usable: f32[1024,3]{1,0}, f32[1024,3]{1,0}, f32[1024,3]{1,0}, f32[1024,3]{1,0}
  warnings.warn("Some donated buffers were not usable: {}".format(
Traceback (most recent call last):
  File "/home/gkopanas/.conda/envs/plenoctrees4/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/gkopanas/.conda/envs/plenoctrees4/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/graphdeco/user/gkopanas/plenoctree/nerf_sh/train.py", line 314, in <module>
    app.run(main)
  File "/home/gkopanas/.conda/envs/plenoctrees4/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/gkopanas/.conda/envs/plenoctrees4/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/data/graphdeco/user/gkopanas/plenoctree/nerf_sh/train.py", line 198, in main
    state, stats, keys = train_pstep(keys, state, batch, lr)
ValueError: INTERNAL: Failed to launch CUDA kernel: fusion_201 with block dimensions: 1024x1x1 and grid dimensions: 576x1x1: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

@Holmes-Alan
Copy link

Have you resolved the problem?

@drcfts
Copy link

drcfts commented Nov 24, 2022

Were you or @Holmes-Alan able to solve this problem?

@chenyuntc
Copy link

The key is to setup jax env

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants