Skip to content

Commit

Permalink
Remove useless latent/state dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
jmvalin committed Nov 7, 2023
1 parent e806d39 commit 98ef642
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 35 deletions.
9 changes: 7 additions & 2 deletions dnn/dred_rdovae_enc.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "dred_rdovae_enc.h"
#include "os_support.h"
#include "dred_rdovae_constants.h"

static void conv1_cond_init(float *mem, int len, int dilation, int *init)
{
Expand All @@ -52,6 +53,8 @@ void dred_rdovae_encode_dframe(
const float *input /* i: double feature frame (concatenated) */
)
{
float padded_latents[DRED_PADDED_LATENT_DIM];
float padded_state[DRED_PADDED_STATE_DIM];
float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE
+ ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE];
float state_hidden[GDENSE1_OUT_SIZE];
Expand Down Expand Up @@ -96,9 +99,11 @@ void dred_rdovae_encode_dframe(
compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH);
output_index += ENC_CONV5_OUT_SIZE;

compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR);
compute_generic_dense(&model->enc_zdense, padded_latents, buffer, ACTIVATION_LINEAR);
OPUS_COPY(latents, padded_latents, DRED_LATENT_DIM);

/* next, calculate initial state */
compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH);
compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_LINEAR);
compute_generic_dense(&model->gdense2, padded_state, state_hidden, ACTIVATION_LINEAR);
OPUS_COPY(initial_state, padded_state, DRED_STATE_DIM);
}
61 changes: 47 additions & 14 deletions dnn/torch/rdovae/export_rdovae_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@

def dump_statistical_model(writer, w, name):
levels = w.shape[0]
N = w.shape[-1]

print("printing statistical model")
quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy()
Expand All @@ -62,13 +61,20 @@ def dump_statistical_model(writer, w, name):

quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)

mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
quant_scales_q8 = quant_scales_q8[:, mask]
dead_zone_q10 = dead_zone_q10[:, mask]
r_q8 = r_q8[:, mask]
p0_q8 = p0_q8[:, mask]
N = r_q8.shape[-1]

print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
print_vector(writer.source, r_q15, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, p0_q15, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)

writer.header.write(
f"""
Expand All @@ -79,6 +85,7 @@ def dump_statistical_model(writer, w, name):
"""
)
return N, mask


def c_export(args, model):
Expand Down Expand Up @@ -112,18 +119,40 @@ def c_export(args, model):
"""
)

latent_out = model.get_submodule('core_encoder.module.z_dense').weight
states_out = model.get_submodule('core_encoder.module.state_dense_2').weight
nb_latents = latent_out.shape[0]
nb_states = states_out.shape[0]
latent_out = model.get_submodule('core_encoder.module.z_dense')
state_out = model.get_submodule('core_encoder.module.state_dense_2')
orig_latent_dim = latent_out.weight.shape[0]
orig_state_dim = state_out.weight.shape[0]
# statistical model
qembedding = model.statistical_model.quant_embedding.weight.detach()
levels = qembedding.shape[0]
qembedding = torch.reshape(qembedding, (levels, 6, -1))

dump_statistical_model(stats_writer, qembedding[:, :, :nb_latents], 'latents')
dump_statistical_model(stats_writer, qembedding[:, :, nb_latents:], 'states')

latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')

padded_latent_dim = (latent_dim+7)//8*8
latent_pad = padded_latent_dim - latent_dim;
w = latent_out.weight[latent_mask,:]
w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
b = latent_out.bias[latent_mask]
b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
latent_out.weight = torch.nn.Parameter(w)
latent_out.bias = torch.nn.Parameter(b)

padded_state_dim = (state_dim+7)//8*8
state_pad = padded_state_dim - state_dim;
w = state_out.weight[state_mask,:]
w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
b = state_out.bias[state_mask]
b = torch.cat([b, torch.zeros(state_pad)], dim=0)
state_out.weight = torch.nn.Parameter(w)
state_out.bias = torch.nn.Parameter(b)

latent_in = model.get_submodule('core_decoder.module.dense_1')
state_in = model.get_submodule('core_decoder.module.hidden_init')
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask])
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask])

# encoder
encoder_dense_layers = [
Expand Down Expand Up @@ -206,9 +235,13 @@ def c_export(args, model):
f"""
#define DRED_NUM_FEATURES {model.feature_dim}
#define DRED_LATENT_DIM {model.latent_dim}
#define DRED_LATENT_DIM {latent_dim}
#define DRED_STATE_DIM {state_dim}
#define DRED_PADDED_LATENT_DIM {padded_latent_dim}
#define DRED_STATE_DIME {model.state_dim}
#define DRED_PADDED_STATE_DIM {padded_state_dim}
#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
Expand Down
5 changes: 3 additions & 2 deletions dnn/torch/weight-exchange/wexchange/c_export/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def extract_diagonal(A):
return diag, B

