Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make MAX_NET_DETUNING in AHS device capabilities optional and extend field validator utility functions. #272

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from decimal import Decimal
from typing import Optional

from pydantic import PositiveInt
from pydantic.v1.main import BaseModel
Expand All @@ -33,4 +33,4 @@ class CapabilitiesConstants(BaseModel):

MAGNITUDE_PATTERN_VALUE_MIN: Decimal
MAGNITUDE_PATTERN_VALUE_MAX: Decimal
MAX_NET_DETUNING: Decimal
MAX_NET_DETUNING: Optional[Decimal]
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,50 @@ def validate_net_detuning_with_warning(
# Return immediately if there is an atom has net detuning
# exceeding MAX_NET_DETUNING at a time point
return program


Copy link
Contributor

@maolinml maolinml Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to remind myself, are these four functions the only things that needs to be moved from the service side to the client side? Also it maybe good to comment on the top of each of these functions that they are only for device emulator and not for AHS local simulator [so that people won't be confused that why they are not used in this repo].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, don't we also need "validate_pattern_precision"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_pattern_precision is never used by the Device validators so I chose not to bring it over to the Default Simulator repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the additional comments makes sense for explaining the use of these helpers!

# Two time points cannot be too close, assuming the time points are sorted ascendingly
def validate_time_separation(times: List[Decimal], min_time_separation: Decimal, name: str):
for i in range(len(times) - 1):
time_diff = times[i + 1] - times[i]
if time_diff < min_time_separation:
raise ValueError(
f"Time points of {name} time_series, {i} ({times[i]}) and "
f"{i + 1} ({times[i + 1]}), are too close; they are separated "
f"by {time_diff} seconds. It must be at least {min_time_separation} seconds"
)


def validate_value_precision(values: List[Decimal], max_precision: Decimal, name: str):
# Raise ValueError if at any item in the values is beyond the max allowable precision
for idx, v in enumerate(values):
if v % max_precision != 0:
raise ValueError(
f"Value {idx} ({v}) in {name} time_series is defined with too many digits; "
f"it must be an integer multiple of {max_precision}"
)


def validate_max_absolute_slope(
times: List[Decimal], values: List[Decimal], max_slope: Decimal, name: str
):
# Raise ValueError if at any time the time series (times, values)
# rises/falls faster than allowed
for idx in range(len(values) - 1):
slope = (values[idx + 1] - values[idx]) / (times[idx + 1] - times[idx])
if abs(slope) > max_slope:
raise ValueError(
f"For the {name} field, rate of change of values "
f"(between the {idx}-th and the {idx + 1}-th times) "
f"is {abs(slope)}, more than {max_slope}"
)


def validate_time_precision(times: List[Decimal], time_precision: Decimal, name: str):
for idx, t in enumerate(times):
if t % time_precision != 0:
raise ValueError(
f"time point {idx} ({t}) of {name} time_series is "
f"defined with too many digits; it must be an "
f"integer multiple of {time_precision}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def net_detuning_must_not_exceed_max_net_detuning(cls, values):
# If no local detuning, we simply return the values
# because there are separate validators to validate
# the global driving fields in the program
if not len(local_detuning):
if not len(local_detuning) or not capabilities.MAX_NET_DETUNING:
return values

detuning_times = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from decimal import Decimal

import pytest

from braket.analog_hamiltonian_simulator.rydberg.validators.field_validator_util import (
validate_max_absolute_slope,
validate_time_precision,
validate_time_separation,
validate_value_precision,
)


@pytest.mark.parametrize(
"times, min_time_separation, fail",
[
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1e-3"),
True,
),
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1e-6"),
False,
),
(
[Decimal("0.0"), Decimal("1"), Decimal("2"), Decimal("3"), Decimal("4")],
Decimal("1e-3"),
False,
),
],
)
def test_validate_time_separation(times, min_time_separation, fail):
if fail:
with pytest.raises(ValueError):
validate_time_separation(times, min_time_separation, "test")
else:
try:
validate_time_separation(times, min_time_separation, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_min_time_separation: {str(e)}")


@pytest.mark.parametrize(
"values, max_precision, fail",
[
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1e-3"),
True,
),
(
[Decimal("0.0"), Decimal("1e-9"), Decimal("2e-5"), Decimal("3e-4"), Decimal("5.0")],
Decimal("1e-6"),
True,
),
(
[
Decimal("0.0"),
Decimal("0.00089"),
Decimal("2e-4"),
Decimal("0.003"),
Decimal("0.21"),
Decimal("1"),
],
Decimal("1e-5"),
False,
),
],
)
def test_validate_value_precision(values, max_precision, fail):
if fail:
with pytest.raises(ValueError):
validate_value_precision(values, max_precision, "test")
else:
try:
validate_value_precision(values, max_precision, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_value_precision: {str(e)}")


@pytest.mark.parametrize(
"times, values, max_slope, fail",
[
(
[Decimal("0.0"), Decimal("1.0"), Decimal("2.0"), Decimal("3.0")],
[Decimal("0.0"), Decimal("2.1"), Decimal("3.2"), Decimal("3.9")],
Decimal("2.0"),
True,
),
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("3")],
[Decimal("0.0"), Decimal("1.2"), Decimal("2.34"), Decimal("2.39")],
Decimal("1.5e5"),
False,
),
(
[Decimal("0.0"), Decimal("1.0"), Decimal("2e-5"), Decimal("3")],
[Decimal("0.0"), Decimal("1.2"), Decimal("2.34"), Decimal("2.39")],
Decimal("1e4"),
False,
),
],
)
def test_validate_max_absolute_slope(times, values, max_slope, fail):
if fail:
with pytest.raises(ValueError):
validate_max_absolute_slope(times, values, max_slope, "test")
else:
try:
validate_max_absolute_slope(times, values, max_slope, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_max_absolute_slope: {str(e)}")


@pytest.mark.parametrize(
"times, max_precision, fail",
[
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1.3"),
True,
),
(
[Decimal("0.0"), Decimal("1e-9"), Decimal("2e-5"), Decimal("3e-4"), Decimal("5.0")],
Decimal("1e-6"),
True,
),
(
[Decimal("0"), Decimal("1e-07"), Decimal("3.9e-06"), Decimal("4e-06")],
Decimal("1e-09"),
False,
),
],
)
def test_validate_time_precision(times, max_precision, fail):
if fail:
with pytest.raises(ValueError):
validate_time_precision(times, max_precision, "test")
else:
try:
validate_time_precision(times, max_precision, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_min_time_precision: {str(e)}")