Skip to content

Commit

Permalink
Update test cases to match the inclusion of two more dataset entries …
Browse files Browse the repository at this point in the history
…(due to removal of derivative calculation), add functionality to check that a dataset path provided by the user matches what the dataset class expects
  • Loading branch information
DanielChaseButterfield committed Jun 21, 2024
1 parent 593c28c commit d495770
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/grfgnn/datasets_py/LinTzuYaunDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def process(self):

# Write a txt file to save the dataset length & and first sequence index
with open(str(Path(self.processed_dir, "info.txt")), "w") as f:
f.write(str(dataset_entries) + " " + str(0))
f.write(str(dataset_entries) + " " + str(0) + self.get_google_drive_file_id())

# ============= DATA SORTING ORDER AND MAPPINGS ==================
def get_base_node_sorted_order(self) -> list[str]:
Expand Down
11 changes: 11 additions & 0 deletions src/grfgnn/datasets_py/flexibleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def __init__(self,
)
self.first_index = int(data[1])

# Check to make sure that this dataset id matches what we expect.
# Protects against users passing a folder path to a different
# dataset sequence, causing a different dataset to be used than
# expected.
if self.get_google_drive_file_id() != data[2]:
raise ValueError("'root' parameter points to a Dataset sequence that doesn't match this Dataset class. Either fix the path to point to the correct sequence, or delete the data in the folder so that the proper sequence can be downloaded.")

# Parse the robot graph from the URDF file
if self.data_format == 'heterogeneous_gnn':
self.robotGraph = HeterogeneousRobotGraph(urdf_path,
Expand Down Expand Up @@ -134,6 +141,10 @@ def get_google_drive_file_id(self):
"""
Method for child classes to choose which sequence to load;
used if the dataset is downloaded.
Additionally, used to check already downloaded datasets to
make sure the user didn't accidentally pass a path to
a different sequence.
"""
raise self.notImplementedError

Expand Down
5 changes: 3 additions & 2 deletions src/grfgnn/datasets_py/quadSDKDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ def process(self):
# Track how many entries we have
dataset_entries += 1

# Write a txt file to save the dataset length & and first sequence index
# Write a txt file to save the dataset length, first sequence index,
# and the download id (for ensuring we have the right dataset later)
with open(os.path.join(self.processed_dir, "info.txt"), "w") as f:
f.write(str(dataset_entries) + " " + str(0))
f.write(str(dataset_entries) + " " + str(0) + " " + self.get_google_drive_file_id())

# ============= DATA SORTING ORDER AND MAPPINGS ==================

Expand Down
95 changes: 59 additions & 36 deletions tests/testDatasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
from pathlib import Path
from grfgnn import CerberusStreetDataset, CerberusTrackDataset, Go1SimulatedDataset, FlexibleDataset, QuadSDKDataset_A1Speed1_0, QuadSDKDataset
from rosbags.highlevel import AnyReader
from grfgnn import FlexibleDataset, QuadSDKDataset_A1Speed1_0, QuadSDKDataset, QuadSDKDataset_A1Speed0_5
from torch_geometric.data import Data, HeteroData
import numpy as np
import torch
Expand Down Expand Up @@ -36,31 +35,56 @@ class TestQuadSDKDatasets(unittest.TestCase):

def setUp(self):
# Get the paths to the URDF file and the dataset
path_to_a1_urdf = Path(
self.path_to_a1_urdf = Path(
Path('.').parent, 'urdf_files', 'A1', 'a1.urdf').absolute()
path_to_normal_sequence = Path(
self.path_to_normal_sequence = Path(
Path('.').parent, 'datasets', 'QuadSDK-A1Speed1.0').absolute()

# Set up the QuadSDK datasets
self.dataset_hgnn_1 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence,
path_to_a1_urdf, 'package://a1_description/',
self.dataset_hgnn_1 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1)
self.dataset_gnn_1 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence,
path_to_a1_urdf, 'package://a1_description/',
self.dataset_gnn_1 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'gnn', 1)
self.dataset_mlp_1 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence,
path_to_a1_urdf, 'package://a1_description/',
self.dataset_mlp_1 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'mlp', 1)
self.dataset_hgnn_3 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence,
path_to_a1_urdf, 'package://a1_description/',
self.dataset_hgnn_3 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 3)
self.dataset_gnn_3 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence,
path_to_a1_urdf, 'package://a1_description/',
self.dataset_gnn_3 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'gnn', 3)
self.dataset_mlp_3 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence,
path_to_a1_urdf, 'package://a1_description/',
self.dataset_mlp_3 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'mlp', 3)

def test_init(self):
"""
Test that the __init__ method properly detect when the user
erroneously gives a path to dataset folder that doesn't match
the Dataset Sequence class created.
"""
# Test the __init__ function properly runs on a new sequence
path_to_slow_sequence = Path(
Path('.').parent, 'datasets', 'QuadSDK-A1Speed0.5').absolute()
dataset_slow = QuadSDKDataset_A1Speed0_5(path_to_slow_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1)

# Try to create a normal sequence, pointing to the slow dataset directory
with self.assertRaises(ValueError):
dataset = QuadSDKDataset_A1Speed0_5(self.path_to_normal_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1)

# Try to create a slow sequence, pointing to the normal dataset directory
with self.assertRaises(ValueError) as e:
dataset = QuadSDKDataset_A1Speed1_0(path_to_slow_sequence,
self.path_to_a1_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1)

