Skip to content

Commit

Permalink
Add quantization-aware EfficientNetV2 implementation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611249382
  • Loading branch information
sdenton4 authored and copybara-github committed Feb 28, 2024
1 parent 7ec6093 commit 57c205f
Show file tree
Hide file tree
Showing 10 changed files with 590 additions and 30 deletions.
8 changes: 6 additions & 2 deletions chirp/configs/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def get_config() -> config_dict.ConfigDict:
'efficientnet.EfficientNetModel',
value='b1',
),
op_set='default',
)
model_config.taxonomy_loss_weight = 0.001
model_config.frontend = presets.get_bio_pcen_melspec_config(config)
model_config.frontend = presets.get_new_pcen_melspec_config(config)
config.init_config.model_config = model_config
# Configure the training loop
config.train_config = presets.get_base_train_config(config)
Expand All @@ -63,4 +64,7 @@ def get_config() -> config_dict.ConfigDict:


def get_hyper(hyper):
return hyper.sweep('config.batch_size', hyper.discrete([256]))
return hyper.sweep(
'config.init_config.model_config.encoder.__config.op_set',
hyper.discrete(['default', 'qat']),
)
67 changes: 67 additions & 0 deletions chirp/configs/baseline_effnet_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration to run baseline model."""
from chirp import config_utils
from chirp.configs import presets
from ml_collections import config_dict

_c = config_utils.callable_config


def get_config() -> config_dict.ConfigDict:
"""Create configuration dictionary for training."""
config = presets.get_base_config()
# Configure the data
config.train_dataset_config = presets.get_supervised_train_pipeline(
config,
mixin_prob=0.75,
train_dataset_dir='bird_taxonomy/slice_peaked:1.4.0',
)
config.eval_dataset_config = presets.get_supervised_eval_pipeline(
config, 'soundscapes/powdermill:1.3.0'
)
# Configure the experiment setup
config.init_config = presets.get_classifier_init_config(config)
config.init_config.optimizer = _c(
'optax.adam', learning_rate=config.init_config.get_ref('learning_rate')
)
model_config = config_dict.ConfigDict()
model_config.encoder = _c(
'efficientnet_v2.EfficientNetV2',
model_name='efficientnetv2-s',
op_set='qat',
)
model_config.taxonomy_loss_weight = 0.001
model_config.frontend = presets.get_new_pcen_melspec_config(config)
config.init_config.model_config = model_config
# Configure the training loop
config.train_config = presets.get_base_train_config(config)
config.eval_config = presets.get_base_eval_config(config)

config.export_config = config_dict.ConfigDict()
config.export_config.input_shape = (
config.get_ref('eval_window_size_s') * config.get_ref('sample_rate_hz'),
)
config.export_config.num_train_steps = config.get_ref('num_train_steps')

return config


def get_hyper(hyper):
return hyper.sweep(
'config.init_config.rng_seed',
hyper.discrete([17, 42, 666]),
)
2 changes: 2 additions & 0 deletions chirp/configs/config_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from chirp.eval import eval_lib
from chirp.models import conformer
from chirp.models import efficientnet
from chirp.models import efficientnet_v2
from chirp.models import frontend
from chirp.models import handcrafted_features
from chirp.models import hubert
Expand All @@ -42,6 +43,7 @@ def get_globals() -> dict[str, Any]:
"config_utils": config_utils,
"conformer": conformer,
"efficientnet": efficientnet,
"efficientnet_v2": efficientnet_v2,
"eval_lib": eval_lib,
"hubert": hubert,
"quantizers": quantizers,
Expand Down
8 changes: 5 additions & 3 deletions chirp/configs/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_base_config(**kwargs):
config.num_channels = 128
config.batch_size = 256
config.add_taxonomic_labels = True
config.taxonomy_loss_weight = 0.001
config.target_class_list = 'xenocanto'
config.num_train_steps = 1_000_000
config.loss_fn = _o('optax.sigmoid_binary_cross_entropy')
Expand All @@ -57,6 +58,7 @@ def get_base_config(**kwargs):
config.update(kwargs)
return config


def get_base_init_config(
config: config_dict.ConfigDict, **kwargs
) -> config_dict.ConfigDict:
Expand Down Expand Up @@ -96,21 +98,21 @@ def get_classifier_init_config(
'train_utils.OutputHeadMetadata.from_mapping',
key='genus',
source_class_list_name=config.get_ref('target_class_list'),
weight=0.1,
weight=config.get_ref('taxonomy_loss_weight'),
mapping_name='ebird2021_to_genus',
),
_c(
'train_utils.OutputHeadMetadata.from_mapping',
key='family',
source_class_list_name=config.get_ref('target_class_list'),
weight=0.1,
weight=config.get_ref('taxonomy_loss_weight'),
mapping_name='ebird2021_to_family',
),
_c(
'train_utils.OutputHeadMetadata.from_mapping',
key='order',
source_class_list_name=config.get_ref('target_class_list'),
weight=0.1,
weight=config.get_ref('taxonomy_loss_weight'),
mapping_name='ebird2021_to_order',
),
)
Expand Down
71 changes: 63 additions & 8 deletions chirp/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
Implementation of the EfficientNet model in Flax.
"""
import dataclasses
import enum
import math
from typing import NamedTuple
from typing import Callable, NamedTuple

