Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems Running Jax-Triton with an Nvidia 4090 #114

Closed
adam-hartshorne opened this issue Apr 2, 2023 · 8 comments
Closed

Problems Running Jax-Triton with an Nvidia 4090 #114

adam-hartshorne opened this issue Apr 2, 2023 · 8 comments

Comments

@adam-hartshorne
Copy link

Running the quick start example using an Nvidia 4090, if you use the suggested triton version (2.0.0.dev20221202), you receive the following error

RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/filePpgwy4, line 6; error   : PTX .version 7.4 does not support .target sm_89
ptxas fatal   : Ptx assembly aborted due to errors

Upgrading the latest development version of triton, it is possible to run Pytorch based examples, but JAX-ML results in the following error

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 135, in <module>
    print(add(x_val, y_val))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 124, in add
    return jt.triton_call(
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 531, in triton_call
    out_flat = triton_kernel_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AttributeError: module 'triton.compiler' has no attribute '_compile'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 135, in <module>
    print(add(x_val, y_val))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 124, in add
    return jt.triton_call(
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 531, in triton_call
    out_flat = triton_kernel_call_p.bind(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/core.py", line 807, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 122, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/util.py", line 254, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/util.py", line 247, in cached
    return f(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 201, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 353, in _xla_callable_uncached
    computation = sharded_lowering(fun, device, backend, name, donated_invars,
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 343, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 3082, in lower_sharding_computation
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 742, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1044, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1179, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 320, in triton_kernel_call_lowering
    kernel, specialization = get_or_create_triton_kernel(
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 173, in get_or_create_triton_kernel
    asm, shared_mem, name = tc._compile(
AttributeError: module 'triton.compiler' has no attribute '_compile'. Did you mean: 'compile'?
@sharadmv
Copy link
Collaborator

sharadmv commented Apr 3, 2023

We don't yet support the latest Triton dev version but we soon will!

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Apr 3, 2023

I am correct in thinking that Nvidia 4090s aren't supported by any earlier versions of Triton, thus for the moment we are in a situation where I can't use JAX-ML until further updates?

@sharadmv
Copy link
Collaborator

sharadmv commented Apr 3, 2023

There's an open PR that rebases Jax triton on top of triton at head
You could try checking out and using that.

@sharadmv
Copy link
Collaborator

sharadmv commented Apr 3, 2023

Link here: #50

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Apr 15, 2023

After seeing that the PR has been accepted into the main branch, I have revisited this issue. The original error has gone, however I now receive the following error when I run the provided examples e.g. https://github.com/jax-ml/jax-triton/blob/main/examples/add.py

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 18, in <module>
    import jax_triton as jt
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax_triton/__init__.py", line 19, in <module>
    from jax_triton.triton_lib import triton_call
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 47, in <module>
    from triton.compiler import code_generator as code_gen
ImportError: cannot import name 'code_generator' from 'triton.compiler' (/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/triton/compiler.py)

which obviously suggests that that function doesn't exist / not installed, but it is definitely there and that I can run the example Triton / Pytorch examples from their tutorials without error.

@sharadmv
Copy link
Collaborator

You'll need to use triton installed from HEAD or nightly.

@adam-hartshorne
Copy link
Author

I have, I am using the triton nightly.

@adam-hartshorne
Copy link
Author

I total reinstall from scratch seems to have fixed the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants