Skip to content

Commit

Permalink
TimeIndex field for ParameterNode and VariableNode
Browse files Browse the repository at this point in the history
  • Loading branch information
a-zakir committed Jan 16, 2025
1 parent 456a98b commit fbab834
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ namespace Antares::Solver::Visitors
class TimeIndexVisitor: public NodeVisitor<TimeIndex>
{
public:
// TODO if Node contains time and scenario dependency, do we need this ctor?
/**
* @brief Constructs a time index visitor with the specified context.
*
* @param context The context containing the time index for each node.
*/
explicit TimeIndexVisitor(std::unordered_map<const Nodes::Node*, TimeIndex> context);
explicit TimeIndexVisitor() = default;

std::string name() const override;

Expand Down
4 changes: 2 additions & 2 deletions src/solver/expressions/visitors/TimeIndexVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ TimeIndex TimeIndexVisitor::visit(const Nodes::GreaterThanOrEqualNode* gt)

TimeIndex TimeIndexVisitor::visit(const Nodes::VariableNode* var)
{
return context_.at(var);
return var->timeIndex();
}

TimeIndex TimeIndexVisitor::visit(const Nodes::ParameterNode* param)
{
return context_.at(param);
return param->timeIndex();
}

TimeIndex TimeIndexVisitor::visit([[maybe_unused]] const Nodes::LiteralNode* lit)
Expand Down
19 changes: 10 additions & 9 deletions src/solver/modelConverter/convertorVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ class ConvertorVisitor: public ExprVisitor
std::unordered_map<const Nodes::Node*, Visitors::TimeIndex> nodeTimeIndex;
};

