Skip to content

Commit

Permalink
Merge pull request #24 from zhaoyi11/rp1m
Browse files Browse the repository at this point in the history
Lift the requirement of human fingering with RP1M
  • Loading branch information
kevinzakka authored Nov 2, 2024
2 parents d9cde23 + ad4ca3d commit 0d9736c
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions robopianist/suite/tasks/piano_with_shadow_hands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Optional, Sequence, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment
from dm_control import mjcf
from dm_control.composer import variation as base_variation
from dm_control.composer.observation import observable
Expand Down Expand Up @@ -134,6 +135,11 @@ def _set_rewards(self) -> None:
)
if not self._disable_fingering_reward:
self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
else:
# use OT based fingering
print('Fingering is unavailable. OT fingering reward is used.')
self._reward_fn.add("ot_fingering_reward", self._compute_ot_fingering_reward)

if not self._disable_forearm_reward:
self._reward_fn.add("forearm_reward", self._compute_forearm_reward)

Expand Down Expand Up @@ -324,6 +330,44 @@ def _distance_finger_to_key(
)
return float(np.mean(rews))

def _compute_ot_fingering_reward(self, physics: mjcf.Physics) -> float:
""" OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
# calcuate fingertip positions
fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]

# calcuate the positions of piano keys to press.
keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
# if no key is pressed
if keys_to_press.shape[0] == 0:
return 1.

# calculate key pos
key_pos = []
for key in keys_to_press:
key_geom = self.piano.keys[key].geom[0]
key_geom_pos = physics.bind(key_geom).xpos.copy()
key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
key_pos.append(key_geom_pos.copy())

# calcualte the distance between keys and fingers
dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
for i, finger in enumerate(fingertip_pos):
for j, key in enumerate(key_pos):
dist[i, j] = np.linalg.norm(key - finger)

# calculate the shortest distance
row_ind, col_ind = linear_sum_assignment(dist)
dist = dist[row_ind, col_ind]
rews = tolerance(
dist,
bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
sigmoid="gaussian",
)
return float(np.mean(rews))

def _update_goal_state(self) -> None:
# Observable callables get called after `after_step` but before
# `should_terminate_episode`. Since we increment `self._t_idx` in `after_step`,
Expand Down

0 comments on commit 0d9736c

Please sign in to comment.