From bab5d985dcdf0dd9e72ddb0cc99faec73cfd6d14 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 26 Dec 2023 16:30:05 +0100 Subject: [PATCH] clean up --- src/losses/mmd.py | 2 +- src/visualization/__init__.py | 1 - src/visualization/multiscatter.py | 2 +- src/visualization/rainbow.py | 8 -------- 4 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/losses/mmd.py b/src/losses/mmd.py index 6e0f433..bd274cc 100644 --- a/src/losses/mmd.py +++ b/src/losses/mmd.py @@ -46,7 +46,7 @@ def mmd_loss(samples: Tensor, target: Tensor, kernel=gaussian_kernel, scales: Te Returns ------- - mmd: Maximum-Mean-Discrepancy between data and code. Tensor of shape () + mmd: Maximum-Mean-Discrepancy between data and code. Tensor of shape (N,) """ if scales == "auto": scales = torch.logspace(-6, 6, 13) diff --git a/src/visualization/__init__.py b/src/visualization/__init__.py index e4f4905..42aa7ca 100644 --- a/src/visualization/__init__.py +++ b/src/visualization/__init__.py @@ -1,5 +1,4 @@ - from .multiscatter import multiscatter, multiscatter_bp from .rainbow import Rainbow from .scatter import scatter, scatter_bp diff --git a/src/visualization/multiscatter.py b/src/visualization/multiscatter.py index ab49075..c580acd 100644 --- a/src/visualization/multiscatter.py +++ b/src/visualization/multiscatter.py @@ -43,4 +43,4 @@ def multiscatter_bp(samples: np.ndarray, layout: (int, int), **render_kwargs): plt.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0) - return fig \ No newline at end of file + return fig diff --git a/src/visualization/rainbow.py b/src/visualization/rainbow.py index 021851b..06f4d0c 100644 --- a/src/visualization/rainbow.py +++ b/src/visualization/rainbow.py @@ -38,11 +38,3 @@ def __call__(self, X, alpha=None, bytes=False): rgba = np.concatenate([rgb, alpha], axis=1) return rgba - - -r, g, b = 1, 1, 1 - -m = max(r, g, b) - - -