Skip to content

Commit

Permalink
Add multithreaded support to TrajOpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Levi-Armstrong committed Jun 19, 2023
1 parent 6bef2ea commit 44ec270
Show file tree
Hide file tree
Showing 13 changed files with 341 additions and 77 deletions.
4 changes: 3 additions & 1 deletion trajopt_sco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ elseif(NOT TARGET jsoncpp_lib)
endif()
find_package(ros_industrial_cmake_boilerplate REQUIRED)
find_package(Boost REQUIRED)
find_package(OpenMP REQUIRED)

# Load variable for clang tidy args, compiler options and cxx version
trajopt_variables()
Expand Down Expand Up @@ -108,7 +109,8 @@ target_link_libraries(
Boost::boost
Eigen3::Eigen
${CMAKE_DL_LIBS}
jsoncpp_lib)
jsoncpp_lib
OpenMP::OpenMP_CXX)
target_compile_options(${PROJECT_NAME} PRIVATE ${TRAJOPT_COMPILE_OPTIONS_PRIVATE})
target_compile_options(${PROJECT_NAME} PUBLIC ${TRAJOPT_COMPILE_OPTIONS_PUBLIC})
target_compile_definitions(${PROJECT_NAME} PUBLIC ${TRAJOPT_COMPILE_DEFINITIONS})
Expand Down
1 change: 1 addition & 0 deletions trajopt_sco/cmake/trajopt_sco-config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ if(${CMAKE_VERSION} VERSION_LESS "3.15.0")
else()
find_dependency(Boost)
endif()
find_dependency(OpenMP)

include("${CMAKE_CURRENT_LIST_DIR}/@[email protected]")
20 changes: 14 additions & 6 deletions trajopt_sco/include/trajopt_sco/bpmpd_interface.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#pragma once
#include <trajopt_utils/macros.h>
TRAJOPT_IGNORE_WARNINGS_PUSH
#include <mutex>
TRAJOPT_IGNORE_WARNINGS_POP
#include <trajopt_sco/solver_interface.hpp>

namespace sco
Expand All @@ -19,26 +23,30 @@ class BPMPDModel : public Model
int m_pipeOut{ 0 };
int m_pid{ 0 };

std::mutex m_mutex; /**< The mutex */

BPMPDModel();
~BPMPDModel() override = default;
BPMPDModel(const BPMPDModel&) = default;
BPMPDModel& operator=(const BPMPDModel&) = default;
BPMPDModel(BPMPDModel&&) = default;
BPMPDModel& operator=(BPMPDModel&&) = default;
BPMPDModel(const BPMPDModel&) = delete;
BPMPDModel& operator=(const BPMPDModel&) = delete;
BPMPDModel(BPMPDModel&&) = delete;
BPMPDModel& operator=(BPMPDModel&&) = delete;

// Must be threadsafe
Var addVar(const std::string& name) override;
Cnt addEqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const QuadExpr&, const std::string& name) override;
void removeVars(const VarVector& vars) override;
void removeCnts(const CntVector& cnts) override;

// These do not need to be threadsafe
void update() override;
void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector& vars) const override;
CvxOptStatus optimize() override;
void setObjective(const AffExpr&) override;
void setObjective(const QuadExpr&) override;
void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector& vars) const override;
void writeToFile(const std::string& fname) const override;
VarVector getVars() const override;
};
Expand Down
23 changes: 12 additions & 11 deletions trajopt_sco/include/trajopt_sco/gurobi_interface.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#pragma once
#include <trajopt_utils/macros.h>
TRAJOPT_IGNORE_WARNINGS_PUSH
#include <mutex>
TRAJOPT_IGNORE_WARNINGS_POP
#include <trajopt_sco/solver_interface.hpp>

/**
Expand All @@ -22,34 +26,31 @@ class GurobiModel : public Model
GRBmodel* m_model;
VarVector m_vars;
CntVector m_cnts;
std::mutex m_mutex;

GurobiModel();
~GurobiModel();

// Must be threadsafe
Var addVar(const std::string& name) override;
Var addVar(const std::string& name, double lower, double upper) override;

Cnt addEqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const QuadExpr&, const std::string& name) override;

void removeVars(const VarVector&) override;
void removeCnts(const CntVector&) override;

// These do not need to be threadsafe
void update() override;
void setVarBounds(const VarVector&, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector&) const override;

CvxOptStatus optimize() override;
/** Don't use this function, because it adds constraints that aren't tracked
*/
CvxOptStatus optimizeFeasRelax();

void setObjective(const AffExpr&) override;
void setObjective(const QuadExpr&) override;
void setVarBounds(const VarVector&, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector&) const override;
void writeToFile(const std::string& fname) const override;

VarVector getVars() const override;

~GurobiModel();
/** Don't use this function, because it adds constraints that aren't tracked*/
CvxOptStatus optimizeFeasRelax();
};
} // namespace sco
76 changes: 72 additions & 4 deletions trajopt_sco/include/trajopt_sco/optimizers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,71 @@ struct BasicTrustRegionSQPParameters
bool log_results; // Log results to file
std::string log_dir; // Directory to store log results (Default: /tmp)

