Skip to content

Commit

Permalink
Merge branch 'f/vision_attention' into 'main'
Browse files Browse the repository at this point in the history
Add vision transformer based search space

See merge request es/ai/hannah/hannah!379
  • Loading branch information
cgerum committed Apr 11, 2024
2 parents ad24411 + 13ccc3b commit 0661009
Show file tree
Hide file tree
Showing 10 changed files with 549 additions and 4 deletions.
37 changes: 37 additions & 0 deletions experiments/conv_vit/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defaults:
- base_config
- override nas: aging_evolution_nas
- override model: conv_vit
- override dataset: cifar10 # Dataset configuration name
- override features: identity # Feature extractor configuration name (use identity for vision datasets)
- override scheduler: 1cycle # learning rate scheduler config name
- override optimizer: adamw # Optimizer config name
- override normalizer: null # Feature normalizer (used for quantized neural networks)
- override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks)
- _self_

model:
num_classes: 10

module:
batch_size: 128
num_workers: 4

nas:
budget: 500
n_jobs: 4
predictor:
model:
input_feature_size: 38

trainer:
max_epochs: 10

scheduler:
max_lr: 0.001

fx_mac_summary: True

seed: [1234]

experiment_id: "conv_vit"
22 changes: 21 additions & 1 deletion hannah/callbacks/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tabulate import tabulate
from torch.fx.graph_module import GraphModule

from hannah.nas.functional_operators.operators import add, conv2d, linear, conv1d
from hannah.nas.functional_operators.operators import add, conv2d, linear, conv1d, self_attention2d
from hannah.nas.graph_conversion import GraphConversionTracer

from ..models.factory import qat
Expand Down Expand Up @@ -478,6 +478,25 @@ def get_linear(node, output, args, kwargs):
return num_weights, macs, attrs


def get_attn2d(node, output, args, kwargs):
num_weights = 0

qkv_input = args[0] # shape: [B, h*d*3, H, W]
batch_size = qkv_input.shape[0]
numheads_x_headdim = qkv_input.shape[1] / 3
height = qkv_input.shape[2]
width = qkv_input.shape[3]
hxw = height * width
# query and key dot product
qk_macs = batch_size * hxw * hxw * numheads_x_headdim
# attention and value dot product
av_macs = batch_size * hxw * hxw * numheads_x_headdim
macs = qk_macs + av_macs

attrs = ""
return num_weights, macs, attrs


def get_type(node):
try:
return node.name.split("_")[-2]
Expand All @@ -498,6 +517,7 @@ def __init__(self, module: torch.nn.Module):
conv2d: get_conv,
linear: get_linear,
add: get_zero_op,
self_attention2d: get_attn2d,
}

