Skip to content

Commit

Permalink
Added param to warm start TOP using trained neural network
Browse files Browse the repository at this point in the history
  • Loading branch information
somritabanerjee committed Jan 4, 2025
1 parent 1965e15 commit 9d6ee08
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 20 deletions.
4 changes: 4 additions & 0 deletions astrobee/config/mobility/planner_scp_gusto.config
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,9 @@ parameters = {
id = "enforce_obs_avoidance_const", reconfigurable = true, type = "boolean",
default = true, unit = "unitless",
description = "Flag for avoiding (virtual) obstacle"
},{
id = "use_nn_warm_start", reconfigurable = true, type = "boolean",
default = true, unit = "unitless",
description = "Flag for using NN to warm-start optimization"
}
}
6 changes: 6 additions & 0 deletions mobility/planner_scp_gusto/include/planner_scp_gusto/optim.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ struct Net : torch::nn::Module {
fc2 = register_module("fc2", torch::nn::Linear(128, 64));
fc3 = register_module("fc3", torch::nn::Linear(64, 12));
}
~Net() {
std::cout << "Destructor called for Net object." << std::endl;
}

// // Load weights from a file
// void loadWeights(const std::string& filename) {
Expand Down Expand Up @@ -278,8 +281,11 @@ class TOP {
const std::string& fname);

// Neural network for warm start
bool use_nn_warm_start;
std::string nn_model_path;
std::shared_ptr<Net> net;
torch::optim::Adam optimizer;
void InitTrajWarmStart();

// Function to read a single dataset file
std::tuple<torch::Tensor, torch::Tensor> ReadData(const std::string& filename);
Expand Down
123 changes: 103 additions & 20 deletions mobility/planner_scp_gusto/src/optim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ TOP::TOP(decimal_t Tf_, int N_)
dh = Tf / N;

// Network for warm start
use_nn_warm_start = false;
nn_model_path = "saved_NN_models/trained_model_27_2025-01-03_00-34-39.pt";
// Set weights to zero
// net.initializeWeightsToZero();
// OR Load weights from file
Expand Down Expand Up @@ -129,7 +131,7 @@ TOP::TOP(decimal_t Tf_, int N_)
ResetSCPParams();
UpdateProblemDimension(N);

// Run warm start
// Run warm start for OSQP initialization on demo problem
if (!solver->solve()) {
solver_ready_ = false;
}
Expand All @@ -142,7 +144,7 @@ TOP::TOP(decimal_t Tf_, int N_)
}

TOP::~TOP() {
std::cout << "Destructor called" << std::endl;
std::cout << "Destructor called for TOP object." << std::endl;
}

