Skip to content

Commit

Permalink
Merge pull request #291 from leggedrobotics/dev/open_source_jf
Browse files Browse the repository at this point in the history
enjoy reviewing the pull request
  • Loading branch information
JonasFrey96 authored Feb 18, 2024
2 parents 0650dbe + f47c6c9 commit ff242f9
Show file tree
Hide file tree
Showing 212 changed files with 15,827 additions and 2,941 deletions.
1 change: 1 addition & 0 deletions .deprecated/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .graph_trav_dataset import get_ablation_module
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from torch_geometric.data import InMemoryDataset, DataListLoader, DataLoader
from torch_geometric.data import LightningDataset
# TODO
# from torch_geometric.data import InMemoryDataset, DataLoader
# from torch_geometric.data import LightningDataset
# from torch_geometric.data import Dataset

from wild_visual_navigation import WVN_ROOT_DIR
import os
import torch
from pathlib import Path
from torch_geometric.data import Dataset
from torchvision import transforms as T
from typing import Optional, Callable
from torch_geometric.data import Data
import random


Expand Down Expand Up @@ -186,13 +185,28 @@ def get_test_dataset(perugia_root, env, feature_key, test_all_datasets, training
)
]
else:
val_dataset = get_test_dataset(perugia_root, env, feature_key, test_all_datasets, training_data_percentage)
val_dataset = get_test_dataset(
perugia_root,
env,
feature_key,
test_all_datasets,
training_data_percentage,
)

train_loader = DataLoader(
dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=False
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=False,
)
val_loader = [
DataLoader(dataset=v, batch_size=batch_size, num_workers=num_workers, pin_memory=False) for v in val_dataset
DataLoader(
dataset=v,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=False,
)
for v in val_dataset
]

else:
Expand All @@ -212,9 +226,20 @@ def get_test_dataset(perugia_root, env, feature_key, test_all_datasets, training
)
]
else:
test_dataset = get_test_dataset(perugia_root, env, feature_key, test_all_datasets, training_data_percentage)
test_dataset = get_test_dataset(
perugia_root,
env,
feature_key,
test_all_datasets,
training_data_percentage,
)
test_loader = [
DataLoader(dataset=t, batch_size=batch_size, num_workers=num_workers, pin_memory=False)
DataLoader(
dataset=t,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=False,
)
for t in test_dataset
]
else:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# Frameworks
import torch
import pytorch_lightning
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -60,7 +59,11 @@ def training_routine(exp: ExperimentParams, seed=42) -> torch.Tensor:

if exp.cb_checkpoint.active:
checkpoint_callback = ModelCheckpoint(
dirpath=model_path, save_top_k=1, monitor="epoch", mode="max", save_last=True
dirpath=model_path,
save_top_k=1,
monitor="epoch",
mode="max",
save_last=True,
)
cb_ls.append(checkpoint_callback)

Expand Down Expand Up @@ -93,7 +96,7 @@ def training_routine(exp: ExperimentParams, seed=42) -> torch.Tensor:
ckpt = torch.load(exp.model.load_ckpt)
try:
res = model.load_state_dict(ckpt.state_dict, strict=False)
except:
except Exception:
res = model.load_state_dict(ckpt, strict=False)
print("Loaded model checkpoint:", res)
trainer = Trainer(**exp.trainer, callbacks=cb_ls, logger=logger)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from os.path import join
from wild_visual_navigation.visu import LearningVisualizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch_geometric.data import Data

# from torch_geometric.data import Data
from torchmetrics import ROC

from wild_visual_navigation.utils import TraversabilityLoss, MetricLogger
import os
import pickle


class LightningTrav(pl.LightningModule):
Expand Down Expand Up @@ -70,8 +69,20 @@ def training_step(self, batch: any, batch_idx: int) -> torch.Tensor:

