Skip to content

Commit

Permalink
Hide PyTorch trace compilation warnings
Browse files Browse the repository at this point in the history
The test execution shows warnings about traces being potentially incorrect because the Python3 control flow is not completely recorded.
This includes conditions on the shape of the integration domain tensor.
Since the only arguments of the compiled integration function are the integrand and integration domain,
and the dimensionality of this integration domain is constant,
we can ignore the warnings.

After this change,
the two `get_jit_compiled_integrate` functions hide PyTorch trace compilation warnings with `warnings.filterwarnings`.
  • Loading branch information
FHof committed Oct 18, 2023
1 parent 1692758 commit f561a61
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
13 changes: 3 additions & 10 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from loguru import logger
from autoray import numpy as anp, infer_backend

from ..utils.torch_trace_without_warnings import _torch_trace_without_warnings
from .base_integrator import BaseIntegrator
from .integration_grid import IntegrationGrid
from .utils import (
Expand Down Expand Up @@ -229,7 +230,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the first step
step1 = torch.jit.trace(step1, (integration_domain,))
step1 = _torch_trace_without_warnings(step1, (integration_domain,))

# Get example input for the third step
grid_points, hs, n_per_dim = step1(integration_domain)
Expand All @@ -241,15 +242,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the third step
# Avoid the warnings about a .grad attribute access of a
# non-leaf Tensor
if hs.requires_grad:
hs = hs.detach()
hs.requires_grad = True
if function_values.requires_grad:
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(
step3 = _torch_trace_without_warnings(
step3, (function_values, hs, integration_domain)
)

Expand Down
16 changes: 7 additions & 9 deletions torchquad/integration/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from autoray import infer_backend
from loguru import logger

from ..utils.torch_trace_without_warnings import _torch_trace_without_warnings
from .base_integrator import BaseIntegrator
from .utils import _setup_integration_domain, expand_func_values_and_squeeze_integral
from .rng import RNG
Expand Down Expand Up @@ -195,8 +196,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
return self.calculate_sample_points(
Expand All @@ -206,7 +205,9 @@ def step1(integration_domain):
step3 = self.calculate_result

# Trace the first step (which is non-deterministic)
step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False)
step1 = _torch_trace_without_warnings(
step1, (integration_domain,), check_trace=False
)

# Get example input for the third step
sample_points = step1(integration_domain)
Expand All @@ -215,12 +216,9 @@ def step1(integration_domain):
)

# Trace the third step
if function_values.requires_grad:
# Avoid the warning about a .grad attribute access of a
# non-leaf Tensor
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(step3, (function_values, integration_domain))
step3 = _torch_trace_without_warnings(
step3, (function_values, integration_domain)
)

# Define a compiled integrate function
def compiled_integrate(fn, integration_domain):
Expand Down
16 changes: 16 additions & 0 deletions torchquad/utils/torch_trace_without_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import warnings


def _torch_trace_without_warnings(*args, **kwargs):
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings
PyTorch can show warnings about traces being potentially incorrect because
the Python3 control flow is not completely recorded.
This function can be used to hide the warnings in situations where they are
false positives.
"""
import torch

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return torch.jit.trace(*args, **kwargs)

0 comments on commit f561a61

Please sign in to comment.