Skip to content

Commit

Permalink
Quad-SDK datasets, and improved graph functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielChaseButterfield authored Jun 10, 2024
2 parents f69c73a + e74ee4d commit 5ca56a7
Show file tree
Hide file tree
Showing 25 changed files with 2,695 additions and 1,199 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,10 @@ models/
# Ignore wandb information and datasets
datasets/
wandb/
lightning_logs/
lightning_logs/

# Ignore vscode config
.vscode

# Ignored generated animation files
media/
38 changes: 14 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,15 @@ git submodule init
git submodule update
```

## Dataset Installation
Next, install Dr. Xiong's Quadruped dataset by going to the following link (https://gtvault-my.sharepoint.com/:u:/g/personal/lgan31_gatech_edu/Ee5lmlVVQTZCreMujfQOTFABPJn6RyjK8UDABFXPL86UcA?e=tBGhhO), unzipping the folder, and then placing all of the bag
files within the following folder:
```
<repository base directory>/datasets/xiong_simulated/raw/
```

So, for example, this should be a valid path to a bag file:
```
<repository base directory>/datasets/xiong_simulated/raw/traj_0000.bag
```
There should be about 100 bag files.

## Training a new model
To train a new model from Dr. Xiong's quadruped data, run the following command within your Conda environment:
To train a new model from QuadSDK data, run the following command within your Conda environment:

```
python research/train.py
```

First, this command will process the dataset to ensure quick data access during training. Next, it will begin a
training a GNN and log the results to WandB (Weights and Biases). Note that you may need to put in
training a Heterogeneous GNN and log the results to WandB (Weights and Biases). Note that you may need to put in
a WandB key in order to use this logging feature. You can disable this logging following the instructions in
[#Editing this repository](#editing-this-repository), since the logger code is found on line 386 of
`src/grfgnn/gnnLightning.py` and can be commented out.
Expand All @@ -57,10 +44,10 @@ type and the randomly chosen model name (which is output in the terminal when tr

If you used logging on a previous step, you can see the losses and other relevant info in WandB (Weights and Biases).

But, regardless, whether you used logging or not, you can evaluate the data on the test subset of Dr. Xiong's quadruped data
But, regardless, whether you used logging or not, you can evaluate the data on the test subset of the Quad-SDK data
and see the predicted and ground truth GRF values for all four legs.

First, edit the file `research/evaluator.py` on lines 22, 23, and 46; this will tell the code what model you want to visualize, and how many entries in the dataset to use.
First, edit the file `research/evaluator.py` following the provided comments; this will tell the code what model you want to visualize, and how many entries in the dataset to use.

Then, run the following command to evaluate on the model:
```
Expand All @@ -69,12 +56,15 @@ python research/evaluator.py

The visualization of the predicted and GT GRF will be found in a file called `model_eval_results.pdf`.

## Editing this repository

