Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scan] Test we don't recompile under debugging env flags #8555

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
#run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/scan/test_scan.py"
run_xla_hlo_debug "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_grad_checkpoint.py" "$@" --test_autocast
Expand Down Expand Up @@ -321,4 +322,4 @@ if [ "$LOGFILE" != "" ]; then
run_tests 2>&1 | tee $LOGFILE
else
run_tests
fi
fi
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_spmd.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
69 changes: 69 additions & 0 deletions test/scan/test_scan_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
import os
import unittest

import torch
import torch_xla
import torch_xla.debug.metrics as met
from torch_xla.experimental.scan import scan

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from test_utils import XlaTestCase # type:ignore


class ScanDebugTest(XlaTestCase):

def test_scan_no_recompile_with_debug_annotations(self):
"""
When someone adds debugging annotations to the HLO via env vars, the
HLO graph of the combine function captured by scan would have additional metadata
such as line numbers and scopes. Still, that should not cause the final IR
graph hash to change. This is subtle because the IR of the `scan` operation will
reference the HLO computation within.
"""
assert os.environ["XLA_HLO_DEBUG"] == "1"
met.clear_all()

def fn(carry, x):
carry = carry + x
y = x + 42
return carry, y

# fn2 should trace to the same graph despite having different line numbers
def fn2(carry, x):
carry = carry + x
y = x + 42
return carry, y

init = torch.tensor([0.0, 0.0],
requires_grad=True,
device=torch_xla.device())
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=torch_xla.device())

# Run some graph involving a scan operation two times.
for i in range(2):
init.grad = None
xs.grad = None
carry, ys = scan(fn, init, xs)
(carry.sum() + ys.sum()).backward()
torch_xla.sync()

# Use a differently named but semantically the same combine function.
# This should still trace to identical HLO and hence reuse the cache.
init.grad = None
xs.grad = None
carry, ys = scan(fn2, init, xs)
(carry.sum() + ys.sum()).backward()
torch_xla.sync()

# Should only compile once and cache the last two times.
self.assertEqual(int(met.counter_value("UncachedCompile")), 1)
self.assertEqual(int(met.counter_value("CachedCompile")), 2)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ python3 "$TEST_CDIR/test_while_loop.py"
python3 "$TEST_CDIR/scan/test_scan.py"
python3 "$TEST_CDIR/scan/test_scan_spmd.py"
python3 "$TEST_CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"
Expand Down
Loading