Skip to content

Commit

Permalink
Support model version 15
Browse files Browse the repository at this point in the history
This commit refactors the `SWPolicyHeadDesc` struct and related
calculation code in `metalbackend.swift` to accommodate model versions
15 and higher. Specifically, it adds new fields to `SWPolicyHeadDesc`
for the bias layer description, pass activation function, and an
additional fully connected linear layer. The changes ensure appropriate
handling of these components in the calculation of the policy head.

The modifications address the need to incorporate these additional
layers in the neural network's policy head for incremented model versions.
By properly configuring the policy head description and adjusting the
corresponding calculation code, the Metal backend can now handle model version
15 accurately.

The GPU error of Metal backend is shown as follows:

```
: Loaded 2247 positions from: base-s1436726784.bin
: Running batched evaluations in fp32
: Running evaluations using current config
: Running batched evaluations using current config
: Computed stats on 2247 positions
: Reporting the average, 90%, 99%, and max abs error between the following configurations:
: batched fp32 - fp32 winrateError:  0.00003% 0.00008% 0.00016% 0.00032%
: batched fp32 - fp32 scoreError:     0.00001  0.00002  0.00004  0.00009
: batched fp32 - fp32 topPolicyDelta: 0.00006% 0.00013% 0.00023% 0.00038%
: batched fp32 - fp32 policyKLDiv:   -0.000000 0.000000 0.000000 0.000000
: current - fp32 winrateError:  0.00003% 0.00008% 0.00016% 0.00027%
: current - fp32 scoreError:     0.00001  0.00002  0.00004  0.00010
: current - fp32 topPolicyDelta: 0.00006% 0.00013% 0.00021% 0.00040%
: current - fp32 policyKLDiv:   -0.000000 0.000000 0.000000 0.000000
: batched current - fp32 winrateError:  0.00003% 0.00008% 0.00015% 0.00032%
: batched current - fp32 scoreError:     0.00001  0.00002  0.00004  0.00010
: batched current - fp32 topPolicyDelta: 0.00006% 0.00013% 0.00023% 0.00040%
: batched current - fp32 policyKLDiv:   -0.000000 0.000000 0.000000 0.000000
: GPU -1 finishing, processed 2247 rows 282 batches
: GPU -1 finishing, processed 4494 rows 2529 batches
```
  • Loading branch information
