Skip to content

Commit

Permalink
Update logic for real-valued SSM (S4D-Real)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertfgu committed Jul 8, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 94d0257 commit 070384a
Showing 11 changed files with 62 additions and 44 deletions.
8 changes: 4 additions & 4 deletions configs/experiment/mega/lra-image/README.md
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ python -m train experiment=mega/lra-image/large-mega-s4d
Same model but replacing the EMA component with original (complex) S4D.

```
python -m train experiment=mega/lra-image/large-mega-s4d '~model.layer.disc' '~model.layer.force_real' model.layer.mode=nplr model.layer.measure=legs
python -m train experiment=mega/lra-image/large-mega-s4d '~model.layer.disc' '~model.layer.is_real' model.layer.mode=nplr model.layer.measure=legs
```
Same model but replacing S4D with S4.

@@ -117,15 +117,15 @@ python -m train experiment=mega/lra-image/small-mega-s4d-real
```
Same as above, but replaces EMA with an S4D layer that is forced to be real-valued instead of complex-valued.

**Details**: Exactly the same as above but replaces `EMAKernel` with `SSMKernelDiag` with the option `force_real=True`. Note that the latter has more features, but the minimal version of it ([here](https://github.com/HazyResearch/state-spaces/blob/17663f26f7e91f88757e1d61318ed216dfb8a8a5/src/models/s4/s4d.py#L16)) is nearly identical to the EMA kernel.
**Details**: Exactly the same as above but replaces `EMAKernel` with `SSMKernelDiag` with the option `is_real=True`. Note that the latter has more features, but the minimal version of it ([here](https://github.com/HazyResearch/state-spaces/blob/17663f26f7e91f88757e1d61318ed216dfb8a8a5/src/models/s4/s4d.py#L16)) is nearly identical to the EMA kernel.

```
python -m train experiment=mega/lra-image/small-mega-s4d
```
Same as above, but with the original (complex-valued) S4D layer.

```
python -m train experiment=mega/lra-image/small-mega-s4d '~model.layer.disc' '~model.layer.force_real' model.layer.mode=nplr model.layer.measure=legs
python -m train experiment=mega/lra-image/small-mega-s4d '~model.layer.disc' '~model.layer.is_real' model.layer.mode=nplr model.layer.measure=legs
```
Same model but replacing S4D with S4.

@@ -159,7 +159,7 @@ python -m train experiment=mega/lra-image/small-ema-with-s4d
Same as above, but use settings to match the parameter count of S4D.

```
python -m train experiment=mega/lra-image/small-s4 '~model.layer.disc' '~model.layer.force_real' model.layer.mode=nplr model.layer.measure=legs
python -m train experiment=mega/lra-image/small-s4 '~model.layer.disc' '~model.layer.is_real' model.layer.mode=nplr model.layer.measure=legs
```
Same model but replacing S4 with S4D.

2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/large-mega-s4d-real.yaml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ model:
disc: zoh
init: diag-real # Simple initialization for real
lr: 0.001
force_real: true # Throw away imaginary part of A
is_real: true # Throw away imaginary part of A

dataset:
grayscale: true
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/large-mega-s4d.yaml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ model:
disc: zoh
init: diag-lin
lr: 0.001
force_real: false # Throw away imaginary part
is_real: false # Throw away imaginary part

dataset:
grayscale: true
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/large-s4d-real.yaml
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ model:
init: diag-real # Simple initialization for real
disc: zoh
lr: 0.001
force_real: true # Throw away imaginary part
is_real: true # Throw away imaginary part
n_ssm: null # Set to 1 for smaller parameter count of original config

dataset:
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/large-s4d.yaml
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ model:
init: diag-lin
disc: zoh
lr: 0.001
force_real: false
is_real: false
n_ssm: null # Set to 1 for smaller parameter count of original config

dataset:
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/small-mega-s4d-real.yaml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ model:
disc: zoh
init: diag-real # Simple initialization for real
lr: 0.001
force_real: true # Throw away imaginary part
is_real: true # Throw away imaginary part

dataset:
grayscale: true
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/small-mega-s4d.yaml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ model:
disc: zoh
init: diag-lin
lr: 0.001
force_real: false
is_real: false

dataset:
grayscale: true
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/small-s4d-real.yaml
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ model:
disc: zoh
n_ssm: null # Don't tie A/B params to param match EMA
lr: 0.001
force_real: true # Throw away imaginary part
is_real: true # Throw away imaginary part

dataset:
grayscale: true
2 changes: 1 addition & 1 deletion configs/experiment/mega/lra-image/small-s4d.yaml
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ model:
disc: zoh
lr: 0.001
n_ssm: null # Don't tie A/B params to param match EMA
force_real: false
is_real: false

dataset:
grayscale: true
41 changes: 25 additions & 16 deletions models/s4/s4.py
Original file line number Diff line number Diff line change
@@ -986,7 +986,7 @@ class SSMKernelDiag(SSMKernel):
bandlimit: Mask high frequencies of the kernel (indices corresponding to
diagonal elements with large imaginary part). Introduced in S4ND paper.
backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
force_real : Force A to have 0 imaginary part, to emulate EMA.
is_real : Real-valued SSM; can be interpreted as EMA.
"""

def __init__(
@@ -997,17 +997,20 @@ def __init__(
imag_transform: str = 'none',
bandlimit: Optional[float] = None,
backend: str = 'cuda',
force_real: bool = False,
is_real: bool = False,
**kwargs,
):
# Special case: for real-valued, d_state semantics change
if is_real and 'd_state' in kwargs:
kwargs['d_state'] = kwargs['d_state'] * 2
super().__init__(**kwargs)
self.disc = disc
self.dt_fast = dt_fast
self.real_transform = real_transform
self.imag_transform = imag_transform
self.bandlimit = bandlimit
self.backend = backend
self.force_real = force_real
self.is_real = is_real

# Initialize dt, A, B, C
inv_dt = self.init_dt()
@@ -1053,32 +1056,38 @@ def register_params(self, A, B, C, inv_dt, P):
# Broadcast everything to correct shapes
C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) # TODO originally this was only in DPLR, check safe for Diag
B = B.unsqueeze(0) # (1, H, N)

assert self.channels == C.shape[0]
self.C = nn.Parameter(_c2r(_resolve_conj(C)))

# Register dt, B, A
# Register dt
self.register("inv_dt", inv_dt, self.lr_dict['dt'], self.wd_dict['dt'])
self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B'])
self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A'])
self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A'])
# Register ABC
if self.is_real:
self.register("C", C.real, self.lr_dict['C'], None)
self.register("B", B.real, self.lr_dict['B'], self.wd_dict['B'])
self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A'])
else:
self.register("C", _c2r(_resolve_conj(C)), self.lr_dict['C'], None)
self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B'])
self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A'])
self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A'])

def _get_params(self, rate=1.0):
"""Process the internal parameters."""

# (S N) where S=n_ssm
A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform)
B = _r2c(self.B) # (1 S N)
C = _r2c(self.C) # (C H N)
if self.is_real:
A = -param_transform(self.A_real, self.real_transform)
B = self.B # (1 S N)
C = self.C # (C H N)
else:
A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform)
B = _r2c(self.B) # (1 S N)
C = _r2c(self.C) # (C H N)

if self.dt_fast: inv_dt = torch.sinh(self.inv_dt)
else: inv_dt = self.inv_dt
dt = param_transform(inv_dt, self.dt_transform) * rate # (H N)

# Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA"
if self.force_real:
A = A.real + 0j

if self.bandlimit is not None:
freqs = dt / rate * A.imag.abs() / (2*math.pi) # (H N)
mask = torch.where(freqs < self.bandlimit * .5, 1, 0)
41 changes: 25 additions & 16 deletions src/models/sequence/kernels/ssm.py
Original file line number Diff line number Diff line change
@@ -448,7 +448,7 @@ class SSMKernelDiag(SSMKernel):
bandlimit: Mask high frequencies of the kernel (indices corresponding to
diagonal elements with large imaginary part). Introduced in S4ND paper.
backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
force_real : Force A to have 0 imaginary part, to emulate EMA.
is_real : Real-valued SSM; can be interpreted as EMA.
"""

def __init__(
@@ -459,17 +459,20 @@ def __init__(
imag_transform: str = 'none',
bandlimit: Optional[float] = None,
backend: str = 'cuda',
force_real: bool = False,
is_real: bool = False,
**kwargs,
):
# Special case: for real-valued, d_state semantics change
if is_real and 'd_state' in kwargs:
kwargs['d_state'] = kwargs['d_state'] * 2
super().__init__(**kwargs)
self.disc = disc
self.dt_fast = dt_fast
self.real_transform = real_transform
self.imag_transform = imag_transform
self.bandlimit = bandlimit
self.backend = backend
self.force_real = force_real
self.is_real = is_real

# Initialize dt, A, B, C
inv_dt = self.init_dt()
@@ -515,32 +518,38 @@ def register_params(self, A, B, C, inv_dt, P):
# Broadcast everything to correct shapes
C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) # TODO originally this was only in DPLR, check safe for Diag
B = B.unsqueeze(0) # (1, H, N)

assert self.channels == C.shape[0]
self.C = nn.Parameter(_c2r(_resolve_conj(C)))

# Register dt, B, A
# Register dt
self.register("inv_dt", inv_dt, self.lr_dict['dt'], self.wd_dict['dt'])
self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B'])
self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A'])
self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A'])
# Register ABC
if self.is_real:
self.register("C", C.real, self.lr_dict['C'], None)
self.register("B", B.real, self.lr_dict['B'], self.wd_dict['B'])
self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A'])
else:
self.register("C", _c2r(_resolve_conj(C)), self.lr_dict['C'], None)
self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B'])
self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A'])
self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A'])

def _get_params(self, rate=1.0):
"""Process the internal parameters."""

# (S N) where S=n_ssm
A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform)
B = _r2c(self.B) # (1 S N)
C = _r2c(self.C) # (C H N)
if self.is_real:
A = -param_transform(self.A_real, self.real_transform)
B = self.B # (1 S N)
C = self.C # (C H N)
else:
A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform)
B = _r2c(self.B) # (1 S N)
C = _r2c(self.C) # (C H N)

if self.dt_fast: inv_dt = torch.sinh(self.inv_dt)
else: inv_dt = self.inv_dt
dt = param_transform(inv_dt, self.dt_transform) * rate # (H N)

# Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA"
if self.force_real:
A = A.real + 0j

if self.bandlimit is not None:
freqs = dt / rate * A.imag.abs() / (2*math.pi) # (H N)
mask = torch.where(freqs < self.bandlimit * .5, 1, 0)

0 comments on commit 070384a

Please sign in to comment.