/** @brief If greater than one, multi threaded functions are called */
int num_threads;

BasicTrustRegionSQPParameters();
};

struct BasicTrustRegionSQPUtilFunctions
{
virtual ~BasicTrustRegionSQPUtilFunctions() = default;

virtual DblVec evaluateCosts(const std::vector<Cost::Ptr>& costs, const DblVec& x) const;

virtual DblVec evaluateConstraintViols(const std::vector<Constraint::Ptr>& constraints, const DblVec& x) const;

virtual std::vector<ConvexObjective::Ptr> convexifyCosts(const std::vector<Cost::Ptr>& costs,
const DblVec& x,
Model* model) const;

virtual std::vector<ConvexConstraints::Ptr> convexifyConstraints(const std::vector<Constraint::Ptr>& cnts,
const DblVec& x,
Model* model) const;

virtual DblVec evaluateModelCosts(const std::vector<ConvexObjective::Ptr>& costs, const DblVec& x) const;

virtual DblVec evaluateModelCntViols(const std::vector<ConvexConstraints::Ptr>& cnts, const DblVec& x) const;

virtual std::vector<std::string> getCostNames(const std::vector<Cost::Ptr>& costs) const;

virtual std::vector<std::string> getCntNames(const std::vector<Constraint::Ptr>& cnts) const;

virtual std::vector<std::string> getVarNames(const VarVector& vars) const;
};

struct BasicTrustRegionSQPUtilFunctionsThreaded : public BasicTrustRegionSQPUtilFunctions
{
BasicTrustRegionSQPUtilFunctionsThreaded();
BasicTrustRegionSQPUtilFunctionsThreaded(int num_threads);
~BasicTrustRegionSQPUtilFunctionsThreaded() override = default;

DblVec evaluateCosts(const std::vector<Cost::Ptr>& costs, const DblVec& x) const override final;

DblVec evaluateConstraintViols(const std::vector<Constraint::Ptr>& constraints, const DblVec& x) const override final;

std::vector<ConvexObjective::Ptr> convexifyCosts(const std::vector<Cost::Ptr>& costs,
const DblVec& x,
Model* model) const override final;

std::vector<ConvexConstraints::Ptr> convexifyConstraints(const std::vector<Constraint::Ptr>& cnts,
const DblVec& x,
Model* model) const override final;

DblVec evaluateModelCosts(const std::vector<ConvexObjective::Ptr>& costs, const DblVec& x) const override final;

DblVec evaluateModelCntViols(const std::vector<ConvexConstraints::Ptr>& cnts, const DblVec& x) const override final;

std::vector<std::string> getCostNames(const std::vector<Cost::Ptr>& costs) const override final;

std::vector<std::string> getCntNames(const std::vector<Constraint::Ptr>& cnts) const override final;

std::vector<std::string> getVarNames(const VarVector& vars) const override final;

private:
int num_threads_;
};

/**
* @brief This struct stores iteration results for the BasicTrustRegionSQP
*
Expand Down Expand Up @@ -183,7 +245,8 @@ struct BasicTrustRegionSQPResults

BasicTrustRegionSQPResults(std::vector<std::string> var_names,
std::vector<std::string> cost_names,
std::vector<std::string> cnt_names);
std::vector<std::string> cnt_names,
std::shared_ptr<BasicTrustRegionSQPUtilFunctions> util_funcs);

/**
* @brief Update the structure data for a new iteration
Expand Down Expand Up @@ -217,6 +280,9 @@ struct BasicTrustRegionSQPResults
void writeConstraints(std::FILE* stream, bool header = false) const;
/** @brief Prints the raw values to the terminal */
void printRaw() const;

private:
std::shared_ptr<BasicTrustRegionSQPUtilFunctions> util_funcs_;
};

class BasicTrustRegionSQP : public Optimizer
Expand All @@ -238,17 +304,19 @@ class BasicTrustRegionSQP : public Optimizer
BasicTrustRegionSQP() = default;
BasicTrustRegionSQP(const OptProb::Ptr& prob);
void setProblem(OptProb::Ptr prob) override;
void setParameters(const BasicTrustRegionSQPParameters& param) { param_ = param; }
const BasicTrustRegionSQPParameters& getParameters() const { return param_; }
BasicTrustRegionSQPParameters& getParameters() { return param_; }
void setParameters(const BasicTrustRegionSQPParameters& param);
const BasicTrustRegionSQPParameters& getParameters() const;
BasicTrustRegionSQPParameters& getParameters();
OptStatus optimize() override;