If you want to make changes to the model type, the training parameters, or anything else, modify the files
found in the `src/grfgnn` folder, and then rebuild the library following the instructions in [#Installation](#installation).

Currently, two model types are supported:
## Changing the model type
Currently, three model types are supported:
- `mlp`
- `gnn`
To change the model type, please change line 316 in `src/grfgnn/gnnLightning.py`.
- `heterogeneous_gnn`

To change the model type, please change the `model_type` parameter in the `train.py` and `evaluator.py` files.

## Editing this repository

If you want to make changes to the source files, feel free to edit them in the `src/grfgnn` folder, and then
rebuild the library following the instructions in [#Installation](#installation).
153 changes: 145 additions & 8 deletions research/animations.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,157 @@
from manim import *
from manim.typing import Vector3D, Point3D
from grfgnn import NormalRobotGraph

from pathlib import Path

class CreateURDF(Scene):

def construct(self):
def node_type_to_color(self, node_type: str):
if(node_type == 'base'):
return ManimColor.from_hex('#D02536')
elif(node_type == 'joint'):
return ManimColor.from_hex('#F38C16')
elif(node_type == 'foot'):
return ManimColor.from_hex('#F4FF1F')
else:
raise ValueError

def construct(self):
# Load the A1 urdf
path_to_urdf = Path(Path('.').parent, 'urdf_files', 'A1',
'a1.urdf').absolute()
A1_URDF = NormalRobotGraph(path_to_urdf, 'package://a1_description/',
'unitree_ros/robots/a1_description', True)
'unitree_ros/robots/a1_description')

# Create a rectangle for each node
rectangles = []
for i, node in enumerate(A1_URDF.nodes):
# Make the rectangle
color = self.node_type_to_color(node.get_node_type())
rect = RoundedRectangle(corner_radius=0.2, color=color, height=1.0, width=1.5)
rect.set_fill(color, opacity=0.5)

# Add text to the rectangle
text = Text(node.name).scale(0.25)
rect.add(text)

# Move the rectangle to the proper spot
if i == 0:
rect.move_to([0, 3, 0])
else:
i_div = int((i - 1) / 4)
i_mod = (i - 1) % 4
if i_mod == 0:
rect.move_to([2*(i_div-1.5), 1.5, 0])
elif i_mod == 1:
rect.move_to([2*(i_div-1.5), 0, 0])
elif i_mod == 2:
rect.move_to([2*(i_div-1.5), -1.5, 0])
elif i_mod == 3:
rect.move_to([2*(i_div-1.5), -3, 0])

# Add it to all the others
rectangles.append(rect)

# For each connection, make an arrow
edge_matrix = A1_URDF.get_edge_index_matrix()
arrows = []
for j in range(0, edge_matrix.shape[1]):
if j % 2 == 1:
continue
col = edge_matrix[:,j]

# Get the two corresponding rectangles
parent: RoundedRectangle = rectangles[col[0]]
child: RoundedRectangle = rectangles[col[1]]

# Make the arrow
start = parent.get_center() + [0, -0.5, 0]
end = child.get_center() + [0, 0.5, 0]
arrow = Arrow(start, end, buff=0)
arrows.append(arrow)


# Play them on the screen
rectangles_vgroup = VGroup(*rectangles)
arrows = VGroup(*arrows)
shift_vector = UP * 0.25
self.play(FadeIn(rectangles_vgroup, shift=shift_vector, scale=0.9, run_time=1.0), FadeIn(arrows, shift=shift_vector, scale=0.9, run_time=1.0))
self.wait(1)
self.play(rectangles_vgroup.animate.shift(LEFT*2), arrows.animate.shift(LEFT*2))

# Display the text "URDF File"
right_side_placement = 4.5
text_title = Text("URDF File", weight=BOLD, font="sans-serif").scale(1).move_to([right_side_placement, 3, 0])
self.play(FadeIn(text_title, shift=shift_vector, scale=0.9, run_time=1.0))

# Create three circles to classify the node types
circle_base = Circle(color=self.node_type_to_color('base'), radius=0.25)
circle_base.set_fill(self.node_type_to_color('base'), opacity=0.5)
text = Text('base', slant=ITALIC).scale(0.4)
text.next_to(circle_base, RIGHT, buff=0.3)
circle_base.add(text)
circle_base.move_to([right_side_placement, 1, 0])

circle_joint = Circle(color=self.node_type_to_color('joint'), radius=0.25)
circle_joint.set_fill(self.node_type_to_color('joint'), opacity=0.5)
text = Text('joint', slant=ITALIC).scale(0.4)
text.next_to(circle_joint, RIGHT, buff=0.3)
circle_joint.add(text)
circle_joint.move_to([right_side_placement, 0, 0])

circle_foot = Circle(color=self.node_type_to_color('foot'), radius=0.25)
circle_foot.set_fill(self.node_type_to_color('foot'), opacity=0.5)
text = Text('foot', slant=ITALIC).scale(0.4)
text.next_to(circle_foot, RIGHT, buff=0.3)
circle_foot.add(text)
circle_foot.move_to([right_side_placement, -1, 0])

group = VGroup(*[circle_base, circle_joint, circle_foot])
self.play(FadeIn(group, shift=shift_vector, scale=0.9, run_time=1.0))

# Make most of the graph go away, but select one of each type to move to center
base_rect = rectangles[0]
joint_rect = rectangles[10]
foot_rect = rectangles[8]

text_new_title = Text("Node Representations", weight=BOLD, font="sans-serif").scale(1).move_to([-3,3,0])

embeddings_rects = VGroup(base_rect, joint_rect, foot_rect)
self.play(FadeOut(rectangles_vgroup-base_rect-joint_rect-foot_rect, shift=shift_vector, scale=0.9, run_time=1.0),
FadeOut(arrows, shift=shift_vector, scale=0.9, run_time=1.0),
embeddings_rects.animate.scale(1.5).arrange_in_grid(rows=3).move_to([-5,-0.5,0]),
ReplacementTransform(text_title, text_new_title))

# Add text explaining which each one represents
text_base = Text('The center of the robot with the IMU. \nData: [linear acceleration, \nangular velocity, angular acceleration]').scale(0.4)
text_base.next_to(base_rect, RIGHT, buff=0.3)

text_joint = Text('The joint motors on the quadruped legs. \nData: [position, velocity, \nacceleration, torque]').scale(0.4)
text_joint.next_to(joint_rect, RIGHT, buff=0.3)

text_foot = Text('The feet on the end-effectors. \nData: [ground reaction force]').scale(0.4)
text_foot.next_to(foot_rect, RIGHT, buff=0.3)

self.play(FadeIn(text_base), shift=shift_vector, scale=0.9, run_time=1.0)
self.wait(2)
self.play(FadeIn(text_joint), shift=shift_vector, scale=0.9, run_time=1.0)
self.wait(2)
self.play(FadeIn(text_foot), shift=shift_vector, scale=0.9, run_time=1.0)
self.wait(2)

text_base_new = MathTex('[a, \omega, \dot{\omega}]').scale(1)
text_base_new.next_to(base_rect, RIGHT, buff=0.3)

text_joint_new = MathTex('[x, \dot{x}, \ddot{x}, \\tau]').scale(1)
text_joint_new.next_to(joint_rect, RIGHT, buff=0.3)

text_foot_new = MathTex('[f_{z}]', ).scale(1)
text_foot_new.next_to(foot_rect, RIGHT, buff=0.3)

# Create a Circle and give it text
text = Text('base').scale(2)
rect_1 = RoundedRectangle(corner_radius=0.5)
text_emb_title = Text("Node Inputs", weight=BOLD, font="sans-serif").scale(1).move_to([-4,3,0])

self.play(Create(rect_1))
self.play(Write(text))
self.play(ReplacementTransform(text_new_title, text_emb_title),
Transform(text_base, text_base_new),
Transform(text_joint, text_joint_new),
Transform(text_foot, text_foot_new))
self.wait(2)
Loading

0 comments on commit 5ca56a7

Please sign in to comment.