for k, v in loss_aux.items():
if k.find("loss") != -1:
self.log(f"{self._mode}_{k}", v.item(), on_epoch=True, prog_bar=True, batch_size=BS)
self.log(f"{self._mode}_loss", loss.item(), on_epoch=True, prog_bar=True, batch_size=BS)
self.log(
f"{self._mode}_{k}",
v.item(),
on_epoch=True,
prog_bar=True,
batch_size=BS,
)
self.log(
f"{self._mode}_loss",
loss.item(),
on_epoch=True,
prog_bar=True,
batch_size=BS,
)

self.visu(graph, res_updated, loss_aux["confidence"])

Expand Down Expand Up @@ -112,9 +123,21 @@ def validation_step(self, batch: any, batch_idx: int, dataloader_id: int = 0) ->

for k, v in loss_aux.items():
if k.find("loss") != -1:
self.log(f"{self._mode}_{k}", v.item(), on_epoch=True, prog_bar=True, batch_size=BS)
self.log(
f"{self._mode}_{k}",
v.item(),
on_epoch=True,
prog_bar=True,
batch_size=BS,
)

self.log(f"{self._mode}_loss", loss.item(), on_epoch=True, prog_bar=True, batch_size=BS)
self.log(
f"{self._mode}_loss",
loss.item(),
on_epoch=True,
prog_bar=True,
batch_size=BS,
)
self.visu(graph, res_updated, loss_aux["confidence"])

return loss
Expand All @@ -133,7 +156,7 @@ def update_threshold(self):
fpr, tpr, thresholds = self._auxiliary_training_roc.compute()
index = torch.where(fpr > 0.15)[0][0]
self.threshold[0] = thresholds[index]
except:
except Exception:
pass
else:
self.threshold[0] = 0.5
Expand All @@ -158,13 +181,25 @@ def test_step(self, batch: any, batch_idx: int, dataloader_id: int = 0) -> torch

for k, v in loss_aux.items():
if k.find("loss") != -1:
self.log(f"{self._mode}_{k}", v.item(), on_epoch=True, prog_bar=True, batch_size=BS)
self.log(f"{self._mode}_loss", loss.item(), on_epoch=True, prog_bar=True, batch_size=BS)
self.log(
f"{self._mode}_{k}",
v.item(),
on_epoch=True,
prog_bar=True,
batch_size=BS,
)
self.log(
f"{self._mode}_loss",
loss.item(),
on_epoch=True,
prog_bar=True,
batch_size=BS,
)
self.visu(graph, res_updated, loss_aux["confidence"])
return loss

def test_epoch_end(self, outputs: any, plot=False):
################ NEW VERSION ################
# NEW VERSION
res = self._metric_logger.get_epoch_results("test")
dic2 = {
"trainer_logged_metric" + k: v.item()
Expand Down Expand Up @@ -199,7 +234,7 @@ def visu(self, graph: Data, res: torch.tensor, confidence: torch.tensor):
on_epoch=True,
batch_size=graph.ptr.shape[0] - 1,
)
except:
except Exception:
pass

for b in range(graph.ptr.shape[0] - 1):
Expand Down Expand Up @@ -232,14 +267,22 @@ def visu(self, graph: Data, res: torch.tensor, confidence: torch.tensor):

# Visualize Graph with Segmentation
t1 = self._visualizer.plot_traversability_graph_on_seg(
pred[:, 0], seg, graph[b], center, img, not_log=True, colorize_invalid_centers=True
pred[:, 0],
seg,
graph[b],
center,
img,
not_log=True,
colorize_invalid_centers=True,
)
t2 = self._visualizer.plot_traversability_graph_on_seg(
graph[b].y, seg, graph[b], center, img, not_log=True
)
t_img = self._visualizer.plot_image(img, not_log=True)
self._visualizer.plot_list(
imgs=[t1, t2, t_img], tag=f"C{c}_{self._mode}_GraphTrav", store_folder=f"{self._mode}/graph_trav"
imgs=[t1, t2, t_img],
tag=f"C{c}_{self._mode}_GraphTrav",
store_folder=f"{self._mode}/graph_trav",
)

nr_channel_reco = graph[b].x.shape[1]
Expand All @@ -248,14 +291,27 @@ def visu(self, graph: Data, res: torch.tensor, confidence: torch.tensor):

