Skip to content

Commit

Permalink
DRED: quantize r and p0 parameters with 8 bits
Browse files Browse the repository at this point in the history
Only code non-degenerate symbols, which makes the encoder faster
  • Loading branch information
jmvalin committed Nov 6, 2023
1 parent 98b8be0 commit 544b3e5
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 48 deletions.
2 changes: 1 addition & 1 deletion autogen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"

dnn/download_model.sh c99054d
dnn/download_model.sh 98b8be0

echo "Updating build configuration files, please wait...."

Expand Down
8 changes: 4 additions & 4 deletions dnn/dred_rdovae.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float
}


const opus_uint16 * DRED_rdovae_get_p0_pointer(void)
const opus_uint8 * DRED_rdovae_get_p0_pointer(void)
{
return &dred_p0_q15[0];
return &dred_p0_q8[0];
}

const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
{
return &dred_dead_zone_q10[0];
}

const opus_uint16 * DRED_rdovae_get_r_pointer(void)
const opus_uint8 * DRED_rdovae_get_r_pointer(void)
{
return &dred_r_q15[0];
return &dred_r_q8[0];
}

const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
Expand Down
4 changes: 2 additions & 2 deletions dnn/dred_rdovae.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ void DRED_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, cons

void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z);

const opus_uint16 * DRED_rdovae_get_p0_pointer(void);
const opus_uint8 * DRED_rdovae_get_p0_pointer(void);
const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void);
const opus_uint16 * DRED_rdovae_get_r_pointer(void);
const opus_uint8 * DRED_rdovae_get_r_pointer(void);
const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void);

#endif
12 changes: 6 additions & 6 deletions dnn/torch/rdovae/export_rdovae_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,20 @@ def dump_statistical_model(writer, qembedding):

quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
r_q15 = np.round(r * 2**15).astype(np.uint16)
p0_q15 = np.round(p0 * 2**15).astype(np.uint16)
r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)

print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)

writer.header.write(
f"""
extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
extern const opus_uint16 dred_r_q15[{levels * N}];
extern const opus_uint16 dred_p0_q15[{levels * N}];
extern const opus_uint8 dred_r_q8[{levels * N}];
extern const opus_uint8 dred_p0_q8[{levels * N}];
"""
)
Expand Down
53 changes: 28 additions & 25 deletions dnn/torch/rdovae/rdovae/rdovae.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def clip_weights(module):

return clip_weights

def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)

# RDOVAE module and submodules

class MyConv(nn.Module):
Expand Down Expand Up @@ -295,17 +298,17 @@ def forward(self, features):
device = x.device

# run encoding layer stack
x = torch.tanh(self.dense_1(x))
x = torch.cat([x, self.gru1(x)[0]], -1)
x = torch.cat([x, self.conv1(x)], -1)
x = torch.cat([x, self.gru2(x)[0]], -1)
x = torch.cat([x, self.conv2(x)], -1)
x = torch.cat([x, self.gru3(x)[0]], -1)
x = torch.cat([x, self.conv3(x)], -1)
x = torch.cat([x, self.gru4(x)[0]], -1)
x = torch.cat([x, self.conv4(x)], -1)
x = torch.cat([x, self.gru5(x)[0]], -1)
x = torch.cat([x, self.conv5(x)], -1)
x = n(torch.tanh(self.dense_1(x)))
x = torch.cat([x, n(self.gru1(x)[0])], -1)
x = torch.cat([x, n(self.conv1(x))], -1)
x = torch.cat([x, n(self.gru2(x)[0])], -1)
x = torch.cat([x, n(self.conv2(x))], -1)
x = torch.cat([x, n(self.gru3(x)[0])], -1)
x = torch.cat([x, n(self.conv3(x))], -1)
x = torch.cat([x, n(self.gru4(x)[0])], -1)
x = torch.cat([x, n(self.conv4(x))], -1)
x = torch.cat([x, n(self.gru5(x)[0])], -1)
x = torch.cat([x, n(self.conv5(x))], -1)
z = self.z_dense(x)

# init state for decoder
Expand Down Expand Up @@ -372,18 +375,18 @@ def forward(self, z, initial_state):
h5_state = gru_state[:,:,384:].contiguous()

# run decoding layer stack
x = torch.tanh(self.dense_1(z))

x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
x = torch.cat([x, self.conv1(x)], -1)
x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
x = torch.cat([x, self.conv2(x)], -1)
x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
x = torch.cat([x, self.conv3(x)], -1)
x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
x = torch.cat([x, self.conv4(x)], -1)
x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
x = torch.cat([x, self.conv5(x)], -1)
x = n(torch.tanh(self.dense_1(z)))

x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
x = torch.cat([x, n(self.conv1(x))], -1)
x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
x = torch.cat([x, n(self.conv2(x))], -1)
x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
x = torch.cat([x, n(self.conv3(x))], -1)
x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
x = torch.cat([x, n(self.conv4(x))], -1)
x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
x = torch.cat([x, n(self.conv5(x))], -1)

# output layer and reshaping
x10 = self.output(x)
Expand Down Expand Up @@ -451,7 +454,7 @@ def __init__(self,
cond_size2,
state_dim=24,
split_mode='split',
clip_weights=True,
clip_weights=False,
pvq_num_pulses=82,
state_dropout_rate=0):

Expand Down Expand Up @@ -487,7 +490,7 @@ def clip_weights(self):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)

def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24):
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):

enc_stride = self.enc_stride
dec_stride = self.dec_stride
Expand Down
2 changes: 1 addition & 1 deletion silk/dred_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126

/* Remove these two completely once DRED gets an extension number assigned. */
#define DRED_EXPERIMENTAL_VERSION 6
#define DRED_EXPERIMENTAL_VERSION 7
#define DRED_EXPERIMENTAL_BYTES 2


Expand Down
9 changes: 5 additions & 4 deletions silk/dred_decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,21 @@ static int sign_extend(int x, int b) {
return (x ^ m) - m;
}

static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
int i;
for (i=0;i<dim;i++) {
int q;
q = ec_laplace_decode_p0(dec, p0[i], r[i]);
if (r[i] == 0 || p0[i] == 255) q = 0;
else q = ec_laplace_decode_p0(dec, p0[i]<<7, r[i]<<7);
x[i] = q*256.f/(scale[i] == 0 ? 1 : scale[i]);
}
}

int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
{
const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
const opus_uint16 *r = DRED_rdovae_get_r_pointer();
const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_dec ec;
int q_level;
int i;
Expand Down
10 changes: 5 additions & 5 deletions silk/dred_encoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
}
}

static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
int i;
int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
Expand All @@ -238,16 +238,16 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *
}
for (i=0;i<dim;i++) {
/* Make the impossible actually impossible. */
if (r[i] == 0 || p0[i] >= 32767) q[i] = 0;
ec_laplace_encode_p0(enc, q[i], p0[i], r[i]);
if (r[i] == 0 || p0[i] == 255) q[i] = 0;
else ec_laplace_encode_p0(enc, q[i], p0[i]<<7, r[i]<<7);
}
}

int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) {
const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer();
const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
const opus_uint16 *r = DRED_rdovae_get_r_pointer();
const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_enc ec_encoder;

int q_level;
Expand Down

0 comments on commit 544b3e5

Please sign in to comment.