protected:
void ctor(const OptProb::Ptr& prob);
void adjustTrustRegion(double ratio);
void setTrustRegionSize(double trust_box_size);
void setTrustBoxConstraints(const DblVec& x);
void updateUtilsFunctions();
Model::Ptr model_;
BasicTrustRegionSQPParameters param_;
std::shared_ptr<BasicTrustRegionSQPUtilFunctions> util_funcs_;
};
} // namespace sco
15 changes: 10 additions & 5 deletions trajopt_sco/include/trajopt_sco/osqp_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TRAJOPT_IGNORE_WARNINGS_PUSH
#include <Eigen/Core>
#include <osqp.h>
#include <mutex>
TRAJOPT_IGNORE_WARNINGS_POP

#include <trajopt_sco/solver_interface.hpp>
Expand Down Expand Up @@ -75,28 +76,32 @@ class OSQPModel : public Model

OSQPModelConfig config_; /**< The configuration settings */

std::mutex mutex_; /**< The mutex */

public:
OSQPModel(const ModelConfig::ConstPtr& config = nullptr);
~OSQPModel() override;
OSQPModel(const OSQPModel& model) = delete;
OSQPModel& operator=(const OSQPModel& model) = delete;
OSQPModel(OSQPModel&&) = default;
OSQPModel& operator=(OSQPModel&&) = default;
OSQPModel(OSQPModel&&) = delete;
OSQPModel& operator=(OSQPModel&&) = delete;

// Must be threadsafe
Var addVar(const std::string& name) override;
Cnt addEqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const QuadExpr&, const std::string& name) override;
void removeVars(const VarVector& vars) override;
void removeCnts(const CntVector& cnts) override;

// These do not need to be threadsafe
void update() override;
void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector& vars) const override;
CvxOptStatus optimize() override;
void setObjective(const AffExpr&) override;
void setObjective(const QuadExpr&) override;
VarVector getVars() const override;
void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector& vars) const override;
void writeToFile(const std::string& fname) const override;
VarVector getVars() const override;
};
} // namespace sco
17 changes: 11 additions & 6 deletions trajopt_sco/include/trajopt_sco/qpoases_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TRAJOPT_IGNORE_WARNINGS_PUSH
#include <Eigen/Core>
#include <qpOASES.hpp>
#include <mutex>
TRAJOPT_IGNORE_WARNINGS_POP

#include <trajopt_sco/solver_interface.hpp>
Expand Down Expand Up @@ -71,27 +72,31 @@ class qpOASESModel : public Model

QuadExpr objective_; /**< objective QuadExpr expression */

std::mutex mutex_; /**< The mutex */

public:
qpOASESModel();
~qpOASESModel() override;
qpOASESModel(const qpOASESModel&) = default;
qpOASESModel& operator=(const qpOASESModel&) = default;
qpOASESModel(qpOASESModel&&) = default;
qpOASESModel& operator=(qpOASESModel&&) = default;
qpOASESModel(const qpOASESModel&) = delete;
qpOASESModel& operator=(const qpOASESModel&) = delete;
qpOASESModel(qpOASESModel&&) = delete;
qpOASESModel& operator=(qpOASESModel&&) = delete;

// Must be thread safe
Var addVar(const std::string& name) override;
Cnt addEqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const AffExpr&, const std::string& name) override;
Cnt addIneqCnt(const QuadExpr&, const std::string& name) override;
void removeVars(const VarVector& vars) override;
void removeCnts(const CntVector& cnts) override;

// These do not need to be threadsafe
void update() override;
void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector& vars) const override;
CvxOptStatus optimize() override;
void setObjective(const AffExpr&) override;
void setObjective(const QuadExpr&) override;
void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) override;
DblVec getVarValues(const VarVector& vars) const override;
void writeToFile(const std::string& fname) const override;
VarVector getVars() const override;
};
Expand Down
13 changes: 13 additions & 0 deletions trajopt_sco/include/trajopt_sco/solver_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,31 @@ class Model
Model(Model&&) = default;
Model& operator=(Model&&) = default;

/**
* @brief Add a var to the model
* @details These must be threadsafe
*/
virtual Var addVar(const std::string& name) = 0;
virtual Var addVar(const std::string& name, double lb, double ub);

/**
* @brief Add a equation to the model
* @details These must be threadsafe
*/
virtual Cnt addEqCnt(const AffExpr&, const std::string& name) = 0; // expr == 0
virtual Cnt addIneqCnt(const AffExpr&, const std::string& name) = 0; // expr <= 0
virtual Cnt addIneqCnt(const QuadExpr&, const std::string& name) = 0; // expr <= 0

/**
* @brief Remove items from model
* @details These must be threadsafe
*/
virtual void removeVar(const Var& var);
virtual void removeCnt(const Cnt& cnt);
virtual void removeVars(const VarVector& vars) = 0;
virtual void removeCnts(const CntVector& cnts) = 0;

/** @details It is not neccessary to make the following methods threadsafe */
virtual void update() = 0; // call after adding/deleting stuff
virtual void setVarBounds(const Var& var, double lower, double upper);
virtual void setVarBounds(const VarVector& vars, const DblVec& lower, const DblVec& upper) = 0;
Expand Down
Loading

0 comments on commit 44ec270

Please sign in to comment.