ExpressionConversionResults convertExpressionToNode(const std::string& exprStr,
const ModelParser::Model& model)
NodeRegistry convertExpressionToNode(const std::string& exprStr, const ModelParser::Model& model)
{
if (exprStr.empty())
{
Expand All @@ -92,7 +91,7 @@ ExpressionConversionResults convertExpressionToNode(const std::string& exprStr,
Antares::Solver::Registry<Node> registry;
ConvertorVisitor visitor(registry, model);
auto root = std::any_cast<Node*>(visitor.visit(tree));
return {NodeRegistry(root, std::move(registry)), visitor.getTimeIndex()};
return NodeRegistry(root, std::move(registry));
}

ConvertorVisitor::ConvertorVisitor(Antares::Solver::Registry<Node>& registry,
Expand Down Expand Up @@ -133,19 +132,21 @@ std::any ConvertorVisitor::visitIdentifier(ExprParser::IdentifierContext* contex
{
if (param.id == context->IDENTIFIER()->getText())
{
auto ret = static_cast<Node*>(registry_.create<ParameterNode>(param.id));
nodeTimeIndex[ret] = convertToTimeIndex(param.time_dependent, param.scenario_dependent);
return ret;
return static_cast<Node*>(
registry_.create<ParameterNode>(param.id,
convertToTimeIndex(param.time_dependent,
param.scenario_dependent)));
}
}

for (const auto& var: model_.variables)
{
if (var.id == context->getText())
{
auto ret = static_cast<Node*>(registry_.create<VariableNode>(var.id));
nodeTimeIndex[ret] = convertToTimeIndex(var.time_dependent, var.scenario_dependent);
return ret;
return static_cast<Node*>(
registry_.create<VariableNode>(var.id,
convertToTimeIndex(var.time_dependent,
var.scenario_dependent)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,9 @@
#include <antares/solver/expressions/NodeRegistry.h>
#include "antares/solver/modelParser/Library.h"

namespace Antares::Solver::Visitors
{
enum class TimeIndex : unsigned int;
}

namespace Antares::Solver::ModelConverter
{
struct ExpressionConversionResults
{
NodeRegistry nodeRegistry;
std::unordered_map<const Nodes::Node*, Visitors::TimeIndex> nodeTimeIndex;
};

ExpressionConversionResults convertExpressionToNode(const std::string& exprStr,
const ModelParser::Model& model);
NodeRegistry convertExpressionToNode(const std::string& exprStr, const ModelParser::Model& model);
} // namespace Antares::Solver::ModelConverter
13 changes: 7 additions & 6 deletions src/solver/modelConverter/modelConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ std::vector<Antares::Study::SystemModel::Variable> convertVariables(const ModelP
for (const auto& variable: model.variables)
{
Antares::Study::SystemModel::Expression lb(variable.lower_bound,
convertExpressionToNode(variable.lower_bound, model).nodeRegistry);
convertExpressionToNode(variable.lower_bound,
model));
Antares::Study::SystemModel::Expression ub(variable.upper_bound,
convertExpressionToNode(variable.upper_bound, model).nodeRegistry);
convertExpressionToNode(variable.upper_bound,
model));
variables.emplace_back(
variable.id,
std::move(lb),
Expand Down Expand Up @@ -159,11 +161,10 @@ std::vector<Antares::Study::SystemModel::Constraint> convertConstraints(
std::vector<Antares::Study::SystemModel::Constraint> constraints;
for (const auto& constraint: model.constraints)
{
auto [nodeRegistry, nodeTimeIndex] = convertExpressionToNode(constraint.expression, model);
auto nodeRegistry = convertExpressionToNode(constraint.expression, model);
constraints.emplace_back(constraint.id,
Antares::Study::SystemModel::Expression{constraint.expression,
std::move(nodeRegistry)},
nodeTimeIndex);
std::move(nodeRegistry)});
}
return constraints;
}
Expand All @@ -189,7 +190,7 @@ std::vector<Antares::Study::SystemModel::Model> convertModels(
model);

std::unordered_map<const Nodes::Node*, Visitors::TimeIndex> nodeTimeIndex;
auto [nodeObjective, _] = convertExpressionToNode(model.objective, model);
auto nodeObjective = convertExpressionToNode(model.objective, model);

auto modelObj = modelBuilder.withId(model.id)
.withObjective(
Expand Down
3 changes: 1 addition & 2 deletions src/solver/optim-model-filler/ComponentFiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ void ComponentFiller::addConstraints(Solver::Modeler::Api::ILinearProblem& pb,
// TODO timesteps will be a parameter
if (checkTimeSteps(ctx))
{
Solver::Visitors::TimeIndexVisitor timeIndexVisitor(constraint.getNodeTimeIndex());
if (IsThisConstraintTimeDependent(root_node, constraint))

{
Expand Down Expand Up @@ -195,7 +194,7 @@ bool ComponentFiller::IsThisConstraintTimeDependent(
const Solver::Nodes::Node* node,
const Study::SystemModel::Constraint& constraint)
{
Solver::Visitors::TimeIndexVisitor timeIndexVisitor(constraint.getNodeTimeIndex());
Solver::Visitors::TimeIndexVisitor timeIndexVisitor;
const auto ret = timeIndexVisitor.dispatch(node);
return ret == Solver::Visitors::TimeIndex::VARYING_IN_TIME_ONLY
|| ret == Solver::Visitors::TimeIndex::VARYING_IN_TIME_AND_SCENARIO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ namespace Antares::Study::SystemModel
class Constraint
{
public:
Constraint(std::string name,
Expression expression,
std::unordered_map<const Solver::Nodes::Node*, Solver::Visitors::TimeIndex>
nodeTimeIndex): id_(std::move(name)),
expression_(std::move(expression)),
nodeTimeIndex(std::move(nodeTimeIndex))
Constraint(std::string name, Expression expression):
id_(std::move(name)),
expression_(std::move(expression))
{
}

Expand All @@ -57,16 +54,9 @@ class Constraint
return expression_;
}

std::unordered_map<const Solver::Nodes::Node*, Solver::Visitors::TimeIndex>
getNodeTimeIndex() const
{
return nodeTimeIndex;
}

private:
std::string id_;
Expression expression_;
std::unordered_map<const Solver::Nodes::Node*, Solver::Visitors::TimeIndex> nodeTimeIndex;
};

} // namespace Antares::Study::SystemModel
22 changes: 9 additions & 13 deletions src/tests/src/solver/expressions/test_TimeIndexVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,15 @@ BOOST_AUTO_TEST_SUITE(_TimeIndexVisitor_)
BOOST_FIXTURE_TEST_CASE(simple_time_dependant_expression, Registry<Node>)
{
PrintVisitor printVisitor;
std::unordered_map<const Node*, TimeIndex> context;
// LiteralNode --> constant in time and for all scenarios
LiteralNode literalNode(65.);

// Parameter --> constant in time and varying scenarios
ParameterNode parameterNode1("p1");
context[&parameterNode1] = TimeIndex::VARYING_IN_SCENARIO_ONLY;
ParameterNode parameterNode1("p1", TimeIndex::VARYING_IN_SCENARIO_ONLY);

// Variable time varying but constant across scenarios
VariableNode variableNode1("v1");
context[&variableNode1] = TimeIndex::VARYING_IN_TIME_ONLY;
TimeIndexVisitor timeIndexVisitor(context);
VariableNode variableNode1("v1", TimeIndex::VARYING_IN_TIME_ONLY);
TimeIndexVisitor timeIndexVisitor;

BOOST_CHECK_EQUAL(timeIndexVisitor.dispatch(&literalNode),
TimeIndex::CONSTANT_IN_TIME_AND_SCENARIO);
Expand All @@ -89,14 +86,15 @@ static const std::vector<TimeIndex> TimeIndex_ALL{TimeIndex::CONSTANT_IN_TIME_AN
TimeIndex::VARYING_IN_TIME_AND_SCENARIO};

template<class T>
static std::pair<Node*, ParameterNode*> s_(Registry<Node>& registry)
static std::pair<Node*, ParameterNode*> s_(Registry<Node>& registry, const TimeIndex& time_index)
{
Node* left = registry.create<LiteralNode>(42.);
ParameterNode* right = registry.create<ParameterNode>("param");
ParameterNode* right = registry.create<ParameterNode>("param", time_index);
return {registry.create<T>(left, right), right};
}

static const std::vector<std::pair<Node*, ParameterNode*> (*)(Registry<Node>& registry)>
static const std::vector<std::pair<Node*, ParameterNode*> (*)(Registry<Node>& registry,
const TimeIndex& time_index)>
operator_ALL{&s_<SumNode>,
&s_<SubtractionNode>,
&s_<MultiplicationNode>,
Expand All @@ -111,10 +109,8 @@ BOOST_DATA_TEST_CASE_F(Registry<Node>,
timeIndex,
binaryOperator)
{
auto [root, parameter] = binaryOperator(*this);
std::unordered_map<const Node*, TimeIndex> context;
context[parameter] = timeIndex;
TimeIndexVisitor timeIndexVisitor(context);
auto [root, parameter] = binaryOperator(*this, timeIndex);
TimeIndexVisitor timeIndexVisitor;
BOOST_CHECK_EQUAL(timeIndexVisitor.dispatch(root), timeIndex);
Node* neg = create<NegationNode>(root);
BOOST_CHECK_EQUAL(timeIndexVisitor.dispatch(neg), timeIndex);
Expand Down
4 changes: 2 additions & 2 deletions src/tests/src/solver/modelParser/testConvertorVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ExpressionToNodeConvertorEmptyModel

NodeRegistry run(const std::string& input)
{
return ModelConverter::convertExpressionToNode(input, model_).nodeRegistry;
return ModelConverter::convertExpressionToNode(input, model_);
}

private:
Expand Down Expand Up @@ -116,7 +116,7 @@ BOOST_AUTO_TEST_CASE(identifierNotFound)
.objective = "objectives"};

std::string expression = "abc"; // not a param or var
BOOST_CHECK_EXCEPTION(ModelConverter::convertExpressionToNode(expression, model).nodeRegistry,
BOOST_CHECK_EXCEPTION(ModelConverter::convertExpressionToNode(expression, model),
std::runtime_error,
expectedMessage);
}
Expand Down
42 changes: 14 additions & 28 deletions src/tests/src/solver/optim-model-filler/test_componentFiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ struct ConstraintData
{
string id;
Node* expression;
std::unordered_map<const Node*, Antares::Solver::Visitors::TimeIndex> nodeTimeIndex;
};

struct LinearProblemBuildingFixture
Expand Down Expand Up @@ -94,14 +93,18 @@ struct LinearProblemBuildingFixture
return nodes.create<LiteralNode>(value);
}

Node* parameter(const string& paramId)
Node* parameter(const string& paramId,
const Antares::Solver::Visitors::TimeIndex& timeIndex = Antares::Solver::
Visitors::TimeIndex::CONSTANT_IN_TIME_AND_SCENARIO)
{
return nodes.create<ParameterNode>(paramId);
return nodes.create<ParameterNode>(paramId, timeIndex);
}

Node* variable(const string& varId)
Node* variable(const string& varId,
const Antares::Solver::Visitors::TimeIndex& timeIndex = Antares::Solver::
Visitors::TimeIndex::CONSTANT_IN_TIME_AND_SCENARIO)
{
return nodes.create<VariableNode>(varId);
return nodes.create<VariableNode>(varId, timeIndex);
}

Node* multiply(Node* node1, Node* node2)
Expand Down Expand Up @@ -146,9 +149,9 @@ void LinearProblemBuildingFixture::createModel(string modelId,
static_cast<ScenarioDependent>(scenarioDependent))));
}
vector<Constraint> constraints;
for (auto [id, expression, nodeTimeIndex]: constraintsData)
for (auto [id, expression]: constraintsData)
{
constraints.push_back(move(Constraint(id, createExpression(expression), nodeTimeIndex)));
constraints.push_back(move(Constraint(id, createExpression(expression))));
}
ModelBuilder model_builder;
model_builder.withId(modelId)
Expand Down Expand Up @@ -341,15 +344,6 @@ BOOST_AUTO_TEST_CASE(one_model_two_components__dont_clash)

BOOST_AUTO_TEST_SUITE_END()

static auto ConstantNodeContext(const std::vector<const Node*>& nodes)
{
std::unordered_map<const Node*, Antares::Solver::Visitors::TimeIndex> nodeTimeIndex;
for (auto node: nodes)
{
nodeTimeIndex[node] = Antares::Solver::Visitors::TimeIndex::CONSTANT_IN_TIME_AND_SCENARIO;
}
return nodeTimeIndex;
}

BOOST_FIXTURE_TEST_SUITE(_ComponentFiller_addConstraints_, LinearProblemBuildingFixture)

Expand All @@ -363,7 +357,7 @@ BOOST_AUTO_TEST_CASE(ct_one_var__pb_contains_the_ct)
createModel("model",
{},
{{"var1", ValueType::BOOL, literal(-5), literal(10), false, false}},
{{"ct1", ct_node, ConstantNodeContext({var_node, three})}});
{{"ct1", ct_node}});
createComponent("model", "componentToto");
buildLinearProblem();

Expand Down Expand Up @@ -396,7 +390,7 @@ BOOST_AUTO_TEST_CASE(ct_one_var_with_coef__pb_contains_the_ct)
"var__1",
literal(-5),
literal(10),
{{"ct_1", ct_node, ConstantNodeContext({var_node, five, three})}});
{{"ct_1", ct_node}});
createComponent("model", "componentTata");
buildLinearProblem();

Expand Down Expand Up @@ -440,11 +434,7 @@ BOOST_AUTO_TEST_CASE(ct_with_two_vars)
multiply(literal(5), param4));
auto ct_node = nodes.create<EqualNode>(sum_node_left, sum_node_right);

const auto constnode = ConstantNodeContext({v1, v2, param1, param2, param3, param4});
createModel("my_new_model",
params,
{var1Data, var2Data},
{{"constraint1", ct_node, constnode}});
createModel("my_new_model", params, {var1Data, var2Data}, {{"constraint1", ct_node}});
createComponent("my_new_model",
"my_component",
{{"param1", -16}, {"param2", 8}, {"param3", 5}, {"param4", -3}});
Expand Down Expand Up @@ -482,11 +472,7 @@ BOOST_AUTO_TEST_CASE(two_constraints__they_are_created)
auto two_2 = literal(2);
auto ct2_node = nodes.create<LessThanOrEqualNode>(v2, nodes.create<DivisionNode>(v1, two_2));

auto constnodes = ConstantNodeContext({v1, v2});
createModel("my_new_model",
{},
{var1Data, var2Data},
{{"ct1", ct1_node, constnodes}, {"ct2", ct2_node, constnodes}});
createModel("my_new_model", {}, {var1Data, var2Data}, {{"ct1", ct1_node}, {"ct2", ct2_node}});
createComponent("my_new_model", "my_component", {});
buildLinearProblem();

Expand Down

0 comments on commit fbab834

Please sign in to comment.