Skip to content

Commit

Permalink
Unit test bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielChaseButterfield committed Mar 29, 2024
1 parent 195c38d commit 97f2cce
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 116 deletions.
25 changes: 1 addition & 24 deletions src/grfgnn/urdfParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self,
parent=connections[0],
child=connections[1])
self.edges.append(new_edge)
if len(connections) > 2:
if len(connections) > 2: # TODO: Don't connect children to other children
# Setup an edge for each pair of connections
count = 0
for i in range(0, len(connections)):
Expand Down Expand Up @@ -248,7 +248,6 @@ def get_edge_name_to_connections_dict(self):
[[node_dict[edge.parent], node_dict[edge.child]],
[node_dict[edge.child], node_dict[edge.parent]]])

# print(edge_dict)
return edge_dict

def display_URDF_info(self):
Expand Down Expand Up @@ -276,25 +275,3 @@ def display_URDF_info(self):
print('{} -> {} <- {}'.format(joint.parent, joint.name,
joint.child))
print("")


def main():
"""
Simple code that demonstrates basic functionality of the RobotURDF class.
"""

HyQ_URDF = RobotURDF('urdf_files\HyQ\hyq.urdf',
'package://hyq_description/', 'hyq-description')
# HyQ_URDF.display_URDF_info()
# print("Edge Matrix (HyQ): ", HyQ_URDF.get_node_name_to_index_dict())

# A1_URDF = RobotURDF('urdf_files\A1\a1.urdf', 'package://a1_description/',
# 'unitree_ros/robots/a1_description', True)
# A1_URDF.display_URDF_info()
# print("Edge Matrix (A1): ", A1_URDF.get_edge_index_matrix())

HyQ_URDF.get_edge_name_to_connections_dict()


if __name__ == "__main__":
main()
166 changes: 74 additions & 92 deletions tests/testUrdfParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@
class TestRobotURDF(unittest.TestCase):

