From ad4ca3d6e3ba0dee75b77428f549b876ed316a59 Mon Sep 17 00:00:00 2001 From: zhaoyi11 Date: Fri, 25 Oct 2024 17:44:34 +0300 Subject: [PATCH] Add OT based fingering. --- .../suite/tasks/piano_with_shadow_hands.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/robopianist/suite/tasks/piano_with_shadow_hands.py b/robopianist/suite/tasks/piano_with_shadow_hands.py index 80e0a92..a0ae9c2 100644 --- a/robopianist/suite/tasks/piano_with_shadow_hands.py +++ b/robopianist/suite/tasks/piano_with_shadow_hands.py @@ -17,6 +17,7 @@ from typing import List, Optional, Sequence, Tuple import numpy as np +from scipy.optimize import linear_sum_assignment from dm_control import mjcf from dm_control.composer import variation as base_variation from dm_control.composer.observation import observable @@ -134,6 +135,11 @@ def _set_rewards(self) -> None: ) if not self._disable_fingering_reward: self._reward_fn.add("fingering_reward", self._compute_fingering_reward) + else: + # use OT based fingering + print('Fingering is unavailable. OT fingering reward is used.') + self._reward_fn.add("ot_fingering_reward", self._compute_ot_fingering_reward) + if not self._disable_forearm_reward: self._reward_fn.add("forearm_reward", self._compute_forearm_reward) @@ -324,6 +330,44 @@ def _distance_finger_to_key( ) return float(np.mean(rews)) + def _compute_ot_fingering_reward(self, physics: mjcf.Physics) -> float: + """ OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """ + # calcuate fingertip positions + fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites] + fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites] + + # calcuate the positions of piano keys to press. + keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press + # if no key is pressed + if keys_to_press.shape[0] == 0: + return 1. + + # calculate key pos + key_pos = [] + for key in keys_to_press: + key_geom = self.piano.keys[key].geom[0] + key_geom_pos = physics.bind(key_geom).xpos.copy() + key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2] + key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0] + key_pos.append(key_geom_pos.copy()) + + # calcualte the distance between keys and fingers + dist = np.full((len(fingertip_pos), len(key_pos)), 100.) + for i, finger in enumerate(fingertip_pos): + for j, key in enumerate(key_pos): + dist[i, j] = np.linalg.norm(key - finger) + + # calculate the shortest distance + row_ind, col_ind = linear_sum_assignment(dist) + dist = dist[row_ind, col_ind] + rews = tolerance( + dist, + bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY), + margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10), + sigmoid="gaussian", + ) + return float(np.mean(rews)) + def _update_goal_state(self) -> None: # Observable callables get called after `after_step` but before # `should_terminate_episode`. Since we increment `self._t_idx` in `after_step`,