def test_load_data_at_ros_seq(self):
"""
Make sure the data is loaded properly from the file, and that
Expand Down Expand Up @@ -126,7 +150,7 @@ def test_load_data_sorted(self):

def test_get_helper_mlp(self):
# Get the inputs and labels
x, y = self.dataset_mlp_1.get_helper_mlp(9999)
x, y = self.dataset_mlp_1.get_helper_mlp(10000)

# Define the desired data
des_x = np.array([-0.06452160178213015, -0.366493877667443, 9.715652148737323,
Expand All @@ -151,7 +175,7 @@ def test_get_helper_mlp(self):

def test_get_helper_gnn(self):
# Get the Data graph
data: Data = self.dataset_gnn_1.get_helper_gnn(9999)
data: Data = self.dataset_gnn_1.get_helper_gnn(10000)

# Define the desired data
des_x = np.array([[1, 1, 1],
Expand Down Expand Up @@ -183,7 +207,7 @@ def test_get_helper_gnn(self):

def test_get_helper_heterogeneous_gnn(self):
# Get the HeteroData graph
heteroData: HeteroData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9999)
heteroData: HeteroData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(10000)

# Test the desired edge matrices
bj, jb, jj, fj, jf = self.dataset_hgnn_1.robotGraph.get_edge_index_matrices()
Expand Down Expand Up @@ -243,9 +267,9 @@ def test_get(self):
with self.assertRaises(IndexError):
data = self.dataset_hgnn_1.get(-1)
with self.assertRaises(IndexError):
data = self.dataset_hgnn_1.get(17529)
data = self.dataset_hgnn_1.get(17531)
data = self.dataset_hgnn_1.get(0)
data = self.dataset_hgnn_1.get(17528)
data = self.dataset_hgnn_1.get(17530)

# Test that it returns the proper values based on the model type
x, y = self.dataset_mlp_1.get(0)
Expand All @@ -270,18 +294,18 @@ def test_history_length_parameter(self):
with self.assertRaises(IndexError):
data = self.dataset_mlp_3.get(-1)
with self.assertRaises(IndexError):
data = self.dataset_mlp_3.get(17527)
data = self.dataset_mlp_3.get(17529)
data = self.dataset_mlp_3.get(0)
data = self.dataset_mlp_3.get(17526)
data = self.dataset_mlp_3.get(17528)

# ================================= MLP ==========================================
# Get the output
x_actual, y_actual = self.dataset_mlp_3.get(9997)
x_actual, y_actual = self.dataset_mlp_3.get(9998)

# Calculated the desired x and y values
xb2, yb2 = self.dataset_mlp_1.get_helper_mlp(9997)
xb1, yb1 = self.dataset_mlp_1.get_helper_mlp(9998)
x, y_des = self.dataset_mlp_1.get_helper_mlp(9999)
xb2, yb2 = self.dataset_mlp_1.get_helper_mlp(9998)
xb1, yb1 = self.dataset_mlp_1.get_helper_mlp(9999)
x, y_des = self.dataset_mlp_1.get_helper_mlp(10000)
x_comb = torch.stack((xb2, xb1, x), 0)
x_des = torch.flatten(torch.transpose(x_comb, 0, 1), 0, 1)

Expand All @@ -292,16 +316,16 @@ def test_history_length_parameter(self):
# ================================= GNN ==========================================

# Get the output
data_actual: Data = self.dataset_gnn_3.get_helper_gnn(9997)
data_actual: Data = self.dataset_gnn_3.get_helper_gnn(9998)

# Check the labels
labels_des = [0, 64.74924447333427, 64.98097097053076, 0]
np.testing.assert_array_equal(data_actual.y.numpy(), np.array(labels_des, dtype=np.float64))

# Get desired node attributes
datab2 = self.dataset_gnn_1.get_helper_gnn(9997)
datab1 = self.dataset_gnn_1.get_helper_gnn(9998)
data = self.dataset_gnn_1.get_helper_gnn(9999)
datab2 = self.dataset_gnn_1.get_helper_gnn(9998)
datab1 = self.dataset_gnn_1.get_helper_gnn(9999)
data = self.dataset_gnn_1.get_helper_gnn(10000)
x_des = torch.cat((datab2.x[:,0].unsqueeze(1), datab1.x[:,0].unsqueeze(1), data.x[:,0].unsqueeze(1),
datab2.x[:,1].unsqueeze(1), datab1.x[:,1].unsqueeze(1), data.x[:,1].unsqueeze(1),
datab2.x[:,2].unsqueeze(1), datab1.x[:,2].unsqueeze(1), data.x[:,2].unsqueeze(1)), 1)
Expand All @@ -314,12 +338,12 @@ def test_history_length_parameter(self):
# ================================= Heterogeneous GNN ==========================================

# Get the HeteroData graph
heteroData: HeteroData = self.dataset_hgnn_3.get_helper_heterogeneous_gnn(9997)
heteroData: HeteroData = self.dataset_hgnn_3.get_helper_heterogeneous_gnn(9998)

# Get the desired node attributes
hDatab2 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9997)
hDatab1 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9998)
hData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9999)
hDatab2 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9998)
hDatab1 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9999)
hData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(10000)
base_x_cat = torch.cat((hDatab2['base'].x, hDatab1['base'].x, hData['base'].x), 0)
base_x_des = torch.flatten(torch.transpose(base_x_cat, 0, 1), 0).unsqueeze(0)
joint_x_des = torch.cat((hDatab2['joint'].x[:,0].unsqueeze(1), hDatab1['joint'].x[:,0].unsqueeze(1), hData['joint'].x[:,0].unsqueeze(1),
Expand All @@ -334,6 +358,5 @@ def test_history_length_parameter(self):
np.testing.assert_array_equal(heteroData['foot'].x.numpy(), foot_x.numpy())
np.testing.assert_array_equal(heteroData.y.numpy(), y.numpy())


if __name__ == "__main__":
unittest.main()

0 comments on commit d495770

Please sign in to comment.