def quantize_weight(weight, scale):
scale = scale + 1e-30
Aq = np.round(weight / scale).astype('int')
if Aq.max() > 127 or Aq.min() <= -128:
raise ValueError("value out of bounds in quantize_weight")
Expand Down Expand Up @@ -227,7 +228,7 @@ def print_linear_layer(writer : CWriter,

nb_inputs, nb_outputs = weight.shape

if scale is None:
if scale is None and quantize:
scale = compute_scaling(weight)


Expand Down Expand Up @@ -359,4 +360,4 @@ def print_gru_layer(writer : CWriter,
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {N}\n")
writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n")

return N
return N
3 changes: 0 additions & 3 deletions silk/dred_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
#define DRED_MIN_BYTES 16

/* these are inpart duplicates to the values defined in dred_rdovae_constants.h */
#define DRED_NUM_FEATURES 20
#define DRED_LATENT_DIM 80
#define DRED_STATE_DIM 80
#define DRED_SILK_ENCODER_DELAY (79+12-80)
#define DRED_FRAME_SIZE 160
#define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE))
Expand Down
13 changes: 7 additions & 6 deletions silk/dred_decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "celt/entdec.h"
#include "celt/laplace.h"
#include "dred_rdovae_stats_data.h"
#include "dred_rdovae_constants.h"

/* From http://graphics.stanford.edu/~seander/bithacks.html#FixedSignExtend */
static int sign_extend(int x, int b) {
Expand Down Expand Up @@ -80,9 +81,9 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
dred_decode_latents(
&ec,
dec->state,
dred_states_quant_scales_q8 + state_qoffset,
dred_states_r_q8 + state_qoffset,
dred_states_p0_q8 + state_qoffset,
dred_state_quant_scales_q8 + state_qoffset,
dred_state_r_q8 + state_qoffset,
dred_state_p0_q8 + state_qoffset,
DRED_STATE_DIM);

/* decode newest to oldest and store oldest to newest */
Expand All @@ -96,9 +97,9 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
dred_decode_latents(
&ec,
&dec->latents[(i/2)*DRED_LATENT_DIM],
dred_latents_quant_scales_q8 + offset,
dred_latents_r_q8 + offset,
dred_latents_p0_q8 + offset,
dred_latent_quant_scales_q8 + offset,
dred_latent_r_q8 + offset,
dred_latent_p0_q8 + offset,
DRED_LATENT_DIM
);

Expand Down
1 change: 1 addition & 0 deletions silk/dred_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "dred_config.h"
#include "dred_rdovae.h"
#include "entcode.h"
#include "dred_rdovae_constants.h"

struct OpusDRED {
float fec_features[2*DRED_NUM_REDUNDANCY_FRAMES*DRED_NUM_FEATURES];
Expand Down
16 changes: 8 additions & 8 deletions silk/dred_encoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
dred_encode_latents(
&ec_encoder,
enc->initial_state,
dred_states_quant_scales_q8 + state_qoffset,
dred_states_dead_zone_q10 + state_qoffset,
dred_states_r_q8 + state_qoffset,
dred_states_p0_q8 + state_qoffset,
dred_state_quant_scales_q8 + state_qoffset,
dred_state_dead_zone_q10 + state_qoffset,
dred_state_r_q8 + state_qoffset,
dred_state_p0_q8 + state_qoffset,
DRED_STATE_DIM);
if (ec_tell(&ec_encoder) > 8*max_bytes) {
return 0;
Expand All @@ -285,10 +285,10 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
dred_encode_latents(
&ec_encoder,
enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM,
dred_latents_quant_scales_q8 + offset,
dred_latents_dead_zone_q10 + offset,
dred_latents_r_q8 + offset,
dred_latents_p0_q8 + offset,
dred_latent_quant_scales_q8 + offset,
dred_latent_dead_zone_q10 + offset,
dred_latent_r_q8 + offset,
dred_latent_p0_q8 + offset,
DRED_LATENT_DIM
);
if (ec_tell(&ec_encoder) > 8*max_bytes) {
Expand Down

0 comments on commit 98ef642

Please sign in to comment.