From 25b0fb0c624766826758ec7ab526b4be5e85dcf2 Mon Sep 17 00:00:00 2001 From: Boris Okner Date: Mon, 16 Dec 2024 10:09:52 -0500 Subject: [PATCH] Minor changes to propagator code; 'reduce' fun for subsets of damain values --- lib/solver/core/constraint.ex | 22 +++++++++++++------ lib/solver/core/propagator/propagator.ex | 6 +++++- lib/solver/domain/bitvector_domain.ex | 27 ++++++++++++------------ lib/solver/model/model.ex | 9 +++++++- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/lib/solver/core/constraint.ex b/lib/solver/core/constraint.ex index 8888861..6c7710c 100644 --- a/lib/solver/core/constraint.ex +++ b/lib/solver/core/constraint.ex @@ -29,18 +29,28 @@ defmodule CPSolver.Constraint do {constraint_impl, args} end - def constraint_to_propagators({constraint_mod, args}) when is_list(args) do - constraint_mod.propagators(args) + def constraint_to_propagators(constraint, reducer_fun \\ &Function.identity/1) + + def constraint_to_propagators({constraint_mod, args}, reducer_fun) when is_list(args) do + List.foldr(constraint_mod.propagators(args), [], fn p, plist_acc -> + [reducer_fun.(p) | plist_acc] + end) end - def constraint_to_propagators(constraint) when is_tuple(constraint) do + def constraint_to_propagators(constraint, reducer_fun) when is_tuple(constraint) do [constraint_mod | args] = Tuple.to_list(constraint) - constraint_to_propagators({constraint_mod, args}) + constraint_to_propagators({constraint_mod, args}, reducer_fun) end def post(constraint) when is_tuple(constraint) do - propagators = constraint_to_propagators(constraint) - Enum.map(propagators, fn p -> Propagator.filter(p) end) + constraint_to_propagators(constraint, + fn p -> + case Propagator.filter(p) do + :fail -> throw({:fail, p.id}) + %{state: state} -> Propagator.update_state(p, state) + _ -> p + end + end) end def extract_variables(constraint) do diff --git a/lib/solver/core/propagator/propagator.ex b/lib/solver/core/propagator/propagator.ex index 2c8b6c8..2d0676d 100644 --- a/lib/solver/core/propagator/propagator.ex +++ b/lib/solver/core/propagator/propagator.ex @@ -108,7 +108,11 @@ defmodule CPSolver.Propagator do end def reset(%{mod: mod, args: args} = propagator, opts \\ []) do - Map.put(propagator, :state, mod.reset(args, Map.get(propagator, :state), opts)) + update_state(propagator, mod.reset(args, Map.get(propagator, :state), opts)) + end + + def update_state(propagator, state) do + Map.put(propagator, :state, state) end def bind(%{mod: mod} = propagator, source, var_field \\ :domain) do diff --git a/lib/solver/domain/bitvector_domain.ex b/lib/solver/domain/bitvector_domain.ex index 8b4a116..94b87ec 100644 --- a/lib/solver/domain/bitvector_domain.ex +++ b/lib/solver/domain/bitvector_domain.ex @@ -59,38 +59,39 @@ defmodule CPSolver.BitVectorDomain do to_list(domain, mapper_fun) end - def to_list( - {{:bit_vector, ref} = bit_vector, offset} = domain, - mapper_fun \\ &Function.identity/1 - ) do + ## Reduce over domain values + def reduce( {{:bit_vector, ref} = bit_vector, offset} = domain, value_mapper_fun, reduce_fun \\ &MapSet.union/2, acc_init \\ MapSet.new()) do %{ min_addr: %{block: current_min_block, offset: _min_offset}, max_addr: %{block: current_max_block, offset: _max_offset} } = get_bound_addrs(bit_vector) - mapped_lb = mapper_fun.(min(domain)) - mapped_ub = mapper_fun.(max(domain)) + mapped_lb = value_mapper_fun.(min(domain)) + mapped_ub = value_mapper_fun.(max(domain)) - ## Note: this will only work for monotonic mapper functions. - ## We don't have non-monotonic mappers for the moment. - ## - ## Adjust bounds {lb, ub} = (mapped_lb <= mapped_ub && {mapped_lb, mapped_ub}) || {mapped_ub, mapped_lb} - Enum.reduce(current_min_block..current_max_block, MapSet.new(), fn idx, acc -> + Enum.reduce(current_min_block..current_max_block, acc_init, fn idx, acc -> n = :atomics.get(ref, idx) if n == 0 do acc else - MapSet.union( + reduce_fun.( acc, - bit_positions(n, fn val -> {lb, ub, mapper_fun.(val + 64 * (idx - 1) - offset)} end) + bit_positions(n, fn val -> {lb, ub, value_mapper_fun.(val + 64 * (idx - 1) - offset)} end) ) end end) end + def to_list( + {{:bit_vector, ref} = bit_vector, offset} = domain, + value_mapper_fun \\ &Function.identity/1 + ) do + reduce(domain, value_mapper_fun, &MapSet.union/2, MapSet.new()) + end + def fixed?({bit_vector, _offset} = _domain) do {current_min_max, _min_max_idx, current_min, current_max} = get_min_max(bit_vector) current_max == current_min && current_min_max != @failure_value diff --git a/lib/solver/model/model.ex b/lib/solver/model/model.ex index 3931b04..4c9175f 100644 --- a/lib/solver/model/model.ex +++ b/lib/solver/model/model.ex @@ -16,7 +16,14 @@ defmodule CPSolver.Model do } def new(variables, constraints, opts \\ []) do - constraints = normalize_constraints(constraints) + constraints = + normalize_constraints(constraints) + # TODO: + # consider posting constraints and/or building constraint graph/list of propagators here + ## For instance: + # tap(fn constraints -> Enum.each(constraints, fn c -> Constraint.post(c) end) end) + # + {all_variables, objective} = init_model(variables, constraints, opts[:objective]) %__MODULE__{