Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .models.shift_period() #873

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion message_ix/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from dataclasses import InitVar, dataclass, field
from functools import partial
from pathlib import Path
from typing import Mapping, MutableMapping, Optional, Tuple
from typing import TYPE_CHECKING, Mapping, MutableMapping, Optional, Tuple
from warnings import warn

import ixmp.model.gams
from ixmp import config
from ixmp.backend import ItemType
from ixmp.util import maybe_check_out, maybe_commit

if TYPE_CHECKING:
from .core import Scenario

log = logging.getLogger(__name__)

#: Solver options used by :meth:`.Scenario.solve`.
Expand Down Expand Up @@ -1009,3 +1012,62 @@ def __init__(self, *args, **kwargs):
def initialize(cls, scenario, with_data=False):
MESSAGE.initialize(scenario)
MACRO.initialize(scenario, with_data)


def shift_period(scenario: "Scenario", y0: int) -> None:
"""Shift the first period of the model horizon."""
from ixmp.backend.jdbc import JDBCBackend

# Retrieve existing cat_year information, including the current 'firstmodelyear'
cat_year = scenario.set("cat_year")
y0_pre = cat_year.query("type_year == 'firstmodelyear'")["year"].item()

if y0 == y0_pre:
log.info(f"First model period is already {y0!r}")
return
elif y0 < y0_pre:
raise NotImplementedError(
f"Shift first model period *earlier*, from {y0_pre!r} -> {y0}"
)

# Periods to be shifted from within to before the model horizon
periods = list(
filter(lambda y: y0_pre <= y < y0, map(int, sorted(cat_year["year"].unique())))
)
log.info(f"Shift data for period(s): {periods}")

# Handle historical_* parameters for which the dimensions are a subset of the
# corresponding variable's dimensions
data = {}
for var_name, par_name, filter_dim in (
("ACT", "historical_activity", "year_act"),
("CAP_NEW", "historical_new_capacity", "year_vtg"),
("EXT", "historical_extraction", "year"),
("GDP", "historical_gdp", "year"),
("LAND", "historical_land", "year"),
):
# - Filter data for `var_name` along the `filter_dim`, keeping only the periods
# to be shifted.
# - Drop the marginal column; rename the level column to "value".
# - Group according to the dimensions of the target `par_name`.
# - Sum within groups.
# - Restore index columns.
data[par_name] = (
scenario.var(var_name, filters={filter_dim: periods})
.drop("mrg", axis=1)
.rename(columns={"lvl": "value"})
.groupby(list(MESSAGE.items[par_name].dims))
.sum()["value"]
.reset_index()
)

# TODO Handle "EMISS:n-e-type_tec-y" →
# "historical_emission:n-type_emission-type_tec-type_year", in which dimension names
# are changed

# TODO Adjust cat_year

if isinstance(scenario.platform._backend, JDBCBackend):
raise NotImplementedError("Cannot set variable values with JDBCBackend")

# TODO Store new data
21 changes: 20 additions & 1 deletion message_ix/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import ixmp
import pytest

from message_ix.models import MESSAGE, MESSAGE_MACRO
from message_ix.models import MESSAGE, MESSAGE_MACRO, shift_period
from message_ix.testing import make_dantzig


def test_initialize(test_mp):
Expand Down Expand Up @@ -52,3 +53,21 @@ class _MM(MESSAGE_MACRO):
"--MAX_ITERATION=100",
]
assert all(e in mm.solve_args for e in expected)


@pytest.mark.parametrize(
"y0",
(
# Not implemented: shifting to an earlier period
pytest.param(1962, marks=pytest.mark.xfail(raises=NotImplementedError)),
# Does nothing
1963,
# Not implemented with ixmp.JDBCBackend
pytest.param(1964, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(1965, marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_shift_period(test_mp, y0):
s = make_dantzig(test_mp, solve=True, multi_year=True)

shift_period(s, y0)
Loading