From 070384aac428583d0b817c0e5393078bfef982f0 Mon Sep 17 00:00:00 2001 From: Albert Gu Date: Sat, 8 Jul 2023 17:13:45 +0000 Subject: [PATCH] Update logic for real-valued SSM (S4D-Real) --- configs/experiment/mega/lra-image/README.md | 8 ++-- .../mega/lra-image/large-mega-s4d-real.yaml | 2 +- .../mega/lra-image/large-mega-s4d.yaml | 2 +- .../mega/lra-image/large-s4d-real.yaml | 2 +- .../experiment/mega/lra-image/large-s4d.yaml | 2 +- .../mega/lra-image/small-mega-s4d-real.yaml | 2 +- .../mega/lra-image/small-mega-s4d.yaml | 2 +- .../mega/lra-image/small-s4d-real.yaml | 2 +- .../experiment/mega/lra-image/small-s4d.yaml | 2 +- models/s4/s4.py | 41 +++++++++++-------- src/models/sequence/kernels/ssm.py | 41 +++++++++++-------- 11 files changed, 62 insertions(+), 44 deletions(-) diff --git a/configs/experiment/mega/lra-image/README.md b/configs/experiment/mega/lra-image/README.md index 589b0b0..6eb9d2b 100644 --- a/configs/experiment/mega/lra-image/README.md +++ b/configs/experiment/mega/lra-image/README.md @@ -77,7 +77,7 @@ python -m train experiment=mega/lra-image/large-mega-s4d Same model but replacing the EMA component with original (complex) S4D. ``` -python -m train experiment=mega/lra-image/large-mega-s4d '~model.layer.disc' '~model.layer.force_real' model.layer.mode=nplr model.layer.measure=legs +python -m train experiment=mega/lra-image/large-mega-s4d '~model.layer.disc' '~model.layer.is_real' model.layer.mode=nplr model.layer.measure=legs ``` Same model but replacing S4D with S4. @@ -117,7 +117,7 @@ python -m train experiment=mega/lra-image/small-mega-s4d-real ``` Same as above, but replaces EMA with an S4D layer that is forced to be real-valued instead of complex-valued. -**Details**: Exactly the same as above but replaces `EMAKernel` with `SSMKernelDiag` with the option `force_real=True`. Note that the latter has more features, but the minimal version of it ([here](https://github.com/HazyResearch/state-spaces/blob/17663f26f7e91f88757e1d61318ed216dfb8a8a5/src/models/s4/s4d.py#L16)) is nearly identical to the EMA kernel. +**Details**: Exactly the same as above but replaces `EMAKernel` with `SSMKernelDiag` with the option `is_real=True`. Note that the latter has more features, but the minimal version of it ([here](https://github.com/HazyResearch/state-spaces/blob/17663f26f7e91f88757e1d61318ed216dfb8a8a5/src/models/s4/s4d.py#L16)) is nearly identical to the EMA kernel. ``` python -m train experiment=mega/lra-image/small-mega-s4d @@ -125,7 +125,7 @@ python -m train experiment=mega/lra-image/small-mega-s4d Same as above, but with the original (complex-valued) S4D layer. ``` -python -m train experiment=mega/lra-image/small-mega-s4d '~model.layer.disc' '~model.layer.force_real' model.layer.mode=nplr model.layer.measure=legs +python -m train experiment=mega/lra-image/small-mega-s4d '~model.layer.disc' '~model.layer.is_real' model.layer.mode=nplr model.layer.measure=legs ``` Same model but replacing S4D with S4. @@ -159,7 +159,7 @@ python -m train experiment=mega/lra-image/small-ema-with-s4d Same as above, but use settings to match the parameter count of S4D. ``` -python -m train experiment=mega/lra-image/small-s4 '~model.layer.disc' '~model.layer.force_real' model.layer.mode=nplr model.layer.measure=legs +python -m train experiment=mega/lra-image/small-s4 '~model.layer.disc' '~model.layer.is_real' model.layer.mode=nplr model.layer.measure=legs ``` Same model but replacing S4 with S4D. diff --git a/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml b/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml index 253d9ad..4b9f9b3 100644 --- a/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml +++ b/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml @@ -22,7 +22,7 @@ model: disc: zoh init: diag-real # Simple initialization for real lr: 0.001 - force_real: true # Throw away imaginary part of A + is_real: true # Throw away imaginary part of A dataset: grayscale: true diff --git a/configs/experiment/mega/lra-image/large-mega-s4d.yaml b/configs/experiment/mega/lra-image/large-mega-s4d.yaml index 1f37d1f..bb8fdf8 100644 --- a/configs/experiment/mega/lra-image/large-mega-s4d.yaml +++ b/configs/experiment/mega/lra-image/large-mega-s4d.yaml @@ -22,7 +22,7 @@ model: disc: zoh init: diag-lin lr: 0.001 - force_real: false # Throw away imaginary part + is_real: false # Throw away imaginary part dataset: grayscale: true diff --git a/configs/experiment/mega/lra-image/large-s4d-real.yaml b/configs/experiment/mega/lra-image/large-s4d-real.yaml index 0b23ab8..c40cf45 100644 --- a/configs/experiment/mega/lra-image/large-s4d-real.yaml +++ b/configs/experiment/mega/lra-image/large-s4d-real.yaml @@ -17,7 +17,7 @@ model: init: diag-real # Simple initialization for real disc: zoh lr: 0.001 - force_real: true # Throw away imaginary part + is_real: true # Throw away imaginary part n_ssm: null # Set to 1 for smaller parameter count of original config dataset: diff --git a/configs/experiment/mega/lra-image/large-s4d.yaml b/configs/experiment/mega/lra-image/large-s4d.yaml index be52f28..08abbec 100644 --- a/configs/experiment/mega/lra-image/large-s4d.yaml +++ b/configs/experiment/mega/lra-image/large-s4d.yaml @@ -17,7 +17,7 @@ model: init: diag-lin disc: zoh lr: 0.001 - force_real: false + is_real: false n_ssm: null # Set to 1 for smaller parameter count of original config dataset: diff --git a/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml b/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml index f2f7431..9190c88 100644 --- a/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml +++ b/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml @@ -22,7 +22,7 @@ model: disc: zoh init: diag-real # Simple initialization for real lr: 0.001 - force_real: true # Throw away imaginary part + is_real: true # Throw away imaginary part dataset: grayscale: true diff --git a/configs/experiment/mega/lra-image/small-mega-s4d.yaml b/configs/experiment/mega/lra-image/small-mega-s4d.yaml index 7c3a6ef..8627cc8 100644 --- a/configs/experiment/mega/lra-image/small-mega-s4d.yaml +++ b/configs/experiment/mega/lra-image/small-mega-s4d.yaml @@ -22,7 +22,7 @@ model: disc: zoh init: diag-lin lr: 0.001 - force_real: false + is_real: false dataset: grayscale: true diff --git a/configs/experiment/mega/lra-image/small-s4d-real.yaml b/configs/experiment/mega/lra-image/small-s4d-real.yaml index 8f8d679..2305c88 100644 --- a/configs/experiment/mega/lra-image/small-s4d-real.yaml +++ b/configs/experiment/mega/lra-image/small-s4d-real.yaml @@ -19,7 +19,7 @@ model: disc: zoh n_ssm: null # Don't tie A/B params to param match EMA lr: 0.001 - force_real: true # Throw away imaginary part + is_real: true # Throw away imaginary part dataset: grayscale: true diff --git a/configs/experiment/mega/lra-image/small-s4d.yaml b/configs/experiment/mega/lra-image/small-s4d.yaml index 41882c8..eeee961 100644 --- a/configs/experiment/mega/lra-image/small-s4d.yaml +++ b/configs/experiment/mega/lra-image/small-s4d.yaml @@ -20,7 +20,7 @@ model: disc: zoh lr: 0.001 n_ssm: null # Don't tie A/B params to param match EMA - force_real: false + is_real: false dataset: grayscale: true diff --git a/models/s4/s4.py b/models/s4/s4.py index 8614243..caec2a8 100644 --- a/models/s4/s4.py +++ b/models/s4/s4.py @@ -986,7 +986,7 @@ class SSMKernelDiag(SSMKernel): bandlimit: Mask high frequencies of the kernel (indices corresponding to diagonal elements with large imaginary part). Introduced in S4ND paper. backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency). - force_real : Force A to have 0 imaginary part, to emulate EMA. + is_real : Real-valued SSM; can be interpreted as EMA. """ def __init__( @@ -997,9 +997,12 @@ def __init__( imag_transform: str = 'none', bandlimit: Optional[float] = None, backend: str = 'cuda', - force_real: bool = False, + is_real: bool = False, **kwargs, ): + # Special case: for real-valued, d_state semantics change + if is_real and 'd_state' in kwargs: + kwargs['d_state'] = kwargs['d_state'] * 2 super().__init__(**kwargs) self.disc = disc self.dt_fast = dt_fast @@ -1007,7 +1010,7 @@ def __init__( self.imag_transform = imag_transform self.bandlimit = bandlimit self.backend = backend - self.force_real = force_real + self.is_real = is_real # Initialize dt, A, B, C inv_dt = self.init_dt() @@ -1053,32 +1056,38 @@ def register_params(self, A, B, C, inv_dt, P): # Broadcast everything to correct shapes C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) # TODO originally this was only in DPLR, check safe for Diag B = B.unsqueeze(0) # (1, H, N) - assert self.channels == C.shape[0] - self.C = nn.Parameter(_c2r(_resolve_conj(C))) - # Register dt, B, A + # Register dt self.register("inv_dt", inv_dt, self.lr_dict['dt'], self.wd_dict['dt']) - self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B']) - self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) - self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A']) + # Register ABC + if self.is_real: + self.register("C", C.real, self.lr_dict['C'], None) + self.register("B", B.real, self.lr_dict['B'], self.wd_dict['B']) + self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) + else: + self.register("C", _c2r(_resolve_conj(C)), self.lr_dict['C'], None) + self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B']) + self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) + self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A']) def _get_params(self, rate=1.0): """Process the internal parameters.""" # (S N) where S=n_ssm - A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform) - B = _r2c(self.B) # (1 S N) - C = _r2c(self.C) # (C H N) + if self.is_real: + A = -param_transform(self.A_real, self.real_transform) + B = self.B # (1 S N) + C = self.C # (C H N) + else: + A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform) + B = _r2c(self.B) # (1 S N) + C = _r2c(self.C) # (C H N) if self.dt_fast: inv_dt = torch.sinh(self.inv_dt) else: inv_dt = self.inv_dt dt = param_transform(inv_dt, self.dt_transform) * rate # (H N) - # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" - if self.force_real: - A = A.real + 0j - if self.bandlimit is not None: freqs = dt / rate * A.imag.abs() / (2*math.pi) # (H N) mask = torch.where(freqs < self.bandlimit * .5, 1, 0) diff --git a/src/models/sequence/kernels/ssm.py b/src/models/sequence/kernels/ssm.py index 725540f..269ccb0 100644 --- a/src/models/sequence/kernels/ssm.py +++ b/src/models/sequence/kernels/ssm.py @@ -448,7 +448,7 @@ class SSMKernelDiag(SSMKernel): bandlimit: Mask high frequencies of the kernel (indices corresponding to diagonal elements with large imaginary part). Introduced in S4ND paper. backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency). - force_real : Force A to have 0 imaginary part, to emulate EMA. + is_real : Real-valued SSM; can be interpreted as EMA. """ def __init__( @@ -459,9 +459,12 @@ def __init__( imag_transform: str = 'none', bandlimit: Optional[float] = None, backend: str = 'cuda', - force_real: bool = False, + is_real: bool = False, **kwargs, ): + # Special case: for real-valued, d_state semantics change + if is_real and 'd_state' in kwargs: + kwargs['d_state'] = kwargs['d_state'] * 2 super().__init__(**kwargs) self.disc = disc self.dt_fast = dt_fast @@ -469,7 +472,7 @@ def __init__( self.imag_transform = imag_transform self.bandlimit = bandlimit self.backend = backend - self.force_real = force_real + self.is_real = is_real # Initialize dt, A, B, C inv_dt = self.init_dt() @@ -515,32 +518,38 @@ def register_params(self, A, B, C, inv_dt, P): # Broadcast everything to correct shapes C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) # TODO originally this was only in DPLR, check safe for Diag B = B.unsqueeze(0) # (1, H, N) - assert self.channels == C.shape[0] - self.C = nn.Parameter(_c2r(_resolve_conj(C))) - # Register dt, B, A + # Register dt self.register("inv_dt", inv_dt, self.lr_dict['dt'], self.wd_dict['dt']) - self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B']) - self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) - self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A']) + # Register ABC + if self.is_real: + self.register("C", C.real, self.lr_dict['C'], None) + self.register("B", B.real, self.lr_dict['B'], self.wd_dict['B']) + self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) + else: + self.register("C", _c2r(_resolve_conj(C)), self.lr_dict['C'], None) + self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B']) + self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) + self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A']) def _get_params(self, rate=1.0): """Process the internal parameters.""" # (S N) where S=n_ssm - A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform) - B = _r2c(self.B) # (1 S N) - C = _r2c(self.C) # (C H N) + if self.is_real: + A = -param_transform(self.A_real, self.real_transform) + B = self.B # (1 S N) + C = self.C # (C H N) + else: + A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform) + B = _r2c(self.B) # (1 S N) + C = _r2c(self.C) # (C H N) if self.dt_fast: inv_dt = torch.sinh(self.inv_dt) else: inv_dt = self.inv_dt dt = param_transform(inv_dt, self.dt_transform) * rate # (H N) - # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" - if self.force_real: - A = A.real + 0j - if self.bandlimit is not None: freqs = dt / rate * A.imag.abs() / (2*math.pi) # (H N) mask = torch.where(freqs < self.bandlimit * .5, 1, 0)