ChinChangYang committed Dec 11, 2023
1 parent ff76c07 commit 8380a78
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 28 deletions.
9 changes: 7 additions & 2 deletions cpp/neuralnet/metalbackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ SWPolicyHeadDesc MetalProcess::policyHeadDescToSwift(const PolicyHeadDesc * poli
ActivationKind p1Activation = activationLayerDescToSwift(&policyHead->p1Activation);
SWConvLayerDesc p2Conv = convLayerDescToSwift(&policyHead->p2Conv);
SWMatMulLayerDesc gpoolToPassMul = matMulLayerDescToSwift(&policyHead->gpoolToPassMul);
SWMatBiasLayerDesc gpoolToPassBias = matBiasLayerDescToSwift(&policyHead->gpoolToPassBias);
ActivationKind passActivation = activationLayerDescToSwift(&policyHead->passActivation);
SWMatMulLayerDesc gpoolToPassMul2 = matMulLayerDescToSwift(&policyHead->gpoolToPassMul2);

SWPolicyHeadDesc swPolicyHead = createSWPolicyHeadDesc(policyHead->modelVersion,
p1Conv,
Expand All @@ -221,7 +224,10 @@ SWPolicyHeadDesc MetalProcess::policyHeadDescToSwift(const PolicyHeadDesc * poli
p1BN,
p1Activation,
p2Conv,
gpoolToPassMul);
gpoolToPassMul,
gpoolToPassBias,
passActivation,
gpoolToPassMul2);

return swPolicyHead;
}
Expand Down Expand Up @@ -583,7 +589,6 @@ InputBuffers::InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int n
maxBatchSize = maxBatchSz;
policyResultChannels = m.policyHead.p2Conv.outChannels;
assert((m.modelVersion >= 12) ? (policyResultChannels == 2) : (policyResultChannels == 1));
assert(m.policyHead.p2Conv.outChannels == m.policyHead.gpoolToPassMul.outChannels);
singleSpatialElts = (size_t)m.numInputChannels * nnXLen * nnYLen;
singleInputElts = (size_t)m.numInputChannels * modelXLen * modelYLen;
singleInputGlobalElts = (size_t)m.numInputGlobalChannels;
Expand Down
87 changes: 73 additions & 14 deletions cpp/neuralnet/metalbackend.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,12 @@ public struct SWPolicyHeadDesc {
let p2Conv: SWConvLayerDesc
/// The fully connected linear layer for outputting logits for the pass move
let gpoolToPassMul: SWMatMulLayerDesc
/// The description of the bias layer that is applied to the output of the matrix multiplication layer for model version >= 15
let gpoolToPassBias: SWMatBiasLayerDesc?
/// The activation function for the bias layer in model version >= 15
let passActivation: ActivationKind?
/// The fully connected linear layer for outputting logits for the pass move in model version >= 15
let gpoolToPassMul2: SWMatMulLayerDesc?

/// Initializes a SWPolicyHeadDesc object with the given parameters
/// - Parameters:
Expand All @@ -1881,7 +1887,10 @@ public struct SWPolicyHeadDesc {
p1BN: SWBatchNormLayerDesc,
p1Activation: ActivationKind,
p2Conv: SWConvLayerDesc,
gpoolToPassMul: SWMatMulLayerDesc) {
gpoolToPassMul: SWMatMulLayerDesc,
gpoolToPassBias: SWMatBiasLayerDesc?,
passActivation: ActivationKind?,
gpoolToPassMul2: SWMatMulLayerDesc?) {
self.version = version
self.p1Conv = p1Conv
self.g1Conv = g1Conv
Expand All @@ -1892,6 +1901,12 @@ public struct SWPolicyHeadDesc {
self.p1Activation = p1Activation
self.p2Conv = p2Conv
self.gpoolToPassMul = gpoolToPassMul
self.gpoolToPassBias = gpoolToPassBias
self.passActivation = passActivation
self.gpoolToPassMul2 = gpoolToPassMul2

assert((version >= 15) || ((gpoolToPassBias == nil) && (passActivation == nil) && (gpoolToPassMul2 == nil)))
assert((version < 15) || ((gpoolToPassBias != nil) && (passActivation != nil) && (gpoolToPassMul2 != nil)))
}
}

Expand All @@ -1904,17 +1919,39 @@ public func createSWPolicyHeadDesc(version: Int32,
p1BN: SWBatchNormLayerDesc,
p1Activation: ActivationKind,
p2Conv: SWConvLayerDesc,
gpoolToPassMul: SWMatMulLayerDesc) -> SWPolicyHeadDesc {
return SWPolicyHeadDesc(version: Int(version),
p1Conv: p1Conv,
g1Conv: g1Conv,
g1BN: g1BN,
g1Activation: g1Activation,
gpoolToBiasMul: gpoolToBiasMul,
p1BN: p1BN,
p1Activation: p1Activation,
p2Conv: p2Conv,
gpoolToPassMul: gpoolToPassMul)
gpoolToPassMul: SWMatMulLayerDesc,
gpoolToPassBias: SWMatBiasLayerDesc,
passActivation: ActivationKind,
gpoolToPassMul2: SWMatMulLayerDesc) -> SWPolicyHeadDesc {
if version >= 15 {
return SWPolicyHeadDesc(version: Int(version),
p1Conv: p1Conv,
g1Conv: g1Conv,
g1BN: g1BN,
g1Activation: g1Activation,
gpoolToBiasMul: gpoolToBiasMul,
p1BN: p1BN,
p1Activation: p1Activation,
p2Conv: p2Conv,
gpoolToPassMul: gpoolToPassMul,
gpoolToPassBias: gpoolToPassBias,
passActivation: passActivation,
gpoolToPassMul2: gpoolToPassMul2)
} else {
return SWPolicyHeadDesc(version: Int(version),
p1Conv: p1Conv,
g1Conv: g1Conv,
g1BN: g1BN,
g1Activation: g1Activation,
gpoolToBiasMul: gpoolToBiasMul,
p1BN: p1BN,
p1Activation: p1Activation,
p2Conv: p2Conv,
gpoolToPassMul: gpoolToPassMul,
gpoolToPassBias: nil,
passActivation: nil,
gpoolToPassMul2: nil)
}
}

/// A structure that represents a policy head of a neural network.
Expand Down Expand Up @@ -2001,14 +2038,36 @@ struct PolicyHead {
nnXLen: nnXLen,
nnYLen: nnYLen)

policyTensor = p2Conv.resultTensor

assert(g1Concat.resultTensor.shape?[1] == descriptor.gpoolToPassMul.inChannels)

let gpoolToPassMul = MatMulLayer(graph: graph,
descriptor: descriptor.gpoolToPassMul,
sourceTensor: g1Concat.resultTensor)

policyTensor = p2Conv.resultTensor
policyPassTensor = gpoolToPassMul.resultTensor
if let gpoolToPassBias = descriptor.gpoolToPassBias,
let passActivation = descriptor.passActivation,
let gpoolToPassMul2 = descriptor.gpoolToPassMul2 {
assert(descriptor.version >= 15)

let gpoolToPassBiasLayer = MatBiasLayer(graph: graph,
descriptor: gpoolToPassBias,
sourceTensor: gpoolToPassMul.resultTensor)

let passActivationLayer = ActivationLayer(graph: graph,
sourceTensor: gpoolToPassBiasLayer.resultTensor,
activationKind: passActivation)

let gpoolToPassMul2Layer = MatMulLayer(graph: graph,
descriptor: gpoolToPassMul2,
sourceTensor: passActivationLayer.resultTensor)

policyPassTensor = gpoolToPassMul2Layer.resultTensor
} else {
assert(descriptor.version < 15)
policyPassTensor = gpoolToPassMul.resultTensor
}

assert(policyTensor.shape?.count == 4)
assert(policyPassTensor.shape?.count == 2)
Expand Down
Loading

0 comments on commit 8380a78

Please sign in to comment.