From 6b16e7e8a99bbb3224ca85a94c5b03076990621b Mon Sep 17 00:00:00 2001 From: Junjia Liu Date: Tue, 5 Mar 2024 16:13:58 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20the=20bug=20in=20TP-GMM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rofunc/learning/ml/tpgmm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rofunc/learning/ml/tpgmm.py b/rofunc/learning/ml/tpgmm.py index 0bf8957bf..91553a4d3 100644 --- a/rofunc/learning/ml/tpgmm.py +++ b/rofunc/learning/ml/tpgmm.py @@ -158,8 +158,9 @@ def poe(self, model: HMM, show_demo_idx: int) -> GMM: # get transformation for given demonstration. # We use the transformation of the first timestep as they are constant - if len(self.task_params['frame_origins'][0]) == 1: - A, b = self.demos_A_xdx[0][0], self.demos_b_xdx[0][0] + if len(self.task_params['frame_origins'][0]) == 1: # For new task parameters generation + A, b = self.demo_A_xdx[0][0], self.demo_b_xdx[0][0] # Attention: here we use self.demo_A_xdx not + # self.demos_A_xdx cause we just called get_A_b not get_related_matrix else: A, b = self.demos_A_xdx[show_demo_idx][0], self.demos_b_xdx[show_demo_idx][0]