Skip to content

Commit

Permalink
Handle values defined on different period structures
Browse files Browse the repository at this point in the history
  • Loading branch information
guillett committed Jun 8, 2021
1 parent 7253220 commit ac70e28
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 53 deletions.
82 changes: 82 additions & 0 deletions openfisca_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openfisca_core.errors import PeriodMismatchError
from openfisca_core.data_storage import InMemoryStorage, OnDiskStorage
from openfisca_core.indexed_enums import Enum
from openfisca_core.holders.helpers import set_input_divide_by_period


class Holder:
Expand Down Expand Up @@ -130,6 +131,87 @@ def get_known_periods(self):
return list(self._memory_storage.get_known_periods()) + list((
self._disk_storage.get_known_periods() if self._disk_storage else []))

def solve_input_data(self, buffer):
"""
Determines individual period values given a input buffer with potentiel overlapping periods and data
"""
if self.variable.set_input == set_input_divide_by_period and self.variable.definition_period == periods.MONTH:
# Create grid and determine start and end month
keys = buffer.keys()
periodMap = {k: periods.period(k) for k in keys}
for k in keys:
p = periodMap[k]
if p.unit == periods.ETERNITY and self.variable.definition_period != periods.ETERNITY:
error_message = os.linesep.join([
'Unable to set a value for variable {0} for ETERNITY.',
'{0} is only defined for {1}s. Please adapt your input.',
]).format(
self.variable.name,
self.variable.definition_period
)
raise PeriodMismatchError(
self.variable.name,
p,
self.variable.definition_period,
error_message
)

periodValues = periodMap.values()
starts = [p.start for p in periodValues]
start = min(starts)
ends = [p.offset(p.size_in_months - 1, unit=periods.MONTH).start for p in periodValues]
end = max(ends)

def month_index(instant):
return (instant.year - start.year) * 12 + (instant.month - start.month)

size = month_index(end) + 1

full_size = (self.population.count, size)
presence = numpy.full(full_size, False)
values = self.variable.default_array(full_size)
dim1, dim2 = numpy.indices(full_size)

# Set single period values
for k in keys:
p = periodMap[k]
if p.size_in_months == 1:
column = month_index(p.start)
p_presence, p_values = buffer[k]
tile_count = self.population.count // p_values.size
presence[:, column] = numpy.tile(p_presence, tile_count)
values[:, column] = numpy.tile(p_values, tile_count)

# Set values for multiple period input
for k in keys:
p = periodMap[k]
if p.size_in_months != 1:
# Determine period indexes
start_index = month_index(p.start)
idx = slice(start_index, start_index + p.size_in_months)

counts = p.size_in_months - presence[:, idx].sum(axis=1)
current_sum = values[:, idx].sum(axis=1)

p_presence, p_values = buffer[k]
spread_value = (p_values - current_sum) / counts

i1 = dim1[p_presence, idx][~presence[p_presence, idx]]
i2 = dim2[p_presence, idx][~presence[p_presence, idx]]
values[i1, i2] = numpy.repeat(spread_value[p_presence], counts[p_presence])
presence[i1, i2] = True

# Extract relevant slices
first = start.period(periods.MONTH)
new_buffer = {}
for i in range(size):
p = first.offset(i)
if presence[:, i].any():
new_buffer[str(p)] = values[:, i]
return new_buffer

return {periodKey: buffer[periodKey][1] for periodKey in buffer}

