From d4957706a519f3385e27ebc6382598f80549dce5 Mon Sep 17 00:00:00 2001 From: Daniel Butterfield Date: Fri, 21 Jun 2024 11:52:36 -0400 Subject: [PATCH] Update test cases to match the inclusion of two more dataset entries (due to removal of derivative calculation), add functionality to check that a dataset path provided by the user matches what the dataset class expects --- src/grfgnn/datasets_py/LinTzuYaunDataset.py | 2 +- src/grfgnn/datasets_py/flexibleDataset.py | 11 +++ src/grfgnn/datasets_py/quadSDKDataset.py | 5 +- tests/testDatasets.py | 95 +++++++++++++-------- 4 files changed, 74 insertions(+), 39 deletions(-) diff --git a/src/grfgnn/datasets_py/LinTzuYaunDataset.py b/src/grfgnn/datasets_py/LinTzuYaunDataset.py index 58ccbf3..3b5926e 100644 --- a/src/grfgnn/datasets_py/LinTzuYaunDataset.py +++ b/src/grfgnn/datasets_py/LinTzuYaunDataset.py @@ -58,7 +58,7 @@ def process(self): # Write a txt file to save the dataset length & and first sequence index with open(str(Path(self.processed_dir, "info.txt")), "w") as f: - f.write(str(dataset_entries) + " " + str(0)) + f.write(str(dataset_entries) + " " + str(0) + self.get_google_drive_file_id()) # ============= DATA SORTING ORDER AND MAPPINGS ================== def get_base_node_sorted_order(self) -> list[str]: diff --git a/src/grfgnn/datasets_py/flexibleDataset.py b/src/grfgnn/datasets_py/flexibleDataset.py index 98a312e..3681c6c 100644 --- a/src/grfgnn/datasets_py/flexibleDataset.py +++ b/src/grfgnn/datasets_py/flexibleDataset.py @@ -77,6 +77,13 @@ def __init__(self, ) self.first_index = int(data[1]) + # Check to make sure that this dataset id matches what we expect. + # Protects against users passing a folder path to a different + # dataset sequence, causing a different dataset to be used than + # expected. + if self.get_google_drive_file_id() != data[2]: + raise ValueError("'root' parameter points to a Dataset sequence that doesn't match this Dataset class. Either fix the path to point to the correct sequence, or delete the data in the folder so that the proper sequence can be downloaded.") + # Parse the robot graph from the URDF file if self.data_format == 'heterogeneous_gnn': self.robotGraph = HeterogeneousRobotGraph(urdf_path, @@ -134,6 +141,10 @@ def get_google_drive_file_id(self): """ Method for child classes to choose which sequence to load; used if the dataset is downloaded. + + Additionally, used to check already downloaded datasets to + make sure the user didn't accidentally pass a path to + a different sequence. """ raise self.notImplementedError diff --git a/src/grfgnn/datasets_py/quadSDKDataset.py b/src/grfgnn/datasets_py/quadSDKDataset.py index a8327c0..6c49c0a 100644 --- a/src/grfgnn/datasets_py/quadSDKDataset.py +++ b/src/grfgnn/datasets_py/quadSDKDataset.py @@ -95,9 +95,10 @@ def process(self): # Track how many entries we have dataset_entries += 1 - # Write a txt file to save the dataset length & and first sequence index + # Write a txt file to save the dataset length, first sequence index, + # and the download id (for ensuring we have the right dataset later) with open(os.path.join(self.processed_dir, "info.txt"), "w") as f: - f.write(str(dataset_entries) + " " + str(0)) + f.write(str(dataset_entries) + " " + str(0) + " " + self.get_google_drive_file_id()) # ============= DATA SORTING ORDER AND MAPPINGS ================== diff --git a/tests/testDatasets.py b/tests/testDatasets.py index 956ecc2..2f5d4e5 100644 --- a/tests/testDatasets.py +++ b/tests/testDatasets.py @@ -1,7 +1,6 @@ import unittest from pathlib import Path -from grfgnn import CerberusStreetDataset, CerberusTrackDataset, Go1SimulatedDataset, FlexibleDataset, QuadSDKDataset_A1Speed1_0, QuadSDKDataset -from rosbags.highlevel import AnyReader +from grfgnn import FlexibleDataset, QuadSDKDataset_A1Speed1_0, QuadSDKDataset, QuadSDKDataset_A1Speed0_5 from torch_geometric.data import Data, HeteroData import numpy as np import torch @@ -36,31 +35,56 @@ class TestQuadSDKDatasets(unittest.TestCase): def setUp(self): # Get the paths to the URDF file and the dataset - path_to_a1_urdf = Path( + self.path_to_a1_urdf = Path( Path('.').parent, 'urdf_files', 'A1', 'a1.urdf').absolute() - path_to_normal_sequence = Path( + self.path_to_normal_sequence = Path( Path('.').parent, 'datasets', 'QuadSDK-A1Speed1.0').absolute() # Set up the QuadSDK datasets - self.dataset_hgnn_1 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence, - path_to_a1_urdf, 'package://a1_description/', + self.dataset_hgnn_1 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', 'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1) - self.dataset_gnn_1 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence, - path_to_a1_urdf, 'package://a1_description/', + self.dataset_gnn_1 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', 'unitree_ros/robots/a1_description', 'gnn', 1) - self.dataset_mlp_1 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence, - path_to_a1_urdf, 'package://a1_description/', + self.dataset_mlp_1 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', 'unitree_ros/robots/a1_description', 'mlp', 1) - self.dataset_hgnn_3 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence, - path_to_a1_urdf, 'package://a1_description/', + self.dataset_hgnn_3 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', 'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 3) - self.dataset_gnn_3 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence, - path_to_a1_urdf, 'package://a1_description/', + self.dataset_gnn_3 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', 'unitree_ros/robots/a1_description', 'gnn', 3) - self.dataset_mlp_3 = QuadSDKDataset_A1Speed1_0(path_to_normal_sequence, - path_to_a1_urdf, 'package://a1_description/', + self.dataset_mlp_3 = QuadSDKDataset_A1Speed1_0(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', 'unitree_ros/robots/a1_description', 'mlp', 3) + def test_init(self): + """ + Test that the __init__ method properly detect when the user + erroneously gives a path to dataset folder that doesn't match + the Dataset Sequence class created. + """ + # Test the __init__ function properly runs on a new sequence + path_to_slow_sequence = Path( + Path('.').parent, 'datasets', 'QuadSDK-A1Speed0.5').absolute() + dataset_slow = QuadSDKDataset_A1Speed0_5(path_to_slow_sequence, + self.path_to_a1_urdf, 'package://a1_description/', + 'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1) + + # Try to create a normal sequence, pointing to the slow dataset directory + with self.assertRaises(ValueError): + dataset = QuadSDKDataset_A1Speed0_5(self.path_to_normal_sequence, + self.path_to_a1_urdf, 'package://a1_description/', + 'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1) + + # Try to create a slow sequence, pointing to the normal dataset directory + with self.assertRaises(ValueError) as e: + dataset = QuadSDKDataset_A1Speed1_0(path_to_slow_sequence, + self.path_to_a1_urdf, 'package://a1_description/', + 'unitree_ros/robots/a1_description', 'heterogeneous_gnn', 1) + def test_load_data_at_ros_seq(self): """ Make sure the data is loaded properly from the file, and that @@ -126,7 +150,7 @@ def test_load_data_sorted(self): def test_get_helper_mlp(self): # Get the inputs and labels - x, y = self.dataset_mlp_1.get_helper_mlp(9999) + x, y = self.dataset_mlp_1.get_helper_mlp(10000) # Define the desired data des_x = np.array([-0.06452160178213015, -0.366493877667443, 9.715652148737323, @@ -151,7 +175,7 @@ def test_get_helper_mlp(self): def test_get_helper_gnn(self): # Get the Data graph - data: Data = self.dataset_gnn_1.get_helper_gnn(9999) + data: Data = self.dataset_gnn_1.get_helper_gnn(10000) # Define the desired data des_x = np.array([[1, 1, 1], @@ -183,7 +207,7 @@ def test_get_helper_gnn(self): def test_get_helper_heterogeneous_gnn(self): # Get the HeteroData graph - heteroData: HeteroData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9999) + heteroData: HeteroData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(10000) # Test the desired edge matrices bj, jb, jj, fj, jf = self.dataset_hgnn_1.robotGraph.get_edge_index_matrices() @@ -243,9 +267,9 @@ def test_get(self): with self.assertRaises(IndexError): data = self.dataset_hgnn_1.get(-1) with self.assertRaises(IndexError): - data = self.dataset_hgnn_1.get(17529) + data = self.dataset_hgnn_1.get(17531) data = self.dataset_hgnn_1.get(0) - data = self.dataset_hgnn_1.get(17528) + data = self.dataset_hgnn_1.get(17530) # Test that it returns the proper values based on the model type x, y = self.dataset_mlp_1.get(0) @@ -270,18 +294,18 @@ def test_history_length_parameter(self): with self.assertRaises(IndexError): data = self.dataset_mlp_3.get(-1) with self.assertRaises(IndexError): - data = self.dataset_mlp_3.get(17527) + data = self.dataset_mlp_3.get(17529) data = self.dataset_mlp_3.get(0) - data = self.dataset_mlp_3.get(17526) + data = self.dataset_mlp_3.get(17528) # ================================= MLP ========================================== # Get the output - x_actual, y_actual = self.dataset_mlp_3.get(9997) + x_actual, y_actual = self.dataset_mlp_3.get(9998) # Calculated the desired x and y values - xb2, yb2 = self.dataset_mlp_1.get_helper_mlp(9997) - xb1, yb1 = self.dataset_mlp_1.get_helper_mlp(9998) - x, y_des = self.dataset_mlp_1.get_helper_mlp(9999) + xb2, yb2 = self.dataset_mlp_1.get_helper_mlp(9998) + xb1, yb1 = self.dataset_mlp_1.get_helper_mlp(9999) + x, y_des = self.dataset_mlp_1.get_helper_mlp(10000) x_comb = torch.stack((xb2, xb1, x), 0) x_des = torch.flatten(torch.transpose(x_comb, 0, 1), 0, 1) @@ -292,16 +316,16 @@ def test_history_length_parameter(self): # ================================= GNN ========================================== # Get the output - data_actual: Data = self.dataset_gnn_3.get_helper_gnn(9997) + data_actual: Data = self.dataset_gnn_3.get_helper_gnn(9998) # Check the labels labels_des = [0, 64.74924447333427, 64.98097097053076, 0] np.testing.assert_array_equal(data_actual.y.numpy(), np.array(labels_des, dtype=np.float64)) # Get desired node attributes - datab2 = self.dataset_gnn_1.get_helper_gnn(9997) - datab1 = self.dataset_gnn_1.get_helper_gnn(9998) - data = self.dataset_gnn_1.get_helper_gnn(9999) + datab2 = self.dataset_gnn_1.get_helper_gnn(9998) + datab1 = self.dataset_gnn_1.get_helper_gnn(9999) + data = self.dataset_gnn_1.get_helper_gnn(10000) x_des = torch.cat((datab2.x[:,0].unsqueeze(1), datab1.x[:,0].unsqueeze(1), data.x[:,0].unsqueeze(1), datab2.x[:,1].unsqueeze(1), datab1.x[:,1].unsqueeze(1), data.x[:,1].unsqueeze(1), datab2.x[:,2].unsqueeze(1), datab1.x[:,2].unsqueeze(1), data.x[:,2].unsqueeze(1)), 1) @@ -314,12 +338,12 @@ def test_history_length_parameter(self): # ================================= Heterogeneous GNN ========================================== # Get the HeteroData graph - heteroData: HeteroData = self.dataset_hgnn_3.get_helper_heterogeneous_gnn(9997) + heteroData: HeteroData = self.dataset_hgnn_3.get_helper_heterogeneous_gnn(9998) # Get the desired node attributes - hDatab2 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9997) - hDatab1 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9998) - hData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9999) + hDatab2 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9998) + hDatab1 = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(9999) + hData = self.dataset_hgnn_1.get_helper_heterogeneous_gnn(10000) base_x_cat = torch.cat((hDatab2['base'].x, hDatab1['base'].x, hData['base'].x), 0) base_x_des = torch.flatten(torch.transpose(base_x_cat, 0, 1), 0).unsqueeze(0) joint_x_des = torch.cat((hDatab2['joint'].x[:,0].unsqueeze(1), hDatab1['joint'].x[:,0].unsqueeze(1), hData['joint'].x[:,0].unsqueeze(1), @@ -334,6 +358,5 @@ def test_history_length_parameter(self): np.testing.assert_array_equal(heteroData['foot'].x.numpy(), foot_x.numpy()) np.testing.assert_array_equal(heteroData.y.numpy(), y.numpy()) - if __name__ == "__main__": unittest.main()