From 9d6ee085b1f7b131a89f027168a664c069b9bdf5 Mon Sep 17 00:00:00 2001 From: Somrita Banerjee Date: Fri, 3 Jan 2025 17:23:23 -0800 Subject: [PATCH] Added param to warm start TOP using trained neural network --- .../config/mobility/planner_scp_gusto.config | 4 + .../include/planner_scp_gusto/optim.h | 6 + mobility/planner_scp_gusto/src/optim.cc | 123 +++++++++++++++--- .../src/planner_scp_gusto_nodelet.cc | 4 + 4 files changed, 117 insertions(+), 20 deletions(-) diff --git a/astrobee/config/mobility/planner_scp_gusto.config b/astrobee/config/mobility/planner_scp_gusto.config index 5a92c069b2..d1a2e77536 100644 --- a/astrobee/config/mobility/planner_scp_gusto.config +++ b/astrobee/config/mobility/planner_scp_gusto.config @@ -29,5 +29,9 @@ parameters = { id = "enforce_obs_avoidance_const", reconfigurable = true, type = "boolean", default = true, unit = "unitless", description = "Flag for avoiding (virtual) obstacle" + },{ + id = "use_nn_warm_start", reconfigurable = true, type = "boolean", + default = true, unit = "unitless", + description = "Flag for using NN to warm-start optimization" } } diff --git a/mobility/planner_scp_gusto/include/planner_scp_gusto/optim.h b/mobility/planner_scp_gusto/include/planner_scp_gusto/optim.h index 3fc66e1223..7a8740ef75 100644 --- a/mobility/planner_scp_gusto/include/planner_scp_gusto/optim.h +++ b/mobility/planner_scp_gusto/include/planner_scp_gusto/optim.h @@ -45,6 +45,9 @@ struct Net : torch::nn::Module { fc2 = register_module("fc2", torch::nn::Linear(128, 64)); fc3 = register_module("fc3", torch::nn::Linear(64, 12)); } + ~Net() { + std::cout << "Destructor called for Net object." << std::endl; + } // // Load weights from a file // void loadWeights(const std::string& filename) { @@ -278,8 +281,11 @@ class TOP { const std::string& fname); // Neural network for warm start + bool use_nn_warm_start; + std::string nn_model_path; std::shared_ptr net; torch::optim::Adam optimizer; + void InitTrajWarmStart(); // Function to read a single dataset file std::tuple ReadData(const std::string& filename); diff --git a/mobility/planner_scp_gusto/src/optim.cc b/mobility/planner_scp_gusto/src/optim.cc index 3d9b954bb6..6232e42bd5 100644 --- a/mobility/planner_scp_gusto/src/optim.cc +++ b/mobility/planner_scp_gusto/src/optim.cc @@ -52,6 +52,8 @@ TOP::TOP(decimal_t Tf_, int N_) dh = Tf / N; // Network for warm start + use_nn_warm_start = false; + nn_model_path = "saved_NN_models/trained_model_27_2025-01-03_00-34-39.pt"; // Set weights to zero // net.initializeWeightsToZero(); // OR Load weights from file @@ -129,7 +131,7 @@ TOP::TOP(decimal_t Tf_, int N_) ResetSCPParams(); UpdateProblemDimension(N); - // Run warm start + // Run warm start for OSQP initialization on demo problem if (!solver->solve()) { solver_ready_ = false; } @@ -142,7 +144,7 @@ TOP::TOP(decimal_t Tf_, int N_) } TOP::~TOP() { - std::cout << "Destructor called" << std::endl; + std::cout << "Destructor called for TOP object." << std::endl; } size_t TOP::GetNumTOPVariables() { @@ -280,9 +282,24 @@ void TOP::UpdateProblemDimension(size_t N_) { qp_soln.resize(num_vars); // UpdateDoubleIntegrator(); - InitTrajStraightline(); // UpdateRotationalDynamics(); + if (use_nn_warm_start) { + std::cout << "TOP::UpdateProblemDimension: Using NN warm start" << std::endl; + InitTrajWarmStart(); + } else { + std::cout << "TOP::UpdateProblemDimension: Using straight line warm start" << std::endl; + InitTrajStraightline(); + } + + // Set warm start + for (size_t ii = 0; ii < N; ii++) { + qp_soln.segment(state_dim*ii, state_dim) = Xprev[ii]; + } + for (size_t ii = 0; ii < N-1; ii++) { + qp_soln.segment(state_dim*N+control_dim*ii, control_dim) = Uprev[ii]; + } + for (size_t ii = 0; ii < num_cons; ii++) { lower_bound(ii) = -OsqpEigen::INFTY; upper_bound(ii) = OsqpEigen::INFTY; @@ -371,6 +388,54 @@ void TOP::InitTrajStraightline() { WriteTrajectoryToFile(Xprev, Uprev, fname); } +void TOP::InitTrajWarmStart() { + // Settings + bool U_linear_only = true; + std::string Xinit_method = "linear_interpolation"; // "forward_dynamics" or "linear_interpolation" + + // Load model + LoadModel(nn_model_path); + + // Call InferenceNN(x0, xg) to get U0, Uf + Vec6 U0, Uf; + std::tie(U0, Uf) = InferenceNN(x0, xg); + // Interpolate linearly for N steps to get Uprev + Vec6Vec U_inter; + for (size_t i = 0; i < N; ++i) { + Vec6 U = U0 + (i/(N-1))*(Uf - U0); + if (U_linear_only) { + U(3) = 0.0; + U(4) = 0.0; + U(5) = 0.0; + } + U_inter.push_back(U); + } + Uprev = U_inter; + + if (Xinit_method == "forward_dynamics") { + // Use dynamics to get Xprev + Vec13Vec X_inter; + X_inter.push_back(x0); + for (size_t i = 0; i < N; ++i) { + Vec13 X = ForwardDynamics(X_inter[i], Uprev[i]); + X_inter.push_back(X); + } + Xprev = X_inter; + } else if (Xinit_method == "linear_interpolation") { + // Linearly interpolate between x0 and xg + Vec13Vec X_inter; + for (size_t i = 0; i < N; i++) { + X_inter.push_back(x0 + (xg - x0) * i / (N - 1)); + } + Xprev = X_inter; + } + + std::string fname = + "output_trajs/" + std::string(is_granite ? "granite" : "iss") + "_initial_nn_warm_start_trajectory.txt"; + WriteTrajectoryToFile(Xprev, Uprev, fname); + return; +} + // void TOP::UpdateF(Vec7& f, Vec13& X, Vec6& U) { // f.setZero(); @@ -1249,13 +1314,10 @@ bool TOP::Solve() { UpdateProblemDimension(N); std::cout << "TOP::Solve: Updated problem dimension" << std::endl; std::cout << "linear_con_mat size: " << linear_con_mat.rows() << " x " << linear_con_mat.cols() << std::endl; - InitTrajStraightline(); bool add_custom_keep_out_zone = true; - std::cout << "SCP::InitTrajStraightline complete" << std::endl; std::cout << "SCP:: start of init traj is " << Xprev[0].transpose() << std::endl; std::cout << "SCP:: end of init traj is " << Xprev[N-1].transpose() << std::endl; - // ROS_ERROR_STREAM("SCP::InitTrajStraightline done"); std::cout << "TOP:: mass: " << mass << std::endl; std::cout << "TOP:: inertia: " << J << std::endl; @@ -1287,19 +1349,6 @@ bool TOP::Solve() { // } // std::cout << std::endl; - // Set warm start - for (size_t ii = 0; ii < N; ii++) { - qp_soln.segment(state_dim*ii, state_dim) = Xprev[ii]; - } - for (size_t ii = 0; ii < N-1; ii++) { - qp_soln.segment(state_dim*N+control_dim*ii, control_dim) = Uprev[ii]; - } - if (!solver->setPrimalVariable(qp_soln)) { - solver_ready_ = false; - } - - std::cout << "SCP::Warm start complete" << std::endl; - // TODO(somrita): Reset max_iter max_iter = 1; for (size_t kk = 0; kk < max_iter; kk++) { @@ -3007,7 +3056,8 @@ int main() { bool create_training_data = false; bool train_and_save_model = false; - bool load_and_run_inference = true; + bool load_and_run_inference = false; + bool test_warm_start = true; int num_problems = 0; @@ -3250,5 +3300,38 @@ int main() { std::cout << "Uprev final: " << Uprev[Uprev.size() - 1].transpose() << std::endl; } + if (test_warm_start) { + // Cold start with straight line initialization + scp::TOP top_cold(20., 801); + top_cold.use_nn_warm_start = false; + + // Warm start from NN + scp::TOP top_warm(20., 801); + top_warm.use_nn_warm_start = true; + + // Set common parameters + top_cold.is_granite = false; + top_warm.is_granite = false; + top_cold.x0 << 10.28, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0; + top_cold.xg << 10.48, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0; + top_warm.x0 << 10.28, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0; + top_warm.xg << 10.48, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0; + + // Solve problems + if (!top_cold.Solve()) { + std::cout << "Cold start: Problem could not be solved!" << std::endl; + } else { + std::cout << "Cold start: Problem solved!" << std::endl; + // TODO(somrita): Log number of iterations or time to solve and quality of solution + } + if (!top_warm.Solve()) { + std::cout << "Warm start: Problem could not be solved!" << std::endl; + } else { + std::cout << "Warm start: Problem solved!" << std::endl; + // TODO(somrita): Log number of iterations or time to solve and quality of solution + } + std::cout << "--------------------------------------------" << std::endl; + } + return 0; } diff --git a/mobility/planner_scp_gusto/src/planner_scp_gusto_nodelet.cc b/mobility/planner_scp_gusto/src/planner_scp_gusto_nodelet.cc index 6be06faf8c..3281a91d55 100644 --- a/mobility/planner_scp_gusto/src/planner_scp_gusto_nodelet.cc +++ b/mobility/planner_scp_gusto/src/planner_scp_gusto_nodelet.cc @@ -71,6 +71,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation { double epsilon_; bool enforce_obs_avoidance_const_; bool is_granite_; + bool use_nn_warm_start; bool use_2d; // true for granite table std::string flight_mode_; ros::NodeHandle *nh_; @@ -97,6 +98,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation { epsilon_ = cfg_.Get("epsilon"); enforce_obs_avoidance_const_ = cfg_.Get("enforce_obs_avoidance_const"); is_granite_ = cfg_.Get("is_granite"); + use_nn_warm_start = cfg_.Get("use_nn_warm_start"); // Notify initialization complete NODELET_DEBUG_STREAM("Initialization complete"); // Success @@ -115,6 +117,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation { enforce_obs_avoidance_const_ = cfg_.Get("enforce_obs_avoidance_const"); std::cout << "set enforce_obs_avoidance_const_ to " << enforce_obs_avoidance_const_ << std::endl; is_granite_ = cfg_.Get("is_granite"); + use_nn_warm_start = cfg_.Get("use_nn_warm_start"); return true; } @@ -411,6 +414,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation { top->x_min(7) = -0.05; // qy top->x_max(7) = 0.05; } + top->use_nn_warm_start = cfg_.Get("use_nn_warm_start"); if (keep_in_zones_.size() == 0) { ROS_ERROR("Zero keepin zones!! Plan failed");