def set_input(self, period, array):
"""
Set a variable's value (``array``) for a given period (``period``)
Expand Down
72 changes: 43 additions & 29 deletions openfisca_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
self.persons_plural = None # Plural name for person entity in current tax and benefits system

# JSON input - Memory of known input values. Indexed by variable or axis name.
self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {}
self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), tuple(numpy.array, numpy.array)]] = {}
self.populations: typing.Dict[Entity.key, Population] = {}
# JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes.
self.entity_counts: typing.Dict[Entity.plural, int] = {}
Expand Down Expand Up @@ -379,20 +379,21 @@ def add_variable_value(self, entity, variable, instance_index, instance_id, peri
if value is None:
return

array = self.get_input(variable.name, str(period_str))
data_tuple = self.get_input(variable.name, str(period_str))

if array is None:
if data_tuple is None:
array_size = self.get_count(entity.plural)
array = variable.default_array(array_size)

data_tuple = (numpy.full(array_size, False), array)
try:
value = variable.check_set_value(value)
except ValueError as error:
raise SituationParsingError(path_in_json, *error.args)

(present, array) = data_tuple
present[instance_index] = True
array[instance_index] = value

self.input_buffer[variable.name][str(periods.period(period_str))] = array
self.input_buffer[variable.name][str(periods.period(period_str))] = data_tuple

def finalize_variables_init(self, population):
# Due to set_input mechanism, we must bufferize all inputs, then actually set them,
Expand All @@ -409,10 +410,11 @@ def finalize_variables_init(self, population):
holder = population.get_holder(variable_name)
except ValueError: # Wrong entity, we can just ignore that
continue
buffer = self.input_buffer[variable_name]
unsorted_periods = [periods.period(period_str) for period_str in self.input_buffer[variable_name].keys()]
buffer = holder.solve_input_data(self.input_buffer[variable_name])

unsorted_periods = [periods.period(period_str) for period_str in buffer.keys()]
# We need to handle small periods first for set_input to work
sorted_periods = sorted(unsorted_periods, key = periods.key_period_size)
sorted_periods = sorted(unsorted_periods, key=periods.key_period_size)
for period_value in sorted_periods:
values = buffer[str(period_value)]
# Hack to replicate the values in the persons entity
Expand Down Expand Up @@ -480,9 +482,9 @@ def expand_axes(self):
# Adjust counts
self.axes_entity_counts[entity_name] = self.get_count(entity_name) * cell_count
# Adjust ids
original_ids = self.get_ids(entity_name) * cell_count
indices = numpy.arange(0, cell_count * self.entity_counts[entity_name])
adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)]
original_ids = self.get_ids(entity_name)
indices = numpy.arange(0, cell_count)
adjusted_ids = [id + str(ix) for ix in indices for id in original_ids]
self.axes_entity_ids[entity_name] = adjusted_ids
# Adjust roles
original_roles = self.get_roles(entity_name)
Expand All @@ -506,22 +508,30 @@ def expand_axes(self):
axis_entity_step_size = self.entity_counts[axis_entity.plural]
# Distribute values along axes
for axis in parallel_axes:
axis_index = axis.get('index', 0)
axis_period = axis.get('period', self.default_period)
axis_name = axis['name']
variable = axis_entity.get_variable(axis_name)
array = self.get_input(axis_name, str(axis_period))
if array is None:

data_tuple = self.get_input(axis_name, axis_period)
if data_tuple is None:
array = variable.default_array(axis_count * axis_entity_step_size)
elif array.size == axis_entity_step_size:
array = numpy.tile(array, axis_count)
array[axis_index:: axis_entity_step_size] = numpy.linspace(
axis['min'],
axis['max'],
num = axis_count,
present = numpy.full(array.size, True)
elif data_tuple[1].size == axis_entity_step_size:
array = numpy.tile(data_tuple[1], cell_count)
present = numpy.full(array.size, True)
else:
array, present = data_tuple
array[:] = numpy.repeat(
numpy.linspace(
axis['min'],
axis['max'],
axis_count
),
axis_entity_step_size
)

# Set input
self.input_buffer[axis_name][str(axis_period)] = array
self.input_buffer[axis_name][str(axis_period)] = (present, array)
else:
first_axes_count: typing.List[int] = (
parallel_axes[0]["count"]
Expand All @@ -541,18 +551,22 @@ def expand_axes(self):
axis_entity_step_size = self.entity_counts[axis_entity.plural]
# Distribute values along the grid
for axis in parallel_axes:
axis_index = axis.get('index', 0)
axis_period = axis['period'] or self.default_period
axis_name = axis['name']
variable = axis_entity.get_variable(axis_name)
array = self.get_input(axis_name, str(axis_period))
if array is None:
data_tuple = self.get_input(axis_name, axis_period)
if data_tuple is None:
array = variable.default_array(cell_count * axis_entity_step_size)
elif array.size == axis_entity_step_size:
array = numpy.tile(array, cell_count)
array[axis_index:: axis_entity_step_size] = axis['min'] \
present = numpy.full(array.size, True)
elif data_tuple[1].size == axis_entity_step_size:
array = numpy.tile(data_tuple[1], cell_count)
present = numpy.full(array.size, True)
else:
array, present = data_tuple

array[:] = axis['min'] \
+ mesh.reshape(cell_count) * (axis['max'] - axis['min']) / (axis_count - 1)
self.input_buffer[axis_name][str(axis_period)] = array
self.input_buffer[axis_name][str(axis_period)] = (present, array)

def get_variable_entity(self, variable_name):
return self.variable_entities[variable_name]
Expand Down
38 changes: 19 additions & 19 deletions tests/core/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_add_axis_without_period(persons):
simulation_builder.register_variable('salary', persons)
simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 1500, 3000])


# With variables
Expand All @@ -35,12 +35,12 @@ def test_add_axis_on_an_existing_variable_with_input(persons):
simulation_builder.register_variable('salary', persons)
simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_count('persons') == 3
assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2']


# With entities
# # With entities


def test_add_axis_on_persons(persons):
Expand All @@ -49,7 +49,7 @@ def test_add_axis_on_persons(persons):
simulation_builder.register_variable('salary', persons)
simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_count('persons') == 3
assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2']

Expand All @@ -61,30 +61,30 @@ def test_add_two_axes(persons):
simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.add_parallel_axis({'count': 3, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 1000, 2000])
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 1500, 3000])
assert simulation_builder.get_input('pension', '2018-11')[1] == pytest.approx([0, 1000, 2000])


def test_add_axis_with_group(persons):
simulation_builder = SimulationBuilder()
simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}})
simulation_builder.register_variable('salary', persons)
simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11', 'index': 1})
# simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11', 'index': 1})
simulation_builder.expand_axes()
assert simulation_builder.get_count('persons') == 4
assert simulation_builder.get_ids('persons') == ['Alicia0', 'Javier1', 'Alicia2', 'Javier3']
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 0, 3000, 3000])
assert simulation_builder.get_ids('persons') == ['Alicia0', 'Javier0', 'Alicia1', 'Javier1']
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 0, 3000, 3000])


def test_add_axis_with_group_int_period(persons):
simulation_builder = SimulationBuilder()
simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}})
simulation_builder.register_variable('salary', persons)
simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018})
simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018, 'index': 1})
# simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018, 'index': 1})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018') == pytest.approx([0, 0, 3000, 3000])
assert simulation_builder.get_input('salary', '2018')[1] == pytest.approx([0, 0, 3000, 3000])


def test_add_axis_on_households(persons, households):
Expand All @@ -98,8 +98,8 @@ def test_add_axis_on_households(persons, households):
simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.expand_axes()
assert simulation_builder.get_count('households') == 4
assert simulation_builder.get_ids('households') == ['housea0', 'houseb1', 'housea2', 'houseb3']
assert simulation_builder.get_input('rent', '2018-11') == pytest.approx([0, 0, 3000, 0])
assert simulation_builder.get_ids('households') == ['housea0', 'houseb0', 'housea1', 'houseb1']
assert simulation_builder.get_input('rent', '2018-11')[1] == pytest.approx([0, 0, 3000, 3000])


def test_axis_on_group_expands_persons(persons, households):
Expand Down Expand Up @@ -162,8 +162,8 @@ def test_add_perpendicular_axes(persons):
simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000])
assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000])
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 1500, 3000, 0, 1500, 3000])
assert simulation_builder.get_input('pension', '2018-11')[1] == pytest.approx([0, 0, 0, 2000, 2000, 2000])


def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons):
Expand All @@ -179,11 +179,11 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons):
simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'})
simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'})
simulation_builder.expand_axes()
assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000])
assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000])
assert simulation_builder.get_input('salary', '2018-11')[1] == pytest.approx([0, 1500, 3000, 0, 1500, 3000])
assert simulation_builder.get_input('pension', '2018-11')[1] == pytest.approx([0, 0, 0, 2000, 2000, 2000])


# Integration test
# # Integration test


def test_simulation_with_axes(tax_benefit_system):
Expand All @@ -208,4 +208,4 @@ def test_simulation_with_axes(tax_benefit_system):
data = test_runner.yaml.safe_load(input_yaml)
simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data)
assert simulation.get_array('salary', '2018-11') == pytest.approx([0, 0, 0, 0, 0, 0])
assert simulation.get_array('rent', '2018-11') == pytest.approx([0, 0, 3000, 0])
assert simulation.get_array('rent', '2018-11') == pytest.approx([0, 0, 3000, 3000])
Loading

0 comments on commit ac70e28

Please sign in to comment.