size_t TOP::GetNumTOPVariables() {
Expand Down Expand Up @@ -280,9 +282,24 @@ void TOP::UpdateProblemDimension(size_t N_) {
qp_soln.resize(num_vars);

// UpdateDoubleIntegrator();
InitTrajStraightline();
// UpdateRotationalDynamics();

if (use_nn_warm_start) {
std::cout << "TOP::UpdateProblemDimension: Using NN warm start" << std::endl;
InitTrajWarmStart();
} else {
std::cout << "TOP::UpdateProblemDimension: Using straight line warm start" << std::endl;
InitTrajStraightline();
}

// Set warm start
for (size_t ii = 0; ii < N; ii++) {
qp_soln.segment(state_dim*ii, state_dim) = Xprev[ii];
}
for (size_t ii = 0; ii < N-1; ii++) {
qp_soln.segment(state_dim*N+control_dim*ii, control_dim) = Uprev[ii];
}

for (size_t ii = 0; ii < num_cons; ii++) {
lower_bound(ii) = -OsqpEigen::INFTY;
upper_bound(ii) = OsqpEigen::INFTY;
Expand Down Expand Up @@ -371,6 +388,54 @@ void TOP::InitTrajStraightline() {
WriteTrajectoryToFile(Xprev, Uprev, fname);
}

void TOP::InitTrajWarmStart() {
// Settings
bool U_linear_only = true;
std::string Xinit_method = "linear_interpolation"; // "forward_dynamics" or "linear_interpolation"

// Load model
LoadModel(nn_model_path);

// Call InferenceNN(x0, xg) to get U0, Uf
Vec6 U0, Uf;
std::tie(U0, Uf) = InferenceNN(x0, xg);
// Interpolate linearly for N steps to get Uprev
Vec6Vec U_inter;
for (size_t i = 0; i < N; ++i) {
Vec6 U = U0 + (i/(N-1))*(Uf - U0);
if (U_linear_only) {
U(3) = 0.0;
U(4) = 0.0;
U(5) = 0.0;
}
U_inter.push_back(U);
}
Uprev = U_inter;

if (Xinit_method == "forward_dynamics") {
// Use dynamics to get Xprev
Vec13Vec X_inter;
X_inter.push_back(x0);
for (size_t i = 0; i < N; ++i) {
Vec13 X = ForwardDynamics(X_inter[i], Uprev[i]);
X_inter.push_back(X);
}
Xprev = X_inter;
} else if (Xinit_method == "linear_interpolation") {
// Linearly interpolate between x0 and xg
Vec13Vec X_inter;
for (size_t i = 0; i < N; i++) {
X_inter.push_back(x0 + (xg - x0) * i / (N - 1));
}
Xprev = X_inter;
}

std::string fname =
"output_trajs/" + std::string(is_granite ? "granite" : "iss") + "_initial_nn_warm_start_trajectory.txt";
WriteTrajectoryToFile(Xprev, Uprev, fname);
return;
}

// void TOP::UpdateF(Vec7& f, Vec13& X, Vec6& U) {
// f.setZero();

Expand Down Expand Up @@ -1249,13 +1314,10 @@ bool TOP::Solve() {
UpdateProblemDimension(N);
std::cout << "TOP::Solve: Updated problem dimension" << std::endl;
std::cout << "linear_con_mat size: " << linear_con_mat.rows() << " x " << linear_con_mat.cols() << std::endl;
InitTrajStraightline();
bool add_custom_keep_out_zone = true;

std::cout << "SCP::InitTrajStraightline complete" << std::endl;
std::cout << "SCP:: start of init traj is " << Xprev[0].transpose() << std::endl;
std::cout << "SCP:: end of init traj is " << Xprev[N-1].transpose() << std::endl;
// ROS_ERROR_STREAM("SCP::InitTrajStraightline done");

std::cout << "TOP:: mass: " << mass << std::endl;
std::cout << "TOP:: inertia: " << J << std::endl;
Expand Down Expand Up @@ -1287,19 +1349,6 @@ bool TOP::Solve() {
// }
// std::cout << std::endl;

// Set warm start
for (size_t ii = 0; ii < N; ii++) {
qp_soln.segment(state_dim*ii, state_dim) = Xprev[ii];
}
for (size_t ii = 0; ii < N-1; ii++) {
qp_soln.segment(state_dim*N+control_dim*ii, control_dim) = Uprev[ii];
}
if (!solver->setPrimalVariable(qp_soln)) {
solver_ready_ = false;
}

std::cout << "SCP::Warm start complete" << std::endl;

// TODO(somrita): Reset max_iter
max_iter = 1;
for (size_t kk = 0; kk < max_iter; kk++) {
Expand Down Expand Up @@ -3007,7 +3056,8 @@ int main() {

bool create_training_data = false;
bool train_and_save_model = false;
bool load_and_run_inference = true;
bool load_and_run_inference = false;
bool test_warm_start = true;

int num_problems = 0;

Expand Down Expand Up @@ -3250,5 +3300,38 @@ int main() {
std::cout << "Uprev final: " << Uprev[Uprev.size() - 1].transpose() << std::endl;
}

if (test_warm_start) {
// Cold start with straight line initialization
scp::TOP top_cold(20., 801);
top_cold.use_nn_warm_start = false;

// Warm start from NN
scp::TOP top_warm(20., 801);
top_warm.use_nn_warm_start = true;

// Set common parameters
top_cold.is_granite = false;
top_warm.is_granite = false;
top_cold.x0 << 10.28, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0;
top_cold.xg << 10.48, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0;
top_warm.x0 << 10.28, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0;
top_warm.xg << 10.48, -9.81, 4.30, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0;

// Solve problems
if (!top_cold.Solve()) {
std::cout << "Cold start: Problem could not be solved!" << std::endl;
} else {
std::cout << "Cold start: Problem solved!" << std::endl;
// TODO(somrita): Log number of iterations or time to solve and quality of solution
}
if (!top_warm.Solve()) {
std::cout << "Warm start: Problem could not be solved!" << std::endl;
} else {
std::cout << "Warm start: Problem solved!" << std::endl;
// TODO(somrita): Log number of iterations or time to solve and quality of solution
}
std::cout << "--------------------------------------------" << std::endl;
}

return 0;
}
4 changes: 4 additions & 0 deletions mobility/planner_scp_gusto/src/planner_scp_gusto_nodelet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation {
double epsilon_;
bool enforce_obs_avoidance_const_;
bool is_granite_;
bool use_nn_warm_start;
bool use_2d; // true for granite table
std::string flight_mode_;
ros::NodeHandle *nh_;
Expand All @@ -97,6 +98,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation {
epsilon_ = cfg_.Get<double>("epsilon");
enforce_obs_avoidance_const_ = cfg_.Get<bool>("enforce_obs_avoidance_const");
is_granite_ = cfg_.Get<bool>("is_granite");
use_nn_warm_start = cfg_.Get<bool>("use_nn_warm_start");
// Notify initialization complete
NODELET_DEBUG_STREAM("Initialization complete");
// Success
Expand All @@ -115,6 +117,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation {
enforce_obs_avoidance_const_ = cfg_.Get<bool>("enforce_obs_avoidance_const");
std::cout << "set enforce_obs_avoidance_const_ to " << enforce_obs_avoidance_const_ << std::endl;
is_granite_ = cfg_.Get<bool>("is_granite");
use_nn_warm_start = cfg_.Get<bool>("use_nn_warm_start");
return true;
}

Expand Down Expand Up @@ -411,6 +414,7 @@ class PlannerSCPGustoNodelet : public planner::PlannerImplementation {
top->x_min(7) = -0.05; // qy
top->x_max(7) = 0.05;
}
top->use_nn_warm_start = cfg_.Get<bool>("use_nn_warm_start");

if (keep_in_zones_.size() == 0) {
ROS_ERROR("Zero keepin zones!! Plan failed");
Expand Down

0 comments on commit 9d6ee08

Please sign in to comment.