self.data = {
Expand Down
3 changes: 3 additions & 0 deletions hannah/conf/model/conv_vit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: hannah.models.conv_vit.models.search_space
name: conv_vit
num_classes: 10
34 changes: 34 additions & 0 deletions hannah/models/conv_vit/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from hannah.nas.functional_operators.op import scope
from hannah.models.conv_vit.operators import conv2d, self_attention2d


@scope
def attention2d(input, num_heads, d_model):
# [B, C, H, W] --> 3 tensors each of shape [B, h*d, H, W]
inner_dim = num_heads * d_model
q = q_proj(input, inner_dim)
k = k_proj(input, inner_dim)
v = v_proj(input, inner_dim)

# 3 x [B, h*d, H, W] --> [B, h*d, H, W]
out = self_attention2d(q, k, v, num_heads=num_heads, d_model=d_model)

return out


@scope
def q_proj(input, out_dim):
q = conv2d(input, out_dim, kernel_size=1)
return q


@scope
def k_proj(input, out_dim):
k = conv2d(input, out_dim, kernel_size=1)
return k


@scope
def v_proj(input, out_dim):
v = conv2d(input, out_dim, kernel_size=1)
return v
174 changes: 174 additions & 0 deletions hannah/models/conv_vit/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from functools import partial
from torch.nn import functional as F

from hannah.nas.parameters.parameters import IntScalarParameter
from hannah.nas.expressions.arithmetic import Ceil
from hannah.nas.expressions.types import Int
from hannah.nas.functional_operators.op import scope
from hannah.nas.functional_operators.lazy import lazy

from hannah.models.conv_vit.operators import (
conv2d, batch_norm, relu, linear, add,
max_pool, adaptive_avg_pooling,
choice, dynamic_depth, grouped_conv2d
)
from hannah.models.conv_vit.attention import attention2d


@scope
def stem(input, kernel_size, stride, out_channels):
out = conv2d(input, out_channels, kernel_size, stride)
out = batch_norm(out)
out = relu(out)
out = max_pool(out, kernel_size=3, stride=2)

return out


@scope
def classifier_head(input, num_classes):
out = adaptive_avg_pooling(input)
out = linear(out, num_classes)

return out


@scope
def residual(input, main_branch_output_shape):
input_shape = input.shape()
in_fmap = input_shape[2]
out_channels = main_branch_output_shape[1]
out_fmap = main_branch_output_shape[2]
stride = Int(Ceil(in_fmap / out_fmap))

out = conv2d(input, out_channels=out_channels, kernel_size=1, stride=stride, padding=0)
out = batch_norm(out)
out = relu(out)

return out


@scope
def conv_layer(input, out_channels, kernel_size, stride):
out = conv2d(input, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
out = relu(out)

out = conv2d(out, out_channels=out_channels, kernel_size=1, stride=1)
out = batch_norm(out)
out = relu(out)

return out


@scope
def embedding(input, expand_ratio, out_channels, kernel_size, stride):
in_channels = input.shape()[1]
expanded_channels = Int(expand_ratio * in_channels)

out = conv2d(input, expanded_channels, kernel_size=1, stride=1, padding=0)
out = batch_norm(out)
out = relu(out)

out = grouped_conv2d(out, expanded_channels, kernel_size=kernel_size, stride=stride, groups=expanded_channels)
out = batch_norm(out)
out = relu(out)

out = conv2d(out, out_channels, kernel_size=1, stride=1, padding=0)
out = batch_norm(out)
out = relu(out)

res = residual(input, out.shape())
out = add(out, res)

return out


@scope
def attention_layer(input, num_heads, d_model, out_channels):
out = attention2d(input, num_heads, d_model)
out = conv2d(out, out_channels, kernel_size=1, stride=1, padding=0)
out = batch_norm(out)
out = relu(out)

res = residual(input, out.shape())
out = add(out, res)

return out


@scope
def feed_forward(input, out_channels):
out = conv2d(input, out_channels, kernel_size=1, stride=1, padding=0)
out = batch_norm(out)
out = relu(out)

res = residual(input, out.shape())
out = add(out, res)

return out


@scope
def transformer_cell(input, expand_ratio, out_channels, kernel_size, stride, num_heads, d_model):
out = embedding(input, expand_ratio, out_channels, kernel_size, stride)
out = attention_layer(out, num_heads, d_model, out_channels)
out = feed_forward(out, out_channels)

return out


@scope
def attention_cell(input, out_channels, kernel_size, stride, num_heads, d_model):
out = conv_layer(input, out_channels, kernel_size, stride)
out = attention_layer(out, num_heads, d_model, out_channels)

return out


@scope
def pattern(input, expand_ratio, kernel_size, stride, num_heads, d_model, out_channels):
attn = partial(
attention_cell,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
num_heads=num_heads,
d_model=d_model
)
trf = partial(
transformer_cell,
expand_ratio=expand_ratio,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
num_heads=num_heads,
d_model=d_model
)

out = choice(input, attn, trf)

return out


@scope
def block(input, depth, expand_ratio, kernel_size, stride, num_heads, d_model, out_channels):
assert isinstance(depth, IntScalarParameter), "block depth must be of type IntScalarParameter"
out = input
exits = []
for i in range(depth.max+1):
out = pattern(
out,
expand_ratio=expand_ratio.new(),
kernel_size=kernel_size.new(),
stride=stride.new() if i == 0 else 1,
num_heads=num_heads.new(),
d_model=d_model.new(),
out_channels=out_channels.new()
)
exits.append(out)

out = dynamic_depth(*exits, switch=depth)
res = residual(input, out.shape())
out = add(out, res)

return out
49 changes: 49 additions & 0 deletions hannah/models/conv_vit/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter
from hannah.models.conv_vit.operators import dynamic_depth
from hannah.models.conv_vit.blocks import stem, block, classifier_head


def search_space(name, input, num_classes=10):
# Stem parameters
stem_kernel_size = CategoricalParameter([3, 5, 7, 9], name="kernel_size")
stem_stride = CategoricalParameter([1, 2], name='stride')
stem_channels = IntScalarParameter(min=16, max=32, step_size=4, name="out_channels")

# Block parameters
kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size')
stride = CategoricalParameter([1, 2], name='stride')

num_heads = IntScalarParameter(2, 8, step_size=2, name='num_heads')
d_model = IntScalarParameter(16, 64, step_size=16, name='d_model')
expand_ratio = IntScalarParameter(1, 2, name='expand_ratio')
out_channels = IntScalarParameter(16, 64, step_size=4, name='out_channels')

depth = IntScalarParameter(0, 2, name='depth')
num_blocks = IntScalarParameter(0, 4, name='num_blocks')

# Stem
out = stem(input, stem_kernel_size, stem_stride, stem_channels)

# Blocks
exits = []
for _ in range(num_blocks.max+1):
out = block(
out,
depth=depth.new(),
expand_ratio=expand_ratio.new(),
kernel_size=kernel_size.new(),
stride=stride.new(),
num_heads=num_heads.new(),
d_model=d_model.new(),
out_channels=out_channels.new()
)
exits.append(out)

out = dynamic_depth(*exits, switch=num_blocks)
output_fmap = out.shape()[2]
out = classifier_head(out, num_classes=num_classes)

stride_params = [v for k, v in out.parametrization(flatten=True).items() if k.split('.')[-1] == 'stride']
out.cond(output_fmap > 1, allowed_params=stride_params)

return out
Loading

0 comments on commit 0661009

Please sign in to comment.