Skip to content

Commit

Permalink
Silencing TensorFlow warning (#332)
Browse files Browse the repository at this point in the history
**Description of the Change:**
Using python built-in logger, this PR silences the annoying
`ComplexWarning`s coming from `tf.cast`.
Additionally, it gets rid of tensorflow's import messages.

---------

Co-authored-by: Filippo Miatto <[email protected]>
  • Loading branch information
SamFerracin and ziofil authored Feb 1, 2024
1 parent 4cb1fd7 commit 774681f
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ which uses the old Numba code. When setting to a higher value, the new Julia cod
Hermite polynomials over a batch of B vectors.
[(#308)](https://github.com/XanaduAI/MrMustard/pull/308)

* Changed the ``cast`` functions in the numpy and tensorflow backends to avoid ``ComplexWarning``s.
[(#307)](https://github.com/XanaduAI/MrMustard/pull/307)
* Added suite to filter undesired warnings, and used it to filter tensorflow's ``ComplexWarning``s.
[(#332)](https://github.com/XanaduAI/MrMustard/pull/332)


### Bug fixes
Expand Down
6 changes: 5 additions & 1 deletion mrmustard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

"""This is the top-most `__init__.py` file of MrMustard package."""


from rich import print

from ._version import __version__
from .utils.settings import *
from .utils.filters import add_complex_warning_filter


def version():
Expand Down Expand Up @@ -79,3 +79,7 @@ def about():
print("Scipy version: {}".format(scipy.__version__))
print("The Walrus version: {}".format(thewalrus.__version__))
print("TensorFlow version: {}".format(tensorflow.__version__))


# filter tensorflow cast warnings
add_complex_warning_filter()
6 changes: 2 additions & 4 deletions mrmustard/lab/abstract/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,10 @@ def outcome(self):
...

@abstractmethod
def _measure_fock(self, other: State) -> Union[State, float]:
...
def _measure_fock(self, other: State) -> Union[State, float]: ...

@abstractmethod
def _measure_gaussian(self, other: State) -> Union[State, float]:
...
def _measure_gaussian(self, other: State) -> Union[State, float]: ...

def primal(self, other: State) -> Union[State, float]:
"""performs the measurement procedure according to the representation of the incoming state"""
Expand Down
4 changes: 1 addition & 3 deletions mrmustard/lab/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,7 @@ def primal(self, state):

coeff = math.cast(
math.exp(
-0.5
* self.phase_stdev.value**2
* math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2
-0.5 * self.phase_stdev.value**2 * math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2
),
dm.dtype,
)
Expand Down
11 changes: 7 additions & 4 deletions mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@

"""This module contains the tensorflow backend."""

# pylint: disable = missing-function-docstring, missing-class-docstring
# pylint: disable = missing-function-docstring, missing-class-docstring, wrong-import-position

from typing import Callable, List, Optional, Sequence, Tuple, Union

import logging
import numpy as np

logging.getLogger("tensorflow").setLevel(logging.ERROR)
import tensorflow as tf
import tensorflow_probability as tfp

logging.getLogger("tensorflow").setLevel(logging.INFO)


from mrmustard.math.lattice.strategies.compactFock.inputValidation import (
grad_hermite_multidimensional_1leftoverMode,
grad_hermite_multidimensional_diagonal,
Expand Down Expand Up @@ -106,9 +112,6 @@ def boolean_mask(self, tensor: tf.Tensor, mask: tf.Tensor) -> Tensor:
def cast(self, array: tf.Tensor, dtype=None) -> tf.Tensor:
if dtype is None:
return array

if dtype not in [self.complex64, self.complex128, "complex64", "complex128"]:
array = self.real(array)
return tf.cast(array, dtype)

def clip(self, array, a_min, a_max) -> tf.Tensor:
Expand Down
15 changes: 12 additions & 3 deletions mrmustard/training/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,23 @@ def rolling_cost_cb(optimizer, cost, **kwargs):
"""

# pylint: disable = wrong-import-position


from dataclasses import dataclass
from datetime import datetime
import hashlib
from pathlib import Path
from typing import Callable, Optional, Mapping, Sequence, Union

import logging
import numpy as np

logging.getLogger("tensorflow").setLevel(logging.ERROR)
import tensorflow as tf

logging.getLogger("tensorflow").setLevel(logging.INFO)


@dataclass
class Callback:
Expand Down Expand Up @@ -254,9 +263,9 @@ def call(
orig_cost = np.array(optimizer.callback_history["orig_cost"][-1]).item()
obj_scalars[f"{obj_tag}/orig_cost"] = orig_cost
if self.cost_converter is not None:
obj_scalars[
f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)"
] = self.cost_converter(orig_cost)
obj_scalars[f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)"] = (
self.cost_converter(orig_cost)
)

for k, v in obj_scalars.items():
tf.summary.scalar(k, data=v, step=self.optimizer_step)
Expand Down
66 changes: 66 additions & 0 deletions mrmustard/utils/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains a class to filter undesired warnings.
"""

import logging


class WarningFilters(logging.Filter):
r"""
A custom logging filter to selectively allow log records based on specific warnings.
Args:
warnings: A list of warning messages that must be filtered.
"""

def __init__(self, warnings: list[str]):
super().__init__()
self.warnings = warnings

def filter(self, record) -> bool:
r"""
Determine if the log record should be allowed based on specific warnings.
Args:
record: The ``LogRecord`` to be filtered.
Returns:
``True`` if the log record should be allowed, ``False`` otherwise.
"""
return any(w in record.getMessage() for w in self.warnings)


# ComplexWarning filter for tensorflow.
msg = "WARNING:tensorflow:You are casting an input of type complex128 to an incompatible dtype float64."
msg += " This will discard the imaginary part and may not be what you intended."
complex_warninig_filter = WarningFilters([msg])


def add_complex_warning_filter():
r"""
Adds the filter for tensorflow's ComplexWarning, or does nothing if the filter is already in place.
"""
logger = logging.getLogger("tensorflow")
logger.addFilter(complex_warninig_filter)


def remove_complex_warning_filter():
r"""
Removes the filter for tensorflow's ComplexWarning, or does nothing if no such filter is present.
"""
logger = logging.getLogger("tensorflow")
logger.removeFilter(complex_warninig_filter)
16 changes: 16 additions & 0 deletions mrmustard/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import rich.table
import numpy as np

from mrmustard.utils.filters import add_complex_warning_filter, remove_complex_warning_filter

__all__ = ["Settings", "settings"]


Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(self):
self._julia_initialized = (
False # set to True when Julia is initialized (cf. PRECISION_BITS_HERMITE_POLY.setter)
)
self._complex_warning = False

def _force_hbar(self, value):
r"can set the value of HBAR at any time. use with caution."
Expand Down Expand Up @@ -146,6 +149,19 @@ def CIRCUIT_DECIMALS(self):
def CIRCUIT_DECIMALS(self, value: int):
self._circuit_decimals = value

@property
def COMPLEX_WARNING(self):
r"""Whether tensorflow's ``ComplexWarning``s should be raised when a complex is casted to a float. Default is ``False``."""
return self._complex_warning

@COMPLEX_WARNING.setter
def COMPLEX_WARNING(self, value: bool):
self._complex_warning = value
if value:
remove_complex_warning_filter()
else:
add_complex_warning_filter()

@property
def DEBUG(self):
r"""Whether or not to print the vector of means and the covariance matrix alongside the
Expand Down
3 changes: 1 addition & 2 deletions mrmustard/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,4 @@
class Batch(Protocol[T_co]):
r"""Anything that can iterate over objects of type T_co."""

def __iter__(self) -> Iterator[T_co]:
...
def __iter__(self) -> Iterator[T_co]: ...
1 change: 0 additions & 1 deletion polyval2d-high-precision
Submodule polyval2d-high-precision deleted from 7c4a4a
1 change: 1 addition & 0 deletions tests/test_math/test_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestBackendManager:
r"""
Tests the BackendManager.
"""

l1 = [1.0]
l2 = [1.0 + 0.0j, -2.0 + 2.0j]
l3 = [[1.0, 2.0], [-3.0, 4.0]]
Expand Down
1 change: 1 addition & 0 deletions tests/test_math/test_compactFock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Unit tests for mrmustard.math.compactFock.compactFock~
"""

import importlib

import numpy as np
Expand Down
22 changes: 22 additions & 0 deletions tests/test_utils/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
Tests for the Settings class.
"""

from mrmustard import math
from mrmustard.utils.settings import Settings, ImmutableSetting
import pytest

from ..conftest import skip_np


class TestImmutableSettings:
"""Tests the ImmutableSettings class"""
Expand Down Expand Up @@ -125,3 +128,22 @@ def test_reproducibility(self):
settings.SEED = 42
seq1 = [settings.rng.integers(0, 2**31 - 1) for _ in range(10)]
assert seq0 == seq1

def test_complex_warnings(self, caplog):
"""Tests that complex warnings can be correctly activated and deactivated."""
skip_np()

settings = Settings()

assert settings.COMPLEX_WARNING is False
math.cast(1 + 1j, math.float64)
assert len(caplog.records) == 0

settings.COMPLEX_WARNING = True
math.cast(1 + 1j, math.float64)
assert len(caplog.records) == 1
assert "You are casting an input of type complex128" in caplog.records[0].msg

settings.COMPLEX_WARNING = False
math.cast(1 + 1j, math.float64)
assert len(caplog.records) == 1

0 comments on commit 774681f

Please sign in to comment.