From 97f2cce087cc7f78715496c80cf6a697b89d3687 Mon Sep 17 00:00:00 2001 From: Daniel Butterfield Date: Fri, 29 Mar 2024 18:12:56 -0400 Subject: [PATCH] Unit test bug fixes --- src/grfgnn/urdfParser.py | 25 +----- tests/testUrdfParser.py | 166 +++++++++++++++++---------------------- 2 files changed, 75 insertions(+), 116 deletions(-) diff --git a/src/grfgnn/urdfParser.py b/src/grfgnn/urdfParser.py index ba513dc..ea7bab9 100644 --- a/src/grfgnn/urdfParser.py +++ b/src/grfgnn/urdfParser.py @@ -95,7 +95,7 @@ def __init__(self, parent=connections[0], child=connections[1]) self.edges.append(new_edge) - if len(connections) > 2: + if len(connections) > 2: # TODO: Don't connect children to other children # Setup an edge for each pair of connections count = 0 for i in range(0, len(connections)): @@ -248,7 +248,6 @@ def get_edge_name_to_connections_dict(self): [[node_dict[edge.parent], node_dict[edge.child]], [node_dict[edge.child], node_dict[edge.parent]]]) - # print(edge_dict) return edge_dict def display_URDF_info(self): @@ -276,25 +275,3 @@ def display_URDF_info(self): print('{} -> {} <- {}'.format(joint.parent, joint.name, joint.child)) print("") - - -def main(): - """ - Simple code that demonstrates basic functionality of the RobotURDF class. - """ - - HyQ_URDF = RobotURDF('urdf_files\HyQ\hyq.urdf', - 'package://hyq_description/', 'hyq-description') - # HyQ_URDF.display_URDF_info() - # print("Edge Matrix (HyQ): ", HyQ_URDF.get_node_name_to_index_dict()) - - # A1_URDF = RobotURDF('urdf_files\A1\a1.urdf', 'package://a1_description/', - # 'unitree_ros/robots/a1_description', True) - # A1_URDF.display_URDF_info() - # print("Edge Matrix (A1): ", A1_URDF.get_edge_index_matrix()) - - HyQ_URDF.get_edge_name_to_connections_dict() - - -if __name__ == "__main__": - main() diff --git a/tests/testUrdfParser.py b/tests/testUrdfParser.py index f6f294b..4cb1c64 100644 --- a/tests/testUrdfParser.py +++ b/tests/testUrdfParser.py @@ -10,20 +10,21 @@ class TestRobotURDF(unittest.TestCase): def setUp(self): - hyq_path = Path(Path(__file__).cwd(), 'urdf_files', 'HyQ', - 'hyq.urdf').absolute() - self.HyQ_URDF = RobotURDF(hyq_path, 'package://hyq_description/', + self.hyq_path = Path( + Path(__file__).cwd(), 'urdf_files', 'HyQ', 'hyq.urdf').absolute() + + self.HyQ_URDF = RobotURDF(self.hyq_path, 'package://hyq_description/', 'hyq-description', False) - self.HyQ_URDF_swapped = RobotURDF(hyq_path, + self.HyQ_URDF_swapped = RobotURDF(self.hyq_path, 'package://hyq_description/', 'hyq-description', True) def test_constructor(self): """ - check if self.nodes has all the name in the URDF file (for False) - check if self.edges has all the name in the URDF file (for True) - + Check if self.nodes has all the name in the URDF file, and + check if self.edges has all the name in the URDF file. """ + nodes_name = { 'world', 'base_link', 'trunk', 'lf_hipassembly', 'lh_hipassembly', 'rf_hipassembly', 'rh_hipassembly', 'lf_upperleg', 'lh_upperleg', @@ -40,37 +41,34 @@ def test_constructor(self): 'lh_foot_joint', 'rf_foot_joint', 'rh_foot_joint' } - get_nodes_name = self.HyQ_URDF.nodes - - for get_nodes_name in (get_nodes_name): - if get_nodes_name.name not in nodes_name: - print(get_nodes_name.name) + for node in self.HyQ_URDF.nodes: + self.assertTrue(node.name in nodes_name) - get_edges_name = self.HyQ_URDF.edges - - for get_edges_name in (get_edges_name): - if get_edges_name.name not in edges_name: - raise Exception - - raise Exception + for edge in self.HyQ_URDF.edges: + self.assertTrue(edge.name in edges_name) def test_create_updated_urdf_file(self): """ - check if the updated urdf file exist - + Check that calling the constructor creates + the updated urdf file. """ - actual_url = os.path.join(os.getcwd(), - os.path.dirname('urdf_files\HyQ\hyq.urdf'), - 'package://hyq_description/', "temp")[:-4] - self.assertTrue(os.path.exists(actual_url)) - raise NotImplemented + # Delete the urdf file + hyq_path_updated = self.hyq_path.parent / "hyq_updated.urdf" + os.remove(str(hyq_path_updated)) + self.assertFalse(os.path.exists(hyq_path_updated)) + + # Rebuild it + RobotURDF(self.hyq_path, 'package://hyq_description/', + 'hyq-description', False) + self.assertTrue(os.path.exists(hyq_path_updated)) def test_get_node_name_to_index_dict(self): """ - check if the index is unique - + Check if all the indexes of the nodes in the dictionary + are unique. """ + key = list(self.HyQ_URDF.get_node_name_to_index_dict()) get_nodes_index = [] @@ -78,109 +76,93 @@ def test_get_node_name_to_index_dict(self): index = self.HyQ_URDF.get_node_name_to_index_dict()[key] get_nodes_index.append(index) - result = pd.Index(get_nodes_index).is_unique - - if result == False: - raise NotImplemented - else: - return result + self.assertTrue(pd.Index(get_nodes_index).is_unique) def test_get_node_index_to_name_dict(self): """ - check if the index matchse the dictionary - or the name is unique - (maybe call test_get_node_name_to_index_dict would be easier to check) - + Check the index_to_name dict by running making sure the + index_to_name dict and the name_to_index dict are consistent. """ - key_index = list(self.HyQ_URDF.get_node_index_to_name_dict()) - # get_nodes_index = self.test_get_node_name_to_index_dict() - - key = list(self.HyQ_URDF.get_node_name_to_index_dict()) + index_to_name = list(self.HyQ_URDF.get_node_index_to_name_dict()) + name_to_index = list(self.HyQ_URDF.get_node_name_to_index_dict()) get_nodes_index = [] - for key in key: + for key in name_to_index: index = self.HyQ_URDF.get_node_name_to_index_dict()[key] get_nodes_index.append(index) - self.assertEqual(key_index, get_nodes_index) - - # raise NotImplemented + self.assertEqual(index_to_name, get_nodes_index) def test_get_edge_index_matrix(self): """ - check the dimension of the edge matrix - + Check the dimensionality of the edge matrix. """ - edge_matrix = self.HyQ_URDF.get_edge_index_matrix() - # print(edge_matrix.shape) - - m = edge_matrix.shape[0] - n = edge_matrix.shape[1] - num_of_edges = self.HyQ_URDF.get_num_nodes() - 1 - self.assertEqual(m, 2) - self.assertEqual(2 * num_of_edges, n) + edge_matrix = self.HyQ_URDF.get_edge_index_matrix() - # raise NotImplemented + self.assertEqual(edge_matrix.shape[0], 2) + self.assertEqual(edge_matrix.shape[1], 36) def test_get_num_nodes(self): """ - check the number of the node(False)/ edges(True) - + Check that the number of nodes are correct. """ - a = self.HyQ_URDF.get_num_nodes() - self.assertEqual(a, 19) - a = self.HyQ_URDF_swapped.get_num_nodes() - self.assertEqual(a, 18) + self.assertEqual(self.HyQ_URDF.get_num_nodes(), 19) + self.assertEqual(self.HyQ_URDF_swapped.get_num_nodes(), 18) def test_get_edge_connections_to_name_dict(self): """ - check if the index matchse the dictionary - or the name is unique - (maybe call test_get_edge_name_to_connections_dict would be easier to check) - + Check the connections_to_name dict by running making sure the + connections_to_name dict and the name_to_connections dict are + consistent. """ - expected_index = list( - self.HyQ_URDF.get_edge_connections_to_name_dict()) - print(expected_index) - key = list(self.HyQ_URDF.get_edge_name_to_connections_dict()) - - connections_dict = [] + connections_to_name = list( + self.HyQ_URDF.get_edge_connections_to_name_dict()) + name_to_connections = list( + self.HyQ_URDF.get_edge_name_to_connections_dict()) - for key in key: - index = self.HyQ_URDF.get_edge_name_to_connections_dict()[key] - for i in range(index.shape[1]): - real_index = np.squeeze(index[:, i].reshape(1, -1)) - connections_dict.append(real_index) + result = [] + for key in name_to_connections: + connections = self.HyQ_URDF.get_edge_name_to_connections_dict( + )[key] + for i in range(connections.shape[1]): + real_reshaped = np.squeeze(connections[:, i].reshape(1, -1)) + result.append(real_reshaped) - connections_dict = [tuple(arr) for arr in connections_dict] - print(connections_dict) + result = [tuple(arr) for arr in result] - self.assertEqual(expected_index, connections_dict) + self.assertEqual(connections_to_name, result) def test_get_edge_name_to_connections_dict(self): """ - check the dictionary is unique - + Check each connection in the dictionary is unique. """ - key = list(self.HyQ_URDF.get_edge_name_to_connections_dict()) - connections_dict = [] - for key in key: - index = self.HyQ_URDF.get_edge_name_to_connections_dict()[key] - connections_dict.append(index) + name_to_connections = list( + self.HyQ_URDF.get_edge_name_to_connections_dict()) + all_connections = [] + + # Get all connections from dictionary + for key in name_to_connections: + connections = self.HyQ_URDF.get_edge_name_to_connections_dict( + )[key] + for i in range(connections.shape[1]): + real_reshaped = np.squeeze(connections[:, i].reshape(1, -1)) + all_connections.append(real_reshaped) seen_arrays = set() - for array in connections_dict: + for array in all_connections: # Convert the array to a tuple since lists are not hashable array_tuple = tuple(array) - # Check if the array is already seen - if array_tuple in seen_arrays: - return False + # Make sure the array hasn't been seen + self.assertTrue(array_tuple not in seen_arrays) + + # Add it to the seen arrays + seen_arrays.add(array_tuple) if __name__ == '__main__':