Skip to content

Commit

Permalink
Added CI job with TSAN and free-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jan 9, 2025
1 parent 640cb00 commit 65e7f58
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 53 deletions.
110 changes: 110 additions & 0 deletions .github/workflows/tsan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
name: CI - Free-threading and Thread Sanitizer (nightly)

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
# branches:
# - main
paths:
- '**/workflows/tsan.yaml'

jobs:
tsan:
runs-on: linux-x86-n2-64
container:
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04
strategy:
fail-fast: false
defaults:
run:
shell: bash -l {0}
steps:
# Install git before actions/checkout as otherwise it will download the code with the GitHub
# REST API and therefore any subsequent git commands will fail.
- name: Install clang 18
env:
DEBIAN_FRONTEND: noninteractive
run: |
apt update
apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
libffi-dev liblzma-dev
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: python/cpython
path: cpython
ref: v3.13.0
- name: Build CPython with TSAN enabled
run: |
cd cpython
mkdir ${GITHUB_WORKSPACE}/cpython-tsan
CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpython-tsan --disable-gil --with-thread-sanitizer
make -j64
make install
# Check whether free-threading mode is enabled
PYTHON_GIL=0 ${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -c "import sys; assert not sys._is_gil_enabled()"
${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv
- name: Install JAX test requirements
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
python -m pip install -r build/test-requirements.txt
- name: Build and install JAX
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
python build/build.py build --wheels=jaxlib \
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=thread \
--bazel_options=--linkopt="-fsanitize=thread" \
--bazel_options=--@rules_python//python/config_settings:py_freethreaded="yes" \
--bazel_options=--@nanobind//:enabled_free_threading=True \
--clang_path=/usr/bin/clang-18
# We have to manually install nightly scipy, otherwise default scipy installation
# is failing to build it here: ../meson.build:84:0: ERROR: Unknown compiler(s)
python -m pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple scipy
python -m pip install dist/jaxlib-*.whl
python -m pip install -e .
- name: Run tests
timeout-minutes: 30
env:
JAX_NUM_GENERATED_CASES: 1
JAX_ENABLE_X64: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
# As we do not have yet free-threading support
# there will be the following warning:
# RuntimeWarning: The global interpreter lock (GIL) has been enabled to load module 'jaxlib.utils',
# which has not declared that it can run safely without the GIL.
# To avoid that we temporarily define PYTHON_GIL
export PYTHON_GIL=0
# Continue running all commands even if they failing
set +e
python -m pytest -s -vvv tests/jaxpr_effects_test.py::EffectOrderingTest::test_different_threads_get_different_tokens
exit_code=$?
python -m pytest -s -vvv tests/api_test.py::CustomJVPTest::test_concurrent_initial_style
exit_code=$(( $exit_code | $? ))
python -m pytest -s -vvv tests/api_test.py::APITest::test_concurrent_device_get_and_put
exit_code=$(( $exit_code | $? ))
python -m pytest -s -vvv tests/api_test.py::JitTest::test_concurrent_jit
exit_code=$(( $exit_code | $? ))
exit $exit_code
79 changes: 46 additions & 33 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2984,21 +2984,27 @@ def e(x):
self.assertIn("stablehlo.sine", stablehlo)

def test_concurrent_device_get_and_put(self):
def f(x):
for _ in range(100):
y = jax.device_put(x)
x = jax.device_get(y)
return x
# Capture ThreadSanitizer warnings and fail the test if anything reported
with jtu.capture_stderr() as get_output:
def f(x):
for _ in range(100):
y = jax.device_put(x)
x = jax.device_get(y)
return x

xs = [self.rng().randn(i) for i in range(10)]
# Make sure JAX backend is initialised on the main thread since some JAX
# backends install signal handlers.
jax.device_put(0)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x, y)
xs = [self.rng().randn(i) for i in range(10)]
# Make sure JAX backend is initialised on the main thread since some JAX
# backends install signal handlers.
jax.device_put(0)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x, y)

captured = get_output()
if len(captured) > 0 and "ThreadSanitizer" in captured:
raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured}")

def test_dtype_from_builtin_types(self):
for dtype in [bool, int, float, complex]:
Expand Down Expand Up @@ -7593,25 +7599,32 @@ def f(x, y):

def test_concurrent_initial_style(self):
# https://github.com/jax-ml/jax/issues/3843
def unroll(param, sequence):
def scan_f(prev_state, inputs):
return prev_state, jax.nn.sigmoid(param * inputs)
return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1])

def run():
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))

expected = run()

# we just don't want this to crash
n_workers = 2
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
futures = []
for _ in range(n_workers):
futures.append(e.submit(run))
results = [f.result() for f in futures]
for ans in results:
self.assertAllClose(ans, expected)

# Capture ThreadSanitizer warnings and fail the test if anything reported
with jtu.capture_stderr() as get_output:
def unroll(param, sequence):
def scan_f(prev_state, inputs):
return prev_state, jax.nn.sigmoid(param * inputs)
return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1])

def run():
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))

expected = run()

# we just don't want this to crash
n_workers = 2
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
futures = []
for _ in range(n_workers):
futures.append(e.submit(run))
results = [f.result() for f in futures]
for ans in results:
self.assertAllClose(ans, expected)

captured = get_output()
if len(captured) > 0 and "ThreadSanitizer" in captured:
raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured}")

def test_nondiff_argnums_vmap_tracer(self):
# https://github.com/jax-ml/jax/issues/3964
Expand Down
51 changes: 31 additions & 20 deletions tests/jaxpr_effects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,27 +567,38 @@ def g(x):
def test_different_threads_get_different_tokens(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
tokens = []
def _noop(_):
return ()

def f(x):
# Runs in a thread.
res = jax.jit(
lambda x: callback_p.bind(
x, callback=_noop, effect=log_effect, out_avals=[])
)(x)
tokens.append(dispatch.runtime_tokens.current_tokens[log_effect])
return res

t1 = threading.Thread(target=lambda: f(2.))
t2 = threading.Thread(target=lambda: f(3.))
t1.start()
t2.start()
t1.join()
t2.join()
token1, token2 = tokens
self.assertIsNot(token1, token2)
# Capture ThreadSanitizer warnings and fail the test if anything reported
with jtu.capture_stderr() as get_output:
tokens = []
def _noop(_):
return ()

def f(x):
# Runs in a thread.
res = jax.jit(
lambda x: callback_p.bind(
x, callback=_noop, effect=log_effect, out_avals=[])
)(x)
# This is necessary for free-threading mode
with threading.Lock():
tokens.append(dispatch.runtime_tokens.current_tokens[log_effect])
return res

t1 = threading.Thread(target=lambda: f(2.))
t2 = threading.Thread(target=lambda: f(3.))
t1.start()
t2.start()
t1.join()
t2.join()
assert len(tokens) == 2, tokens
token1, token2 = tokens
self.assertIsNot(token1, token2)

captured = get_output()
if len(captured) > 0 and "ThreadSanitizer" in captured:
raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured}")


class ParallelEffectsTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 65e7f58

Please sign in to comment.