def setUp(self):
hyq_path = Path(Path(__file__).cwd(), 'urdf_files', 'HyQ',
'hyq.urdf').absolute()
self.HyQ_URDF = RobotURDF(hyq_path, 'package://hyq_description/',
self.hyq_path = Path(
Path(__file__).cwd(), 'urdf_files', 'HyQ', 'hyq.urdf').absolute()

self.HyQ_URDF = RobotURDF(self.hyq_path, 'package://hyq_description/',
'hyq-description', False)
self.HyQ_URDF_swapped = RobotURDF(hyq_path,
self.HyQ_URDF_swapped = RobotURDF(self.hyq_path,
'package://hyq_description/',
'hyq-description', True)

def test_constructor(self):
"""
check if self.nodes has all the name in the URDF file (for False)
check if self.edges has all the name in the URDF file (for True)
Check if self.nodes has all the name in the URDF file, and
check if self.edges has all the name in the URDF file.
"""

nodes_name = {
'world', 'base_link', 'trunk', 'lf_hipassembly', 'lh_hipassembly',
'rf_hipassembly', 'rh_hipassembly', 'lf_upperleg', 'lh_upperleg',
Expand All @@ -40,147 +41,128 @@ def test_constructor(self):
'lh_foot_joint', 'rf_foot_joint', 'rh_foot_joint'
}

get_nodes_name = self.HyQ_URDF.nodes

for get_nodes_name in (get_nodes_name):
if get_nodes_name.name not in nodes_name:
print(get_nodes_name.name)
for node in self.HyQ_URDF.nodes:
self.assertTrue(node.name in nodes_name)

get_edges_name = self.HyQ_URDF.edges

for get_edges_name in (get_edges_name):
if get_edges_name.name not in edges_name:
raise Exception

raise Exception
for edge in self.HyQ_URDF.edges:
self.assertTrue(edge.name in edges_name)

def test_create_updated_urdf_file(self):
"""
check if the updated urdf file exist
Check that calling the constructor creates
the updated urdf file.
"""
actual_url = os.path.join(os.getcwd(),
os.path.dirname('urdf_files\HyQ\hyq.urdf'),
'package://hyq_description/', "temp")[:-4]
self.assertTrue(os.path.exists(actual_url))

raise NotImplemented
# Delete the urdf file
hyq_path_updated = self.hyq_path.parent / "hyq_updated.urdf"
os.remove(str(hyq_path_updated))
self.assertFalse(os.path.exists(hyq_path_updated))

# Rebuild it
RobotURDF(self.hyq_path, 'package://hyq_description/',
'hyq-description', False)
self.assertTrue(os.path.exists(hyq_path_updated))

def test_get_node_name_to_index_dict(self):
"""
check if the index is unique
Check if all the indexes of the nodes in the dictionary
are unique.
"""

key = list(self.HyQ_URDF.get_node_name_to_index_dict())
get_nodes_index = []

for key in key:
index = self.HyQ_URDF.get_node_name_to_index_dict()[key]
get_nodes_index.append(index)

result = pd.Index(get_nodes_index).is_unique

if result == False:
raise NotImplemented
else:
return result
self.assertTrue(pd.Index(get_nodes_index).is_unique)

def test_get_node_index_to_name_dict(self):
"""
check if the index matchse the dictionary
or the name is unique
(maybe call test_get_node_name_to_index_dict would be easier to check)
Check the index_to_name dict by running making sure the
index_to_name dict and the name_to_index dict are consistent.
"""
key_index = list(self.HyQ_URDF.get_node_index_to_name_dict())

# get_nodes_index = self.test_get_node_name_to_index_dict()

key = list(self.HyQ_URDF.get_node_name_to_index_dict())
index_to_name = list(self.HyQ_URDF.get_node_index_to_name_dict())
name_to_index = list(self.HyQ_URDF.get_node_name_to_index_dict())
get_nodes_index = []

for key in key:
for key in name_to_index:
index = self.HyQ_URDF.get_node_name_to_index_dict()[key]
get_nodes_index.append(index)

self.assertEqual(key_index, get_nodes_index)

# raise NotImplemented
self.assertEqual(index_to_name, get_nodes_index)

def test_get_edge_index_matrix(self):
"""
check the dimension of the edge matrix
Check the dimensionality of the edge matrix.
"""
edge_matrix = self.HyQ_URDF.get_edge_index_matrix()
# print(edge_matrix.shape)

m = edge_matrix.shape[0]
n = edge_matrix.shape[1]
num_of_edges = self.HyQ_URDF.get_num_nodes() - 1

self.assertEqual(m, 2)
self.assertEqual(2 * num_of_edges, n)
edge_matrix = self.HyQ_URDF.get_edge_index_matrix()

# raise NotImplemented
self.assertEqual(edge_matrix.shape[0], 2)
self.assertEqual(edge_matrix.shape[1], 36)

def test_get_num_nodes(self):
"""
check the number of the node(False)/ edges(True)
Check that the number of nodes are correct.
"""
a = self.HyQ_URDF.get_num_nodes()
self.assertEqual(a, 19)

a = self.HyQ_URDF_swapped.get_num_nodes()
self.assertEqual(a, 18)
self.assertEqual(self.HyQ_URDF.get_num_nodes(), 19)
self.assertEqual(self.HyQ_URDF_swapped.get_num_nodes(), 18)

def test_get_edge_connections_to_name_dict(self):
"""
check if the index matchse the dictionary
or the name is unique
(maybe call test_get_edge_name_to_connections_dict would be easier to check)
Check the connections_to_name dict by running making sure the
connections_to_name dict and the name_to_connections dict are
consistent.
"""
expected_index = list(
self.HyQ_URDF.get_edge_connections_to_name_dict())
print(expected_index)

key = list(self.HyQ_URDF.get_edge_name_to_connections_dict())

connections_dict = []
connections_to_name = list(
self.HyQ_URDF.get_edge_connections_to_name_dict())
name_to_connections = list(
self.HyQ_URDF.get_edge_name_to_connections_dict())

for key in key:
index = self.HyQ_URDF.get_edge_name_to_connections_dict()[key]
for i in range(index.shape[1]):
real_index = np.squeeze(index[:, i].reshape(1, -1))
connections_dict.append(real_index)
result = []
for key in name_to_connections:
connections = self.HyQ_URDF.get_edge_name_to_connections_dict(
)[key]
for i in range(connections.shape[1]):
real_reshaped = np.squeeze(connections[:, i].reshape(1, -1))
result.append(real_reshaped)

connections_dict = [tuple(arr) for arr in connections_dict]
print(connections_dict)
result = [tuple(arr) for arr in result]

self.assertEqual(expected_index, connections_dict)
self.assertEqual(connections_to_name, result)

def test_get_edge_name_to_connections_dict(self):
"""
check the dictionary is unique
Check each connection in the dictionary is unique.
"""
key = list(self.HyQ_URDF.get_edge_name_to_connections_dict())
connections_dict = []

for key in key:
index = self.HyQ_URDF.get_edge_name_to_connections_dict()[key]
connections_dict.append(index)
name_to_connections = list(
self.HyQ_URDF.get_edge_name_to_connections_dict())
all_connections = []

# Get all connections from dictionary
for key in name_to_connections:
connections = self.HyQ_URDF.get_edge_name_to_connections_dict(
)[key]
for i in range(connections.shape[1]):
real_reshaped = np.squeeze(connections[:, i].reshape(1, -1))
all_connections.append(real_reshaped)

seen_arrays = set()
for array in connections_dict:
for array in all_connections:
# Convert the array to a tuple since lists are not hashable
array_tuple = tuple(array)

# Check if the array is already seen
if array_tuple in seen_arrays:
return False
# Make sure the array hasn't been seen
self.assertTrue(array_tuple not in seen_arrays)

# Add it to the seen arrays
seen_arrays.add(array_tuple)


if __name__ == '__main__':
Expand Down

0 comments on commit 97f2cce

Please sign in to comment.