# # Visualize Graph with Segmentation
c1 = self._visualizer.plot_traversability_graph_on_seg(
conf, seg, graph[b], center, img, not_log=True, colorize_invalid_centers=True
conf,
seg,
graph[b],
center,
img,
not_log=True,
colorize_invalid_centers=True,
)
c2 = self._visualizer.plot_traversability_graph_on_seg(
graph[b].y_valid.type(torch.float32), seg, graph[b], center, img, not_log=True
graph[b].y_valid.type(torch.float32),
seg,
graph[b],
center,
img,
not_log=True,
)
c_img = self._visualizer.plot_image(img, not_log=True)
self._visualizer.plot_list(
imgs=[c1, c2, c_img], tag=f"C{c}_{self._mode}_GraphConf", store_folder=f"{self._mode}/graph_conf"
imgs=[c1, c2, c_img],
tag=f"C{c}_{self._mode}_GraphConf",
store_folder=f"{self._mode}/graph_conf",
)

# if self._mode == "test":
Expand Down Expand Up @@ -316,10 +372,18 @@ def visu(self, graph: Data, res: torch.tensor, confidence: torch.tensor):
reco_loss = F.mse_loss(pred[:, -nr_channel_reco:], graph[b].x, reduction="none").mean(dim=1)

self._visualizer.plot_histogram(
reco_loss, graph[b].y, mean, std, tag=f"C{c}_{self._mode}__confidence_generator_prop"
reco_loss,
graph[b].y,
mean,
std,
tag=f"C{c}_{self._mode}__confidence_generator_prop",
)

if hasattr(graph[b], "y_gt"):
self._visualizer.plot_histogram(
reco_loss, graph[b].y_gt, mean, std, tag=f"C{c}_{self._mode}__confidence_generator_gt"
reco_loss,
graph[b].y_gt,
mean,
std,
tag=f"C{c}_{self._mode}__confidence_generator_gt",
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
raise ValueError("TODO: Not tested with new configuration!")
# Use the same config to load the data using the dataloader
from wild_visual_navigation.dataset import get_ablation_module
from wild_visual_navigation import WVN_ROOT_DIR
from wild_visual_navigation.cfg import ExperimentParams
import torch
from torchmetrics import Accuracy, AUROC
Expand All @@ -10,10 +9,8 @@

patch_sklearn()

from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import numpy as np
import copy
import os
Expand Down Expand Up @@ -125,8 +122,14 @@

for features_test, y_gt, y_prop, test_scene in zip(features_tests, y_gts, y_props, test_scenes):
y_pred = v.predict_proba(features_test)
test_auroc_gt.update(torch.from_numpy(y_pred[:, 1]), torch.from_numpy(y_gt).type(torch.long))
test_auroc_prop.update(torch.from_numpy(y_pred[:, 1]), torch.from_numpy(y_prop).type(torch.long))
test_auroc_gt.update(
torch.from_numpy(y_pred[:, 1]),
torch.from_numpy(y_gt).type(torch.long),
)
test_auroc_prop.update(
torch.from_numpy(y_pred[:, 1]),
torch.from_numpy(y_prop).type(torch.long),
)

res = {
"test_auroc_gt_seg": test_auroc_gt.compute().item(),
Expand All @@ -146,15 +149,16 @@
ws = os.environ.get("ENV_WORKSTATION_NAME", "default")
# Store epoch output to disk.
p = os.path.join(
exp.env.results, f"ablations/classicial_learning_ablation_{ws}/classicial_learning_ablation_test_results.pkl"
exp.env.results,
f"ablations/classicial_learning_ablation_{ws}/classicial_learning_ablation_test_results.pkl",
)
print(results_epoch)
Path(p).parent.mkdir(parents=True, exist_ok=True)

try:
os.remove(p)
except OSError as error:
pass
print(error)

with open(p, "wb") as handle:
pickle.dump(results_epoch, handle, protocol=pickle.HIGHEST_PROTOCOL)
Loading

0 comments on commit ff242f9

Please sign in to comment.