from aqt.jax.v2 import aqt_conv_general
from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import config as aqt_cfg
from chirp.models import layers
from flax import linen as nn
import flax.typing as flax_typing
from jax import numpy as jnp


Expand Down Expand Up @@ -109,6 +114,31 @@ def round_num_blocks(num_blocks: int, depth_coefficient: float) -> int:
return int(math.ceil(depth_coefficient * num_blocks))


@dataclasses.dataclass
class OpSet:
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
sigmoid: Callable[[jnp.ndarray], jnp.ndarray] = nn.sigmoid
stem_activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
head_activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
dot_general: flax_typing.DotGeneralT | None = None
conv_general_dilated: flax_typing.ConvGeneralDilatedT | None = None


op_sets = {
"default": OpSet(),
"qat": OpSet(
activation=nn.relu,
sigmoid=nn.hard_sigmoid,
stem_activation=nn.hard_swish,
head_activation=nn.hard_swish,
dot_general=aqt_dot_general.make_dot_general(None),
conv_general_dilated=aqt_conv_general.make_conv_general_dilated(
aqt_cfg.DotGeneralRaw.make_conv_general_dilated()
),
),
}


class Stem(nn.Module):
"""The stem of an EfficientNet model.
Expand All @@ -120,6 +150,8 @@ class Stem(nn.Module):
"""

features: int
conv_general_dilated: flax_typing.ConvGeneralDilatedT | None = None
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish

@nn.compact
def __call__(
Expand All @@ -140,10 +172,11 @@ def __call__(
kernel_size=(3, 3),
strides=2,
use_bias=False,
conv_general_dilated=self.conv_general_dilated,
padding="VALID",
)(inputs)
x = nn.BatchNorm(use_running_average=use_running_average)(x)
x = nn.swish(x)
x = self.activation(x)
return x


Expand All @@ -155,9 +188,12 @@ class Head(nn.Module):
Attributes:
features: The number of filters.
conv_general_dilated: Convolution op.
"""

features: int
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish
conv_general_dilated: flax_typing.ConvGeneralDilatedT | None = None

@nn.compact
def __call__(
Expand All @@ -174,10 +210,14 @@ def __call__(
A JAX array of `(batch size, height, width, features)`.
"""
x = nn.Conv(
features=self.features, kernel_size=(1, 1), strides=1, use_bias=False
features=self.features,
kernel_size=(1, 1),
strides=1,
use_bias=False,
conv_general_dilated=self.conv_general_dilated,
)(inputs)
x = nn.BatchNorm(use_running_average=use_running_average)(x)
x = nn.swish(x)
x = self.activation(x)
return x


Expand All @@ -192,13 +232,15 @@ class EfficientNet(nn.Module):
survival_probability: The survival probability to use for stochastic depth.
head: Optional Flax module to use as custom head.
stem: Optional Flax module to use as custom stem.
op_set: Named set of ops to use.
"""

model: EfficientNetModel
include_top: bool = True
survival_probability: float = 0.8
head: nn.Module | None = None
stem: nn.Module | None = None
op_set: str = "default"

@nn.compact
def __call__(
Expand All @@ -225,13 +267,19 @@ def __call__(
A JAX array of `(batch size, height, width, features)` if `include_top` is
false. If `include_top` is true the output is `(batch_size, features)`.
"""
ops = op_sets[self.op_set]

if use_running_average is None:
use_running_average = not train
scaling = SCALINGS[self.model]

if self.stem is None:
features = round_features(STEM_FEATURES, scaling.width_coefficient)
stem = Stem(features)
stem = Stem(
features,
activation=ops.stem_activation,
conv_general_dilated=ops.conv_general_dilated,
)
else:
stem = self.stem

Expand All @@ -248,11 +296,14 @@ def __call__(
strides=strides,
expand_ratio=stage.expand_ratio,
kernel_size=stage.kernel_size,
activation=nn.swish,
batch_norm=True,
reduction_ratio=REDUCTION_RATIO,
activation=ops.activation,
sigmoid_activation=ops.sigmoid,
dot_general=ops.dot_general,
conv_general_dilated=ops.conv_general_dilated,
)
y = mbconv(x, use_running_average=use_running_average)
y = mbconv(x, train=train, use_running_average=use_running_average)

# Stochastic depth
if block > 0 and self.survival_probability:
Expand All @@ -267,7 +318,11 @@ def __call__(

if self.head is None:
features = round_features(HEAD_FEATURES, scaling.width_coefficient)
head = Head(features)
head = Head(
features,
activation=ops.head_activation,
conv_general_dilated=ops.conv_general_dilated,
)
else:
head = self.head

Expand Down
Loading

0 comments on commit 57c205f

Please sign in to comment.