Skip to content

Commit

Permalink
[scan] Test we don't recompile under debugging env flags (#8555)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored and qihqi committed Jan 16, 2025
1 parent b94da1d commit d2348a4
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
#run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/scan/test_scan.py"
run_xla_hlo_debug "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_grad_checkpoint.py" "$@" --test_autocast
Expand Down Expand Up @@ -321,4 +322,4 @@ if [ "$LOGFILE" != "" ]; then
run_tests 2>&1 | tee $LOGFILE
else
run_tests
fi
fi
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_spmd.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
69 changes: 69 additions & 0 deletions test/scan/test_scan_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
import os
import unittest

import torch
import torch_xla
import torch_xla.debug.metrics as met
from torch_xla.experimental.scan import scan

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from test_utils import XlaTestCase # type:ignore


class ScanDebugTest(XlaTestCase):

def test_scan_no_recompile_with_debug_annotations(self):
"""
When someone adds debugging annotations to the HLO via env vars, the
HLO graph of the combine function captured by scan would have additional metadata
such as line numbers and scopes. Still, that should not cause the final IR
graph hash to change. This is subtle because the IR of the `scan` operation will
reference the HLO computation within.
"""
assert os.environ["XLA_HLO_DEBUG"] == "1"
met.clear_all()

def fn(carry, x):
carry = carry + x
y = x + 42
return carry, y

# fn2 should trace to the same graph despite having different line numbers
def fn2(carry, x):
carry = carry + x
y = x + 42
return carry, y

init = torch.tensor([0.0, 0.0],
requires_grad=True,
device=torch_xla.device())
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=torch_xla.device())

# Run some graph involving a scan operation two times.
for i in range(2):
init.grad = None
xs.grad = None
carry, ys = scan(fn, init, xs)
(carry.sum() + ys.sum()).backward()
torch_xla.sync()

# Use a differently named but semantically the same combine function.
# This should still trace to identical HLO and hence reuse the cache.
init.grad = None
xs.grad = None
carry, ys = scan(fn2, init, xs)
(carry.sum() + ys.sum()).backward()
torch_xla.sync()

# Should only compile once and cache the last two times.
self.assertEqual(int(met.counter_value("UncachedCompile")), 1)
self.assertEqual(int(met.counter_value("CachedCompile")), 2)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ python3 "$TEST_CDIR/test_while_loop.py"
python3 "$TEST_CDIR/scan/test_scan.py"
python3 "$TEST_CDIR/scan/test_scan_spmd.py"
python3 "$TEST_CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"
Expand Down

0 comments on commit d2348a4

Please sign in to comment.