diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 35a0a56..e359aac 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -94,7 +94,43 @@ jobs: uses: ./.github/workflows/TestRunnerSnitch.yml with: test-names: | + Adder + iSoftmax + TestiNoNorm + TestAdderLarge + TestiSoftmaxLarge testMatMul + testRQGEMM + TestRQAdd + testRQGEMMTransB + num-cores: 9 + + snitch-kernels-tiled-singlebuffer-L2: + uses: ./.github/workflows/TestRunnerTiledSnitchSequential.yml + with: + tests-config: | + [ + { + "name": "TestiNoNorm", + "L1": [5000, 10000] + }, + { + "name": "TestAdderLarge", + "L1": [5000, 10000] + }, + { + "name": "TestiSoftmaxLarge", + "L1": [5000, 10000] + }, + { + "name": "testRQGEMM", + "L1": [2000, 5000] + }, + { + "name": "TestRQAdd", + "L1": [5000, 10000] + } + ] ### Mempool Tests ### mempool-kernels: diff --git a/.github/workflows/TestRunnerSnitch.yml b/.github/workflows/TestRunnerSnitch.yml index 9f2acd2..c7ae8c0 100644 --- a/.github/workflows/TestRunnerSnitch.yml +++ b/.github/workflows/TestRunnerSnitch.yml @@ -6,6 +6,13 @@ on: test-names: required: true type: string + num-cores: + required: true + type: number + simulator: + required: false + default: "banshee" + type: string jobs: test-runner-snitch: @@ -26,7 +33,7 @@ jobs: echo "$testNames" | while IFS= read -r testName; do if [[ -n "$testName" ]]; then echo "Running test: $testName" - python testRunner_snitch.py -t Tests/$testName --toolchain_install_dir /app/install/riscv-llvm/ + python testRunner_snitch.py -t Tests/$testName --simulator=${{ inputs.simulator }} --cores=${{ inputs.num-cores }} --toolchain_install_dir /app/install/riscv-llvm/ fi done shell: bash \ No newline at end of file diff --git a/.github/workflows/TestRunnerTiledSnitchSequential.yml b/.github/workflows/TestRunnerTiledSnitchSequential.yml new file mode 100644 index 0000000..10a101a --- /dev/null +++ b/.github/workflows/TestRunnerTiledSnitchSequential.yml @@ -0,0 +1,60 @@ +name: TestRunnerTiledSnitchSequential + +on: + workflow_call: + inputs: + tests-config: + required: true + type: string + num-cores: + required: false + default: 9 + type: number + default-memory-level: + required: false + default: "L2" + type: string + simulator: + required: false + default: "banshee" + type: string + + +jobs: + + test-runner-snitch-tiled: + runs-on: ubuntu-22.04 + container: + image: ghcr.io/pulp-platform/deeploy:main + steps: + - name: Checkout Repo + uses: actions/checkout@v4 + with: + submodules: recursive + - name: Build Deeploy + run: pip install -e . + - name: Install jq + run: apt-get install -y jq + - name: Cache ccache + id: ccache-cache + uses: actions/cache@v4 + with: + path: /app/.ccache + key: ${{ runner.os }}-ccache + - name: Run Tests + run: | + cd DeeployTest + echo '${{ inputs.tests-config }}' > tests.json + mkdir -p /app/.ccache + export CCACHE_DIR=/app/.ccache + + jq -c '.[]' tests.json | while read test; do + testName=$(echo "$test" | jq -r '.name') + L1_values=$(echo "$test" | jq -r '.L1[]') + for L1_value in $L1_values; do + echo "Running test: $testName with L1: $L1_value" + python testRunner_tiled_snitch.py -t Tests/$testName --cores=${{ inputs.num-cores }} --simulator=${{ inputs.simulator }} --l1 $L1_value --defaultMemLevel=${{ inputs.default-memory-level }} --toolchain_install_dir /app/install/riscv-llvm/ + done + done + shell: bash + \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 1c1506d..def05e2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,4 +6,4 @@ url = https://github.com/pulp-platform/pulp-nnx.git [submodule "CMSIS-NN"] path = TargetLibraries/CMSIS/third_party/CMSIS-NN - url = https://github.com/ARM-software/CMSIS-NN.git \ No newline at end of file + url = https://github.com/ARM-software/CMSIS-NN.git diff --git a/Deeploy/Targets/Generic/Layers.py b/Deeploy/Targets/Generic/Layers.py index 91e8107..a714e6d 100644 --- a/Deeploy/Targets/Generic/Layers.py +++ b/Deeploy/Targets/Generic/Layers.py @@ -29,7 +29,7 @@ import numpy as np -from Deeploy.DeeployTypes import NodeMapper, ONNXLayer, Shape +from Deeploy.DeeployTypes import NodeMapper, ONNXLayer, OperatorRepresentation, Shape class ConcatLayer(ONNXLayer): @@ -85,6 +85,23 @@ def __init__(self, maps: List[NodeMapper]): super().__init__(maps) +class iNoNormLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + def computeOps(self): + return self.mapper.parser.operatorRepresentation['size'] * 4 # 2 mul, 1 add, 1 right shift + + def computeShapes(self, inputShapes: Shape, outputShapes: Shape, operatorRepresentation: OperatorRepresentation, + channels_first: bool) -> Tuple[Shape]: + + # JUNGVI: Broadcast the weights and bias to have as many dimensions as the inputs + inputShapes[1] = [1] * (len(inputShapes[0]) - len(inputShapes[1])) + list(inputShapes[1]) + inputShapes[2] = inputShapes[1] + return (inputShapes, outputShapes) + + class RQSiGELULayer(iGELULayer): def __init__(self, maps: List[NodeMapper]): diff --git a/Deeploy/Targets/Generic/Parsers.py b/Deeploy/Targets/Generic/Parsers.py index 7c9e7e7..852e323 100644 --- a/Deeploy/Targets/Generic/Parsers.py +++ b/Deeploy/Targets/Generic/Parsers.py @@ -752,6 +752,41 @@ def parseNodeCtxt(self, return ctxt, True +class iNoNormParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + + ret = all(['D' in node.attrs, 'mul' in node.attrs, 'n_levels' in node.attrs]) + + if ret: + self.operatorRepresentation['D'] = node.attrs['D'] + self.operatorRepresentation['log2D'] = int(np.log2(node.attrs['D'].values).tolist()[0]) + self.operatorRepresentation['mul'] = int(node.attrs['mul'].values.tolist()[0]) + self.operatorRepresentation['n_levels'] = node.attrs['n_levels'] + + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + + data_in = ctxt.lookup(node.inputs[0].name) + weights = ctxt.lookup(node.inputs[1].name) + bias = ctxt.lookup(node.inputs[2].name) + data_out = ctxt.lookup(node.outputs[0].name) + self.operatorRepresentation['data_in'] = data_in.name + self.operatorRepresentation['weights'] = weights.name + self.operatorRepresentation['bias'] = bias.name + self.operatorRepresentation['data_out'] = data_out.name + self.operatorRepresentation['size'] = np.prod(data_in.shape) + + return ctxt, True + + class RQSiHardswishParser(iHardswishParser): def __init__(self): @@ -2080,3 +2115,59 @@ def parseNodeCtxt(self, return newCtxt, True return ctxt, False + + +class RQAddParser(AddParser): + + def parseNode(self, node: gs.Node) -> bool: + + if not super().parseNode(node): + return False + + ret = all([ + 'rqs1_mul' in node.attrs, + 'rqs1_add' in node.attrs, + 'rqs1_div' in node.attrs, + 'rqs1_signed' in node.attrs, + any(['rqs1_n_levels' in node.attrs, 'rqs1_n_levels_out' in node.attrs]), + 'rqs2_mul' in node.attrs, + 'rqs2_add' in node.attrs, + 'rqs2_div' in node.attrs, + 'rqs2_signed' in node.attrs, + any(['rqs2_n_levels' in node.attrs, 'rqs2_n_levels_out' in node.attrs]), + 'rqsOut_mul' in node.attrs, + 'rqsOut_add' in node.attrs, + 'rqsOut_div' in node.attrs, + 'rqsOut_signed' in node.attrs, + any(['rqsOut_n_levels' in node.attrs, 'rqsOut_n_levels_out' in node.attrs]), + ]) + + if ret: + if 'rqs1_n_levels' in node.attrs: + self.operatorRepresentation['rqs1_n_levels'] = int(node.attrs['rqs1_n_levels'].values) + else: + self.operatorRepresentation['rqs1_n_levels'] = int(node.attrs['rqs1_n_levels_out'].values) + self.operatorRepresentation['rqs1_mul'] = int(node.attrs['rqs1_mul']) + self.operatorRepresentation['rqs1_add'] = int(node.attrs['rqs1_add']) + self.operatorRepresentation['rqs1_signed'] = int(node.attrs['rqs1_signed'].values) + self.operatorRepresentation['rqs1_log2D'] = int(math.log2(node.attrs['rqs1_div'].values)) + + if 'rqs2_n_levels' in node.attrs: + self.operatorRepresentation['rqs2_n_levels'] = int(node.attrs['rqs2_n_levels'].values) + else: + self.operatorRepresentation['rqs2_n_levels'] = int(node.attrs['rqs2_n_levels_out'].values) + self.operatorRepresentation['rqs2_mul'] = int(node.attrs['rqs2_mul']) + self.operatorRepresentation['rqs2_add'] = int(node.attrs['rqs2_add']) + self.operatorRepresentation['rqs2_signed'] = int(node.attrs['rqs2_signed'].values) + self.operatorRepresentation['rqs2_log2D'] = int(math.log2(node.attrs['rqs2_div'].values)) + + if 'rqsOut_n_levels' in node.attrs: + self.operatorRepresentation['rqsOut_n_levels'] = int(node.attrs['rqsOut_n_levels'].values) + else: + self.operatorRepresentation['rqsOut_n_levels'] = int(node.attrs['rqsOut_n_levels_out'].values) + self.operatorRepresentation['rqsOut_mul'] = int(node.attrs['rqsOut_mul']) + self.operatorRepresentation['rqsOut_add'] = int(node.attrs['rqsOut_add']) + self.operatorRepresentation['rqsOut_signed'] = int(node.attrs['rqsOut_signed'].values) + self.operatorRepresentation['rqsOut_log2D'] = int(math.log2(node.attrs['rqsOut_div'].values)) + + return ret diff --git a/Deeploy/Targets/Generic/Templates/RQAddTemplate.py b/Deeploy/Targets/Generic/Templates/RQAddTemplate.py new file mode 100644 index 0000000..dacc9ac --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/RQAddTemplate.py @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------- +# +# File: RQAddTemplate.py +# +# Last edited: 11.11.2023 +# +# Copyright (C) 2023, ETH Zurich and University of Bologna. +# +# Author: +# - Moritz Scherer, ETH Zurich +# - Victor Jung, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class RQAddTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + # Extract signedness information of input, weights and output + signedI2 = ctxt.lookup(operatorRepresentation['data_in_2'])._type.referencedType.typeMin < 0 + signedI = ctxt.lookup(operatorRepresentation['data_in_1'])._type.referencedType.typeMin < 0 + signedO = ctxt.lookup(operatorRepresentation['data_out'])._type.referencedType.typeMin < 0 + operatorRepresentation['input_2_signed'] = signedI2 + operatorRepresentation['input_signed'] = signedI + operatorRepresentation['output_signed'] = signedO + + return ctxt, operatorRepresentation, [] diff --git a/Deeploy/Targets/Generic/Templates/iNoNormTemplate.py b/Deeploy/Targets/Generic/Templates/iNoNormTemplate.py new file mode 100644 index 0000000..242962e --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/iNoNormTemplate.py @@ -0,0 +1,38 @@ +# ---------------------------------------------------------------------- +# +# File: iNoNormTemplate.py +# +# Last edited: 22.02.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from Deeploy.DeeployTypes import NodeTemplate + + +class _iNoNormTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + +referenceTemplate = _iNoNormTemplate(""" +// iNoNorm (Name: ${nodeName}, Op: ${nodeOp}) +SnitchiNoNorm_s${data_in_type.referencedType.typeWidth}_s${data_out_type.referencedType.typeWidth}(${data_in}, ${data_out}, ${weights}, ${bias}, ${size}, ${mul}, ${log2D}); +""") diff --git a/Deeploy/Targets/Generic/Templates/iSoftmaxPreAllocatedBuffTemplate.py b/Deeploy/Targets/Generic/Templates/iSoftmaxPreAllocatedBuffTemplate.py new file mode 100644 index 0000000..45b80a7 --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/iSoftmaxPreAllocatedBuffTemplate.py @@ -0,0 +1,64 @@ +# ---------------------------------------------------------------------- +# +# File: iSoftmaxPreAllocatedBuffTemplate.py +# +# Last edited: 09.07.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Moritz Scherer, scheremo@iis.ee.ethz.ch, ETH Zurich +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple, Union + +from ortools.constraint_solver.pywrapcp import IntVar + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class iSoftmaxPreAllocatedBuffTemplate(NodeTemplate): + + @staticmethod + def computeTransientBuffersSize( + ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> List[Tuple[str, Union[int, IntVar]]]: + + lastDimBuffer_dim = 8 * 4 * operatorRepresentation['lastDimLength'] + lastDimBuffer_name = operatorRepresentation['nodeName'] + "_lastDimBuffer" + return [(lastDimBuffer_name, lastDimBuffer_dim)] + + def hoistTransientBuffers(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + lastDimBuffer_name, lastDimBuffer_dim = iSoftmaxPreAllocatedBuffTemplate.computeTransientBuffersSize( + ctxt, operatorRepresentation)[0] + ctxt.hoistTransientBuffer(lastDimBuffer_name, lastDimBuffer_dim) + + operatorRepresentation['lastDimBuffer'] = lastDimBuffer_name + operatorRepresentation['lastDimBufferSize'] = lastDimBuffer_dim + return ctxt, operatorRepresentation, [lastDimBuffer_name] + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + + signedI = ctxt.lookup(operatorRepresentation['data_in'])._type.referencedType.typeMin < 0 + signedO = ctxt.lookup(operatorRepresentation['data_out'])._type.referencedType.typeMin < 0 + + operatorRepresentation['input_signed'] = signedI + operatorRepresentation['output_signed'] = signedO + + return ctxt, operatorRepresentation, [] diff --git a/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py b/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py index 96eb222..3a83f4e 100644 --- a/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py +++ b/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py @@ -24,13 +24,14 @@ # limitations under the License. import copy +from collections import OrderedDict from functools import partial from typing import List import numpy as np import onnx_graphsurgeon as gs -from Deeploy.CommonExtensions.OptimizationPasses.Matchers import Match, NonBranchingMatcher +from Deeploy.CommonExtensions.OptimizationPasses.Matchers import BranchingMatcher, Match, NonBranchingMatcher from Deeploy.CommonExtensions.OptimizationPasses.PassClasses import ReplaceSequentialPatternPass, contextagnostic @@ -865,3 +866,135 @@ def __init__(self): name = "_SPLIT_RequantShift_PASS" super().__init__(graph, partial(_split_rqs_fun, splitSet = self.splitSet), name) + + +def _merge_add_rq_fun(graph: gs.Graph, match: Match, name: str): + + nodes_map = match.nodes_map + addNode = nodes_map['add'] + + rqDict = OrderedDict([("rqs1", None), ("rqs2", None), ("rqsOut", None)]) + + for key, node in nodes_map.items(): + + if node.outputs[0].name == addNode.inputs[0].name: + rqDict['rqs1'] = node + elif node.outputs[0].name == addNode.inputs[1].name: + rqDict['rqs2'] = node + elif node.inputs[0].name == addNode.outputs[0].name: + rqDict['rqsOut'] = node + + newAttrs = copy.copy(addNode.attrs) + newInputs = [] + + if rqDict['rqsOut'] is not None: + newOutputs = rqDict['rqsOut'].outputs + else: + newOutputs = addNode.outputs + + defaultAttrs = { + "mul": 1, + "add": 0, + "div": gs.Constant('div', np.array(1)), + 'shift': gs.Constant('div', np.array(0)) + } + guessAttrs = {"n_levels_out": 256, "signed": np.array([True])} + for idx, (rqKey, node) in enumerate(rqDict.items()): + if node.op == "RequantShift": + for key, attr in node.attrs.items(): + newAttrs[f"{rqKey}_{key}"] = attr + + if np.prod(node.inputs[1].values.shape) != 1: + return graph + + if np.prod(node.inputs[2].values.shape) != 1: + return graph + + if rqKey != 'rqsOut': + newInputs.append(node.inputs[0]) + + newAttrs[f"{rqKey}_mul"] = int(node.inputs[1].values.item()) + newAttrs[f"{rqKey}_add"] = int(node.inputs[2].values.item() + newAttrs[f"{rqKey}_div"].values.item() // 2) + newAttrs[f"{rqKey}_shift"] = int(np.log2(newAttrs[f"{rqKey}_div"].values.item())) + + else: + for key, attr in defaultAttrs.items(): + newAttrs[f"{rqKey}_{key}"] = attr + + for key, attr in guessAttrs.items(): + if not key in node.attrs: + newAttrs[f"{rqKey}_{key}"] = attr + else: + newAttrs[f"{rqKey}_{key}"] = node.attrs[key] + if rqKey != 'rqsOut': + newInputs.append(addNode.inputs[idx]) + + rqAdd = gs.Node(op = "RequantizedAdd", name = name, attrs = newAttrs) + graph.replaceInsertNode(newInputs, newOutputs, rqAdd) + + return graph + + +@contextagnostic +class AddRequantMergePass(ReplaceSequentialPatternPass): + pass + + def __init__(self): + _input1 = gs.Variable(name = 'input_1') + _input2 = gs.Variable(name = 'input_2') + _addIn1 = gs.Variable(name = 'addIn1') + _addIn2 = gs.Variable(name = 'addIn2') + _addOut = gs.Variable(name = 'addOut') + _rqs = gs.Variable(name = 'rqs') + + anyIn1 = gs.Node(inputs = [_input1], outputs = [_addIn1], op = r'.*', name = 'any1') + anyIn2 = gs.Node(inputs = [_input2], outputs = [_addIn2], op = r'.*', name = 'any2') + + addOut = gs.Node(inputs = [_addIn1, _addIn2], outputs = [_addOut], op = 'Add', name = 'add') + output = gs.Node(inputs = [_addOut], outputs = [_rqs], op = r'RequantShift', name = 'rqsOut') + + graph = gs.Graph(nodes = [anyIn1, anyIn2, addOut, output], inputs = [_input1, _input2], outputs = [_rqs]) + + super().__init__(graph, + replacement_fn = _merge_add_rq_fun, + name = "_MERGE_ADDRQ_PASS", + matcher = BranchingMatcher(regex_op = True)) + + +def merge_gemm_rq_fun(graph: gs.Graph, match: Match, name: str): + matched_nodes = [m for k, m in match.nodes_map.items()] + gemm = matched_nodes[0] + rqs = matched_nodes[1] + + # WIESEP: Per element quantization is not supported for RQGemm + if len(rqs.inputs[2].shape) > 0 and rqs.inputs[2].shape[-1] != 1: + return graph + + # WIESEP: Per column quantization is not supported for RQGemm + if len(rqs.inputs[2].shape) > 2 and rqs.inputs[2].shape[-3] != 1: + return graph + + _inputs = list(gemm.inputs) + list(rqs.inputs[2:]) + list(rqs.inputs[1:2]) + _outputs = rqs.outputs + + attrs = {**gemm.attrs, **rqs.attrs} + rqsGemm = gs.Node(op = 'RQGemm', name = name, attrs = attrs) + graph.replaceInsertNode(_inputs, _outputs, rqsGemm) + + return graph + + +@contextagnostic +class GEMMRequantMergePass(ReplaceSequentialPatternPass): + + def __init__(self): + passes = [] + graph = gs.Graph() + _input = gs.Variable(name = 'input_1') + output = graph.layer(inputs = [_input], outputs = ['matmul_out'], op = 'Gemm', name = 'gemm') + output = graph.layer(inputs = output, outputs = ['rqs'], op = 'RequantShift', name = 'rqs') + graph.outputs.append(output) + graph.inputs.append(_input) + + name = f"_MERGE_GEMM_RQ_PASS" + super().__init__(graph, merge_gemm_rq_fun, name) diff --git a/Deeploy/Targets/Generic/TypeCheckers.py b/Deeploy/Targets/Generic/TypeCheckers.py index cb88602..475cd20 100644 --- a/Deeploy/Targets/Generic/TypeCheckers.py +++ b/Deeploy/Targets/Generic/TypeCheckers.py @@ -392,6 +392,23 @@ def _inferSignedness(self, inputs: List[VariableBuffer], return [False] +class iNoNormChecker(SignPropTypeChecker): + + def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]): + super().__init__(input_types, output_types) + + def _inferNumLevels(self, inputs: List[VariableBuffer], + operatorRepresentation: OperatorRepresentation) -> List[int]: + return [2**(4 * self.input_types[0].referencedType.typeWidth)] + + def _inferSignedness(self, inputs: List[VariableBuffer], + operatorRepresentation: OperatorRepresentation) -> List[bool]: + if inputs[0]._signed: + return [True] + else: + return [False] + + class GELUChecker(SignPropTypeChecker): def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]): @@ -520,3 +537,26 @@ def _inferSignedness(self, inputs: List[VariableBuffer], return [True] else: return [False] + + +class RQAddChecker(SignPropTypeChecker): + + def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]): + super().__init__(input_types, output_types) + + def _inferNumLevels(self, inputs: List[VariableBuffer], + operatorRepresentation: OperatorRepresentation) -> List[int]: + return [operatorRepresentation['rqsOut_n_levels']] + + def _inferSignedness(self, inputs: List[VariableBuffer], + operatorRepresentation: OperatorRepresentation) -> List[bool]: + return [bool(operatorRepresentation["rqsOut_signed"])] + + # Override this. This should compute the signednes of each output node of the Layer + def checkOutputType(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> bool: + outputTypeSigned = self.output_types[0].referencedType.typeMin < 0 + if operatorRepresentation['rqsOut_signed'] and outputTypeSigned: + return True + if (not operatorRepresentation['rqsOut_signed']) and (not outputTypeSigned): + return True + return False diff --git a/Deeploy/Targets/PULPOpen/Bindings.py b/Deeploy/Targets/PULPOpen/Bindings.py index 5d23620..cb7515b 100644 --- a/Deeploy/Targets/PULPOpen/Bindings.py +++ b/Deeploy/Targets/PULPOpen/Bindings.py @@ -38,7 +38,7 @@ from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration from Deeploy.Targets.Generic.Templates import ConcatTemplate, RQSiGELUTemplate, iHardswishTemplate from Deeploy.Targets.Generic.TypeCheckers import ConcatChecker, GELUChecker, HardswishChecker, MatMulChecker, \ - MulChecker, ReduceMeanChecker, RQHardswishChecker, SliceChecker, SoftmaxChecker, TransposeChecker, \ + MulChecker, ReduceMeanChecker, RQAddChecker, RQHardswishChecker, SliceChecker, SoftmaxChecker, TransposeChecker, \ iLayerNormChecker from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterSynch import PULPSynchCoresPass from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterTiling import PULPClusterTiling @@ -48,7 +48,7 @@ MulTemplate, ReduceMeanTemplate, RequantShiftTemplate, RQAddTemplate, RQSiHardswishTemplate, SliceTemplate, \ TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, iRMSNormTemplate, iSoftmaxTemplate from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \ - PULPRequantShiftChecker, PULPRQAddChecker + PULPRequantShiftChecker from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement _clusterEntryClosureCallTemplate = NodeTemplate(""" @@ -156,8 +156,8 @@ ] PULPRQAddBindings = [ - NodeBinding(PULPRQAddChecker([PointerClass(_type), PointerClass(_type2)], [PointerClass(_type3)]), - RQAddTemplate.RQAddTemplate, ForkTransformer) + NodeBinding(RQAddChecker([PointerClass(_type), PointerClass(_type2)], [PointerClass(_type3)]), + RQAddTemplate.referenceTemplate, ForkTransformer) for _type in [int8_t, uint8_t] for _type2 in [int8_t, uint8_t] for _type3 in [int8_t, uint8_t] diff --git a/Deeploy/Targets/PULPOpen/Parsers.py b/Deeploy/Targets/PULPOpen/Parsers.py index 6878237..33b4abe 100644 --- a/Deeploy/Targets/PULPOpen/Parsers.py +++ b/Deeploy/Targets/PULPOpen/Parsers.py @@ -29,63 +29,7 @@ import onnx_graphsurgeon as gs from Deeploy.DeeployTypes import NetworkContext -from Deeploy.Targets.Generic.Parsers import AddParser, GEMMParser, RQSConv1DParser, RQSConv2DParser, RQSParserInterface - - -class PULPRQAddParser(AddParser): - - def parseNode(self, node: gs.Node) -> bool: - - if not super().parseNode(node): - return False - - ret = all([ - 'rqs1_mul' in node.attrs, - 'rqs1_add' in node.attrs, - 'rqs1_div' in node.attrs, - 'rqs1_signed' in node.attrs, - any(['rqs1_n_levels' in node.attrs, 'rqs1_n_levels_out' in node.attrs]), - 'rqs2_mul' in node.attrs, - 'rqs2_add' in node.attrs, - 'rqs2_div' in node.attrs, - 'rqs2_signed' in node.attrs, - any(['rqs2_n_levels' in node.attrs, 'rqs2_n_levels_out' in node.attrs]), - 'rqsOut_mul' in node.attrs, - 'rqsOut_add' in node.attrs, - 'rqsOut_div' in node.attrs, - 'rqsOut_signed' in node.attrs, - any(['rqsOut_n_levels' in node.attrs, 'rqsOut_n_levels_out' in node.attrs]), - ]) - - if ret: - if 'rqs1_n_levels' in node.attrs: - self.operatorRepresentation['rqs1_n_levels'] = int(node.attrs['rqs1_n_levels'].values) - else: - self.operatorRepresentation['rqs1_n_levels'] = int(node.attrs['rqs1_n_levels_out'].values) - self.operatorRepresentation['rqs1_mul'] = int(node.attrs['rqs1_mul']) - self.operatorRepresentation['rqs1_add'] = int(node.attrs['rqs1_add']) - self.operatorRepresentation['rqs1_signed'] = int(node.attrs['rqs1_signed'].values) - self.operatorRepresentation['rqs1_log2D'] = int(math.log2(node.attrs['rqs1_div'].values)) - - if 'rqs2_n_levels' in node.attrs: - self.operatorRepresentation['rqs2_n_levels'] = int(node.attrs['rqs2_n_levels'].values) - else: - self.operatorRepresentation['rqs2_n_levels'] = int(node.attrs['rqs2_n_levels_out'].values) - self.operatorRepresentation['rqs2_mul'] = int(node.attrs['rqs2_mul']) - self.operatorRepresentation['rqs2_add'] = int(node.attrs['rqs2_add']) - self.operatorRepresentation['rqs2_signed'] = int(node.attrs['rqs2_signed'].values) - self.operatorRepresentation['rqs2_log2D'] = int(math.log2(node.attrs['rqs2_div'].values)) - - if 'rqsOut_n_levels' in node.attrs: - self.operatorRepresentation['rqsOut_n_levels'] = int(node.attrs['rqsOut_n_levels'].values) - else: - self.operatorRepresentation['rqsOut_n_levels'] = int(node.attrs['rqsOut_n_levels_out'].values) - self.operatorRepresentation['rqsOut_mul'] = int(node.attrs['rqsOut_mul']) - self.operatorRepresentation['rqsOut_add'] = int(node.attrs['rqsOut_add']) - self.operatorRepresentation['rqsOut_signed'] = int(node.attrs['rqsOut_signed'].values) - self.operatorRepresentation['rqsOut_log2D'] = int(math.log2(node.attrs['rqsOut_div'].values)) - - return ret +from Deeploy.Targets.Generic.Parsers import GEMMParser, RQSConv1DParser, RQSConv2DParser, RQSParserInterface class PULPConv2DParser(RQSConv2DParser): diff --git a/Deeploy/Targets/PULPOpen/Platform.py b/Deeploy/Targets/PULPOpen/Platform.py index c32f29d..bac2d82 100644 --- a/Deeploy/Targets/PULPOpen/Platform.py +++ b/Deeploy/Targets/PULPOpen/Platform.py @@ -39,9 +39,9 @@ PadLayer, ReduceMeanLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, RQSiHardswishLayer, \ SliceLayer, TransposeLayer, iHardswishLayer, iRMSNormLayer, iSoftmaxLayer from Deeploy.Targets.Generic.Parsers import AddParser, ConcatParser, FlattenParser, GatherParser, MatMulParser, \ - MulParser, Pad1DParser, Pad2DParser, ReduceMeanParser, RequantShiftParser, ReshapeParser, RQIntegerDivParser, \ - RQSiGELUParser, RQSiHardswishParser, SliceParser, TransposeParser, UniformRequantShiftParser, UnsqueezeParser, \ - iHardswishParser, iRMSNormParser, iSoftmaxParser + MulParser, Pad1DParser, Pad2DParser, ReduceMeanParser, RequantShiftParser, ReshapeParser, RQAddParser, \ + RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, SliceParser, TransposeParser, UniformRequantShiftParser, \ + UnsqueezeParser, iHardswishParser, iRMSNormParser, iSoftmaxParser from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import IntegerDivRequantMergePass, \ MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, RQSSplitPass, SkipEmptyConcatPass, \ @@ -50,7 +50,7 @@ PULPReduceMeanBindings from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \ - PULPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, PULPRQAddParser, PULPTallGEMMParser + PULPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, PULPTallGEMMParser from Deeploy.Targets.PULPOpen.Templates import AllocateTemplate, FreeTemplate from Deeploy.Targets.PULPOpen.Tiler import PULPAddTilingReadyBindings, PULPConcatTilingReadyBindings, \ PULPFlattenTilingReadyBindings, PULPiHardswishTilingReadyBindings, PULPiRMSNormTilingReadyBindings, \ @@ -62,7 +62,7 @@ from Deeploy.Targets.PULPOpen.TopologyOptimizationPasses.Passes import PULPAddRequantMergePass, \ PULPConvRequantMergePass, PULPGEMMRequantMergePass, PULPMatMulRequantMergePass -RQAddMapper = NodeMapper(PULPRQAddParser(), PULPRQAddTilingReadyBindings) +RQAddMapper = NodeMapper(RQAddParser(), PULPRQAddTilingReadyBindings) AddMapper = NodeMapper(AddParser(), PULPAddTilingReadyBindings) FlattenMapper = NodeMapper(FlattenParser(), PULPFlattenTilingReadyBindings) GatherMapper = NodeMapper(GatherParser(), BasicGatherBindings) diff --git a/Deeploy/Targets/PULPOpen/Templates/RQAddTemplate.py b/Deeploy/Targets/PULPOpen/Templates/RQAddTemplate.py index 49ede2b..f88b2db 100644 --- a/Deeploy/Targets/PULPOpen/Templates/RQAddTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/RQAddTemplate.py @@ -23,30 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from Deeploy.Targets.Generic.Templates.RQAddTemplate import RQAddTemplate -from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation - - -class PULPRQAddTemplate(NodeTemplate): - - def __init__(self, templateStr): - super().__init__(templateStr) - - def alignToContext(self, ctxt: NetworkContext, - operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: - # Extract signedness information of input, weights and output - signedI2 = ctxt.lookup(operatorRepresentation['data_in_2'])._type.referencedType.typeMin < 0 - signedI = ctxt.lookup(operatorRepresentation['data_in_1'])._type.referencedType.typeMin < 0 - signedO = ctxt.lookup(operatorRepresentation['data_out'])._type.referencedType.typeMin < 0 - operatorRepresentation['input_2_signed'] = signedI2 - operatorRepresentation['input_signed'] = signedI - operatorRepresentation['output_signed'] = signedO - - return ctxt, operatorRepresentation, [] - - -RQAddTemplate = PULPRQAddTemplate(""" +referenceTemplate = RQAddTemplate(""" <% signatureString = '' diff --git a/Deeploy/Targets/Snitch/Bindings.py b/Deeploy/Targets/Snitch/Bindings.py new file mode 100644 index 0000000..0b319b7 --- /dev/null +++ b/Deeploy/Targets/Snitch/Bindings.py @@ -0,0 +1,101 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchBindings.py +# +# Last edited: 30.05.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.CodeTransformationPasses.Closure import ClosureGeneration, MemoryAwareClosureGeneration +from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration, \ + MemoryManagementGeneration +from Deeploy.CommonExtensions.DataTypes import int8_t, int32_t, uint8_t +from Deeploy.DeeployTypes import CodeTransformation, NodeBinding +from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration +from Deeploy.Targets.Generic.Templates import iNoNormTemplate +from Deeploy.Targets.Generic.TypeCheckers import AddChecker, GEMMChecker, RQAddChecker, SoftmaxChecker, iNoNormChecker +from Deeploy.Targets.Snitch.CodeTransformationPasses import SnitchClusterTiling, SnitchCoreFilterPass, \ + SnitchProfileExecutionBlockPass, SnitchSynchCoresPass +from Deeploy.Targets.Snitch.Templates import AddTemplate, RQAddTemplate, iSoftmaxTemplate +from Deeploy.Targets.Snitch.Templates.GemmTemplate import SnitchGemm_Template +from Deeploy.Targets.Snitch.Templates.RqGemmTemplate import SnitchRqGemm_Template +from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement + +TilingCallClosure = partial(ClosureGeneration, closureSuffix = "_tiling_closure") +MemoryAwareFunctionCallClosure = partial(MemoryAwareClosureGeneration, + closureSuffix = "_closure", + startRegion = "L2", + endRegion = "L1") + +BasicTransformer = CodeTransformation( + [SnitchSynchCoresPass(), + ArgumentStructGeneration(), + MemoryManagementGeneration(), + FutureGeneration()]) + +TiledTransformer = CodeTransformation([ + SnitchCoreFilterPass("compute"), + SnitchProfileExecutionBlockPass(), + TilingVariableReplacement("L1"), + TilingCallClosure(writeback = False), + SnitchSynchCoresPass(), + SnitchClusterTiling("L1"), + ArgumentStructGeneration(), + MemoryManagementGeneration("L1"), + MemoryAwareFunctionCallClosure(writeback = False, generateStruct = True), + MemoryManagementGeneration() +]) + +SnitchiSoftmaxBindings = [ + NodeBinding(SoftmaxChecker([PointerClass(_type)], [PointerClass(uint8_t)]), iSoftmaxTemplate.referenceTemplate, + TiledTransformer) for _type in [int8_t, uint8_t] +] +SnitchiNoNormBindings = [ + NodeBinding( + iNoNormChecker([PointerClass(_type), PointerClass(int8_t), + PointerClass(int32_t)], [PointerClass(int8_t)]), iNoNormTemplate.referenceTemplate, + TiledTransformer) for _type in [int8_t] +] +SnitchRQAddBindings = [ + NodeBinding(RQAddChecker([PointerClass(_type), PointerClass(_type)], [PointerClass(_type)]), + RQAddTemplate.referenceTemplate, TiledTransformer) for _type in [int8_t] +] +SnitchAddBindings = [ + NodeBinding(AddChecker([PointerClass(_type), PointerClass(_type)], [PointerClass(int32_t)]), + AddTemplate.referenceTemplate, TiledTransformer) for _type in [int8_t] +] +SnitchGemmBindings = [ + NodeBinding( + GEMMChecker([PointerClass(int8_t), PointerClass(int8_t), + PointerClass(int32_t)], [PointerClass(int32_t)]), SnitchGemm_Template, TiledTransformer) +] +SnitchRqGemmBindings = [ + NodeBinding( + GEMMChecker([ + PointerClass(int8_t), + PointerClass(int8_t), + PointerClass(int32_t), + PointerClass(int32_t), + PointerClass(int32_t) + ], [PointerClass(int8_t)]), SnitchRqGemm_Template, TiledTransformer) +] diff --git a/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterSynch.py b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterSynch.py new file mode 100644 index 0000000..362f6c4 --- /dev/null +++ b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterSynch.py @@ -0,0 +1,46 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchClusterSynch.py +# +# Last edited: 31.05.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from Deeploy.DeeployTypes import CodeGenVerbosity, CodeTransformationPass, ExecutionBlock, NetworkContext, \ + NodeTemplate, _NoVerbosity + +_synchTemplate = NodeTemplate(""" + snrt_cluster_hw_barrier(); +""") + + +class SnitchSynchCoresPass(CodeTransformationPass): + + def apply(self, + ctxt: NetworkContext, + executionBlock: ExecutionBlock, + name: str, + verbose: CodeGenVerbosity = _NoVerbosity) -> Tuple[NetworkContext, ExecutionBlock]: + # TODO: JUNGVI: These have to be core only barriers + executionBlock.addLeft(_synchTemplate, {}) + executionBlock.addRight(_synchTemplate, {}) + return ctxt, executionBlock diff --git a/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterTiling.py b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterTiling.py new file mode 100644 index 0000000..50273a2 --- /dev/null +++ b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterTiling.py @@ -0,0 +1,49 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchClusterTiling.py +# +# Last edited: 31.05.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from Deeploy.DeeployTypes import CodeGenVerbosity, CodeTransformationPass, ExecutionBlock, NetworkContext, _NoVerbosity + +from .SnitchClusterTilingSB import SnitchClusterTilingGenerationSB + + +class SnitchClusterTiling(CodeTransformationPass): + + def __init__(self, targetMemLevel: str): + self.SB = SnitchClusterTilingGenerationSB(targetMemLevel) + + def apply(self, + ctxt: NetworkContext, + executionBlock: ExecutionBlock, + name: str, + verbose: CodeGenVerbosity = _NoVerbosity) -> Tuple[NetworkContext, ExecutionBlock]: + + if verbose.tilingProfiling == "L2": + raise NotImplementedError("Profiling not implemented for L2") + # ctxt, executionBlock = self.profilingSB.apply(ctxt, executionBlock, name) + else: + ctxt, executionBlock = self.SB.apply(ctxt, executionBlock, name) + return ctxt, executionBlock diff --git a/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterTilingSB.py b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterTilingSB.py new file mode 100644 index 0000000..f1c221f --- /dev/null +++ b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchClusterTilingSB.py @@ -0,0 +1,519 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchClusterTilingSB.py +# +# Last edited: 03.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from collections import namedtuple +from typing import Dict, List, Literal, Tuple + +from Deeploy.DeeployTypes import CodeSnippet, ExecutionBlock, NetworkContext, NodeTemplate, OperatorRepresentation +from Deeploy.Targets.Snitch.DataTypes import Snitch_DMA_copy +from Deeploy.TilingExtension.CodeTransformationPasses.TilingCodeGeneration import TilingCodeGeneration +from Deeploy.TilingExtension.CodeTransformationPasses.TilingPrototypes import SingleBufferingTilingMixIn, TilingMetaInfo +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TilingCodegen import HyperRectangle, TilingSchedule, VariableReplacementScheme, \ + calculateRectangleOffset, minimizeRectangleDims + +_openTileLoopTemplate = NodeTemplate(""" + +// TILING LOOP +for (int TILING_I=${numTiles}[*${tileIdxPtr}]; TILING_I<${numTiles}[(*${tileIdxPtr})+1]; TILING_I++){ +""") + +_closeTileLoopTemplate = NodeTemplate(""" + +// CLOSE TILING LOOP +} +*${tileIdxPtr} += 1; + +""") + +_moveTileInTemplate = NodeTemplate(""" + +// IMPORT TILE ${innerTilePtr} from ${outerTilePtr} +if(snrt_is_dm_core()){ + ${stateReference}.tid = snrt_dma_start_2d(${stateReference}.dst, + ${stateReference}.src, + ${stateReference}.size, + ${stateReference}.dst_stride, + ${stateReference}.src_stride, + ${stateReference}.repeat); +} +""") + +_iteratedMoveTileInTemplate = NodeTemplate(""" + +""") + +_blockTileInTemplate = NodeTemplate(""" + +// BLOCKING IMPORT TILE ${innerTilePtr} +if(snrt_is_dm_core()){ + // snrt_dma_wait(${stateReference}.tid); + snrt_dma_wait_all(); +} +""") + +_moveTileOutTemplate = NodeTemplate(""" + +// EXPORT TILE ${innerTilePtr} to ${outerTilePtr} +if(snrt_is_dm_core()){ + ${stateReference}.tid = snrt_dma_start_2d(${stateReference}.dst, + ${stateReference}.src, + ${stateReference}.size, + ${stateReference}.dst_stride, + ${stateReference}.src_stride, + ${stateReference}.repeat); +} +""") + +_blockTileOutTemplate = NodeTemplate(""" + +// BLOCKING EXPORT TILE ${innerTilePtr} +if(snrt_is_dm_core()){ + //snrt_dma_wait(${stateReference}.tid); + snrt_dma_wait_all(); +} +""") + +_updateDMATransferStructTemplate = NodeTemplate(""" + +// UPDATE DMA STRUCT ${stateReference} +${stateReference}.dst = ((char*)${dstPtr}) + ${dstOffsetPtr}[${tileNum}]; +${stateReference}.src = ((char*)${srcPtr}) + ${srcOffsetPtr}[${tileNum}]; +${stateReference}.size = ${sizePtr}[${tileNum}]; +${stateReference}.dst_stride = ${dstStridePtr}[${tileNum}]; +${stateReference}.src_stride = ${srcStridePtr}[${tileNum}]; +${stateReference}.repeat = ${repeatPtr}[${tileNum}]; +""") + +_updateReferenceTemplate = NodeTemplate(""" + +// UPDATE VARIABLE ${reference} +*${reference} = ${baseReference}[${tileNum}]; +""") + +_DMAUpdate = namedtuple("_DMAUpdate", "dst src size dst_stride src_stride repeat tid direction") + + +class SnitchClusterTilingSB(TilingCodeGeneration): + + _prefix = "TILING_REPLACED_" + + _openTileLoopTemplate = _openTileLoopTemplate + _closeTileLoopTemplate = _closeTileLoopTemplate + + _moveTileInTemplate = _moveTileInTemplate + _iteratedMoveTileInTemplate = _iteratedMoveTileInTemplate + _blockTileInTemplate = _blockTileInTemplate + + _moveTileOutTemplate = _moveTileOutTemplate + _blockTileOutTemplate = _blockTileOutTemplate + + _updateDMATransferStructTemplate = _updateDMATransferStructTemplate + _updateReferenceTemplate = _updateReferenceTemplate + + @property + def prefix(self): + return self._prefix + self.targetMemLevel + "_" + + def _DMAStructName(self, tensorName: str, nodeName: str) -> str: + return f"{self.prefix}_DMA_{nodeName}_{tensorName}" + + @classmethod + def _generatePointerUpdates(cls, ctxt: NetworkContext, operatorRepresentation: OperatorRepresentation, + loadSchedule: List[Dict[str, + HyperRectangle]], nodeMemoryConstraint: NodeMemoryConstraint, + tilingSchedule: TilingSchedule) -> Dict[str, _DMAUpdate]: + updateDict = {} + deltaOffsets = {} + + for idx, loadStep in enumerate(loadSchedule): + for _, (key, rect) in enumerate(loadStep.items()): + + if key in tilingSchedule.outputBaseOffsets.keys(): + baseOffsets = tilingSchedule.outputBaseOffsets[key] + direction = "FromL1" + else: + baseOffsets = tilingSchedule.inputBaseOffsets[key] + direction = "ToL1" + + if key not in updateDict.keys(): + updateDict[key] = [] + if key not in deltaOffsets.keys(): + deltaOffsets[key] = 0 + + referenceBuffer = ctxt.lookup(ctxt.lookup(operatorRepresentation[key])._referenceName) + l1Buffer = ctxt.lookup(operatorRepresentation[key]) + + finalMemoryLevel = TilingCodeGeneration.isFinalMemoryLevel(nodeMemoryConstraint, l1Buffer) + + struct = cls._rectToDMAStruct(ctxt, rect, direction, l1Buffer.name, l1Buffer._referenceName, + finalMemoryLevel) + accOffset = calculateRectangleOffset(rect, referenceBuffer) + + lIdx = idx % len(baseOffsets) + + if direction == "ToL1": + src = accOffset + dst = baseOffsets[lIdx] + else: + src = baseOffsets[lIdx] + dst = accOffset + + size = struct.value['size'].value + dst_stride = struct.value['dst_stride'].value + src_stride = struct.value['src_stride'].value + repeat = struct.value['repeat'].value + tid = struct.value['tid'].value + + sol = _DMAUpdate(dst, src, size, dst_stride, src_stride, repeat, tid, direction) + + deltaOffsets[key] = accOffset + updateDict[key].append(sol) + + return updateDict + + @classmethod + def _rectToDMAStruct(cls, ctxt: NetworkContext, rectangle: HyperRectangle, direction: Literal["ToL1", "FromL1"], + L1Name: str, L2Name: str, finalMemoryLevel: bool) -> Snitch_DMA_copy: + + referenceBuffer = ctxt.lookup(L2Name) + + rect, referenceRect = minimizeRectangleDims(rectangle, referenceBuffer) + assert len(rect.dims) <= 3, "Snitch's iDMA only 2D transfers are supported!" + + if direction == "FromL1": + _src = L1Name + _dst = referenceBuffer.name + else: + _src = referenceBuffer.name + _dst = L1Name + + transfer_size = rect.dims[-1] * (referenceBuffer._type.referencedType.typeWidth // 8) + + src_stride = 0 + dst_stride = 0 + repeat = 1 + if len(rect.dims) > 1: + repeat = rect.dims[-2] + if direction == "ToL1": + dst_stride = rect.dims[-1] * (referenceBuffer._type.referencedType.typeWidth // 8) + src_stride = referenceRect.dims[-1] * (referenceBuffer._type.referencedType.typeWidth // 8) + else: + dst_stride = referenceRect.dims[-1] * (referenceBuffer._type.referencedType.typeWidth // 8) + src_stride = rect.dims[-1] * (referenceBuffer._type.referencedType.typeWidth // 8) + + struct = Snitch_DMA_copy( + { + "dst": _dst, + "src": _src, + "size": transfer_size, + "dst_stride": dst_stride, + "src_stride": src_stride, + "repeat": repeat, + "tid": 0 + }, ctxt) + + return struct + + def _hoistDMAUpdates(self, ctxt: NetworkContext, tensorName: str, updateList: List[_DMAUpdate], + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict]: + + operatorRepresentation = operatorRepresentation.copy() + + nodeName = operatorRepresentation['nodeName'] + + dstList = [] + srcList = [] + sizeList = [] + dstStrideList = [] + srcStideList = [] + repeatList = [] + for update in updateList: + dstList.append(int(update.dst)) + srcList.append(int(update.src)) + sizeList.append(int(update.size)) + dstStrideList.append(int(update.dst_stride)) + srcStideList.append(int(update.src_stride)) + repeatList.append(int(update.repeat)) + + dmaName = self._DMAStructName(tensorName, nodeName) + + operatorRepresentation['stateReference'] = dmaName + operatorRepresentation['tileNum'] = "TILING_I" + + if updateList[0].direction == "ToL1": + operatorRepresentation['dstPtr'] = ctxt.lookup(operatorRepresentation[tensorName]).name + operatorRepresentation['srcPtr'] = ctxt.lookup(operatorRepresentation[tensorName])._referenceName + + dstOffsetList = [0] * len(updateList) + srcOffsetList = [srcList[i] - srcList[0] for i in range(0, len(srcList))] + # srcOffsetList = [0] + [sum(sizeList[:i+1]) for i in range(0, len(sizeList)-1)] + else: + operatorRepresentation['dstPtr'] = ctxt.lookup(operatorRepresentation[tensorName])._referenceName + operatorRepresentation['srcPtr'] = ctxt.lookup(operatorRepresentation[tensorName]).name + + dstOffsetList = [dstList[i] - dstList[0] for i in range(0, len(dstList))] + # dstOffsetList = [0] + [sum(sizeList[:i+1]) for i in range(0, len(sizeList)-1)] + srcOffsetList = [0] * len(updateList) + + namePrefix = self.prefix + f"{nodeName}_{tensorName}" + + name = namePrefix + "_dst_offset" + cb = ctxt.ConstantBuffer(name, [len(updateList)], dstOffsetList) + ctxt, operatorRepresentation = self._hoistConstantAndReference(ctxt, cb, operatorRepresentation, nodeName, + 'dstOffsetPtr') + + name = namePrefix + "_src_offset" + cb = ctxt.ConstantBuffer(name, [len(updateList)], srcOffsetList) + ctxt, operatorRepresentation = self._hoistConstantAndReference(ctxt, cb, operatorRepresentation, nodeName, + 'srcOffsetPtr') + + name = namePrefix + "_size" + cb = ctxt.ConstantBuffer(name, [len(updateList)], sizeList) + ctxt, operatorRepresentation = self._hoistConstantAndReference(ctxt, cb, operatorRepresentation, nodeName, + 'sizePtr', + Snitch_DMA_copy.structTypeDict['size']) + + name = namePrefix + "_dst_stride" + cb = ctxt.ConstantBuffer(name, [len(updateList)], dstStrideList) + ctxt, operatorRepresentation = self._hoistConstantAndReference(ctxt, cb, operatorRepresentation, nodeName, + 'dstStridePtr', + Snitch_DMA_copy.structTypeDict['dst_stride']) + + name = namePrefix + "_src_stride" + cb = ctxt.ConstantBuffer(name, [len(updateList)], srcStideList) + ctxt, operatorRepresentation = self._hoistConstantAndReference(ctxt, cb, operatorRepresentation, nodeName, + 'srcStridePtr', + Snitch_DMA_copy.structTypeDict['src_stride']) + + name = namePrefix + "_repeat" + cb = ctxt.ConstantBuffer(name, [len(updateList)], repeatList) + ctxt, operatorRepresentation = self._hoistConstantAndReference(ctxt, cb, operatorRepresentation, nodeName, + 'repeatPtr', + Snitch_DMA_copy.structTypeDict['repeat']) + + return ctxt, operatorRepresentation + + def _generateEgressPointerUpdates( + self, nodeMemoryConstraint: NodeMemoryConstraint, tilingSchedule: TilingSchedule, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, List[CodeSnippet]]: + + updates = [] + newCtxt = ctxt.copy() + + updateDict = self._generatePointerUpdates(ctxt, operatorRepresentation, tilingSchedule.outputLoadSchedule, + nodeMemoryConstraint, tilingSchedule) + + for key, updateList in updateDict.items(): + + newCtxt, newNodeRep = self._hoistDMAUpdates(newCtxt, key, updateList, operatorRepresentation) + updates.append(CodeSnippet(self._updateDMATransferStructTemplate, newNodeRep)) + + return newCtxt, updates + + def _generateIngressPointerUpdates( + self, nodeMemoryConstraint: NodeMemoryConstraint, tilingSchedule: TilingSchedule, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, List[CodeSnippet]]: + + updates = [] + newCtxt = ctxt.copy() + + updateDict = self._generatePointerUpdates(ctxt, operatorRepresentation, tilingSchedule.inputLoadSchedule, + nodeMemoryConstraint, tilingSchedule) + + for key, updateList in updateDict.items(): + + newCtxt, newNodeRep = self._hoistDMAUpdates(newCtxt, key, updateList, operatorRepresentation) + updates.append(CodeSnippet(self._updateDMATransferStructTemplate, newNodeRep)) + + return newCtxt, updates + + def _generateVariableUpdates(self, tilingSchedule: TilingSchedule, variableReplacement: VariableReplacementScheme, + ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> List[CodeSnippet]: + + updates = [] + + for key in variableReplacement.perTileReplacements.keys(): + + buf = ctxt.lookup(operatorRepresentation[key]) + reference = str(buf._instance) + + updates.append( + CodeSnippet(self._updateReferenceTemplate, { + "reference": reference, + "tileNum": "TILING_I", + "baseReference": buf._referenceName + })) + + return updates + + def _generateDMACode(self, nodeMemoryConstraint: NodeMemoryConstraint, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation, loadSchedule: List[Dict[str, HyperRectangle]], + direction: Literal["ToL1", "FromL1"]) -> Tuple[List[CodeSnippet], List[CodeSnippet]]: + + DMATransferCalls = [] + DMAWaitStatements = [] + transferNodeRep = {} + + loadStep = loadSchedule[0] + + for idx, (key, rectangle) in enumerate(loadStep.items()): + + permName = f"in{idx}_perm" + + externalPtr = ctxt.lookup(ctxt.lookup(operatorRepresentation[key])._referenceName) + internalPtr = ctxt.lookup(operatorRepresentation[key]) + + tensorName = key + nodeName = operatorRepresentation['nodeName'] + dmaName = self._DMAStructName(tensorName, nodeName) + + transferNodeRep = { + **transferNodeRep, + **{ + 'innerTilePtr': str(internalPtr._instance), + "outerTilePtr": str(externalPtr._instance), + "stateReference": dmaName + } + } + + finalMemoryLevel = TilingCodeGeneration.isFinalMemoryLevel(nodeMemoryConstraint, internalPtr) + struct = self._rectToDMAStruct(ctxt, rectangle, direction, internalPtr.name, externalPtr.name, + finalMemoryLevel) + + transferNodeRep["stateStruct"] = struct + _ = ctxt.hoistStruct(struct, dmaName, Snitch_DMA_copy) + ctxt.lookup(dmaName)._users += [operatorRepresentation['nodeName']] + + if permName in operatorRepresentation and direction == "ToL1": + + DMATransferCalls.append(CodeSnippet(self._iteratedMoveTileInTemplate, transferNodeRep)) + else: + DMATransferCalls.append(CodeSnippet(self._moveTileInTemplate, transferNodeRep)) + + DMAWaitStatements.append(CodeSnippet(self._blockTileInTemplate, transferNodeRep)) + + return DMATransferCalls, DMAWaitStatements + + def _generateIngressDMACode( + self, tilingSchedule: TilingSchedule, nodeMemoryConstraint: NodeMemoryConstraint, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[List[CodeSnippet], List[CodeSnippet]]: + + importLoadStep = tilingSchedule.inputLoadSchedule + ingressDMATransferCalls, ingressDMAWaitStatements = self._generateDMACode(nodeMemoryConstraint, ctxt, + operatorRepresentation, + importLoadStep, "ToL1") + return ingressDMATransferCalls, ingressDMAWaitStatements + + def _generateEgressDMACode( + self, tilingSchedule: TilingSchedule, nodeMemoryConstraint: NodeMemoryConstraint, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[List[CodeSnippet], List[CodeSnippet]]: + + exportLoadStep = tilingSchedule.outputLoadSchedule + egressDMATransferCalls, egressDMAWaitStatements = self._generateDMACode(nodeMemoryConstraint, ctxt, + operatorRepresentation, exportLoadStep, + "FromL1") + + return egressDMATransferCalls, egressDMAWaitStatements + + def _tilingLoop(self, ctxt: NetworkContext, executionBlock: ExecutionBlock, + nodeMemoryConstraint: NodeMemoryConstraint, tilingSchedule: TilingSchedule, + variableReplacement: VariableReplacementScheme, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, ExecutionBlock, bool]: + + tileIdxPtr = self._hoistTileIdxPtr(ctxt, operatorRepresentation) + + ingressDMATransferCalls, ingressDMAWaitStatements = self._generateIngressDMACode( + tilingSchedule, nodeMemoryConstraint, ctxt, operatorRepresentation) + + egressDMATransferCalls, egressDMAWaitStatements = self._generateEgressDMACode( + tilingSchedule, nodeMemoryConstraint, ctxt, operatorRepresentation) + + ctxt, ingressDMAUpdates = self._generateIngressPointerUpdates(nodeMemoryConstraint, tilingSchedule, ctxt, + operatorRepresentation) + ctxt, egressDMAUpdates = self._generateEgressPointerUpdates(nodeMemoryConstraint, tilingSchedule, ctxt, + operatorRepresentation) + + openLoopStatement = [ + CodeSnippet(self._openTileLoopTemplate, { + "numTiles": operatorRepresentation["numTiles"], + "tileIdxPtr": tileIdxPtr + }) + ] + + closeLoopStatement = [ + CodeSnippet(self._closeTileLoopTemplate, { + "numTiles": operatorRepresentation["numTiles"], + "tileIdxPtr": tileIdxPtr + }) + ] + + variableUpdates = self._generateVariableUpdates(tilingSchedule, variableReplacement, ctxt, + operatorRepresentation) + + metaInfo = TilingMetaInfo(nodeName = operatorRepresentation['nodeName'] + "_L2", + nodeOps = operatorRepresentation['nodeOps'], + numTiles = len(tilingSchedule.outputLoadSchedule), + tileIdxVar = "TILING_I") + + newExecutionBlock = self.generateAllTilingCode(executionBlock, metaInfo, ingressDMATransferCalls, + ingressDMAWaitStatements, ingressDMAUpdates, + egressDMATransferCalls, egressDMAWaitStatements, + egressDMAUpdates, variableUpdates, openLoopStatement, + closeLoopStatement, [], []) + + return ctxt, newExecutionBlock, True + + def generateTilingLoop( + self, ctxt: NetworkContext, executionBlock: ExecutionBlock, nodeMemoryConstraint: NodeMemoryConstraint, + tilingSchedules: List[TilingSchedule], variableReplacement: VariableReplacementScheme, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, ExecutionBlock, bool]: + + flatTilingSchedule = copy.copy(tilingSchedules[0]) + for tilingSchedule in tilingSchedules[1:]: + flatTilingSchedule += tilingSchedule + + # SCHEREMO: hoist numTiles + + offsetLists = list({**flatTilingSchedule.inputBaseOffsets, **flatTilingSchedule.outputBaseOffsets}.values()) + + if len(offsetLists) == 0: + return ctxt, executionBlock, False + + for offsetList in offsetLists: + if not len(offsetList) == 1: + return ctxt, executionBlock, False + + operatorRepresentation["numTiles"] = self._hoistNumTiles(ctxt, operatorRepresentation['nodeName'], + tilingSchedules) + + return self._tilingLoop(ctxt, executionBlock, nodeMemoryConstraint, flatTilingSchedule, variableReplacement, + operatorRepresentation) + + +class SnitchClusterTilingGenerationSB(SnitchClusterTilingSB, SingleBufferingTilingMixIn): + pass diff --git a/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchCoreFilter.py b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchCoreFilter.py new file mode 100644 index 0000000..6ea21f4 --- /dev/null +++ b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchCoreFilter.py @@ -0,0 +1,46 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchCoreFilter.py +# +# Last edited: 04.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Luka Macan, luka.macan@unibo.it, University of Bologna +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, Tuple + +from Deeploy.DeeployTypes import CodeGenVerbosity, CodeTransformationPass, ExecutionBlock, NetworkContext, \ + NodeTemplate, _NoVerbosity + + +class SnitchCoreFilterPass(CodeTransformationPass): + + def __init__(self, coreType: Literal["dm", "compute"]): + super().__init__() + self.coreType = coreType + + def apply(self, + ctxt: NetworkContext, + executionBlock: ExecutionBlock, + name: str, + verbose: CodeGenVerbosity = _NoVerbosity) -> Tuple[NetworkContext, ExecutionBlock]: + executionBlock.addLeft(NodeTemplate(f"if (snrt_is_{self.coreType}_core()) {{\n"), {}) + executionBlock.addRight(NodeTemplate("}\n"), {}) + return ctxt, executionBlock diff --git a/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchProfileExecutionBlock.py b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchProfileExecutionBlock.py new file mode 100644 index 0000000..c2cce33 --- /dev/null +++ b/Deeploy/Targets/Snitch/CodeTransformationPasses/SnitchProfileExecutionBlock.py @@ -0,0 +1,52 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchProfileExecutionBlock.py +# +# Last edited: 05.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from Deeploy.DeeployTypes import CodeGenVerbosity, CodeTransformationPass, ExecutionBlock, NetworkContext, \ + NodeTemplate, _NoVerbosity + +_dumpCycleCntTemplate = NodeTemplate(""" + snrt_cluster_hw_barrier(); + if (snrt_is_dm_core()) { + #ifndef BANSHEE_SIMULATION + DUMP(getCycles()); + #else + printf("${position} of ${nodeName} block at cycle %d \\n", getCycles()); + #endif + } +""") + + +class SnitchProfileExecutionBlockPass(CodeTransformationPass): + + def apply(self, + ctxt: NetworkContext, + executionBlock: ExecutionBlock, + name: str, + verbose: CodeGenVerbosity = _NoVerbosity) -> Tuple[NetworkContext, ExecutionBlock]: + executionBlock.addLeft(_dumpCycleCntTemplate, {"position": "Start", "nodeName": name}) + executionBlock.addRight(_dumpCycleCntTemplate, {"position": "End", "nodeName": name}) + return ctxt, executionBlock diff --git a/Deeploy/Targets/Snitch/CodeTransformationPasses/__init__.py b/Deeploy/Targets/Snitch/CodeTransformationPasses/__init__.py new file mode 100644 index 0000000..d3281dd --- /dev/null +++ b/Deeploy/Targets/Snitch/CodeTransformationPasses/__init__.py @@ -0,0 +1,29 @@ +# ---------------------------------------------------------------------- +# +# File: __init__.py +# +# Last edited: 03.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .SnitchClusterSynch import * +from .SnitchClusterTiling import * +from .SnitchCoreFilter import * +from .SnitchProfileExecutionBlock import * diff --git a/Deeploy/Targets/Snitch/DataTypes.py b/Deeploy/Targets/Snitch/DataTypes.py new file mode 100644 index 0000000..b1d3a92 --- /dev/null +++ b/Deeploy/Targets/Snitch/DataTypes.py @@ -0,0 +1,40 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchDataTypes.py +# +# Last edited: 03.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from Deeploy.AbstractDataTypes import PointerClass, Struct, VoidType +from Deeploy.CommonExtensions.DataTypes import uint16_t + + +class Snitch_DMA_copy(Struct): + typeName = "DMA_copy" + structTypeDict = { + "dst": PointerClass(VoidType), + "src": PointerClass(VoidType), + "size": uint16_t, + "dst_stride": uint16_t, + "src_stride": uint16_t, + "repeat": uint16_t, + "tid": uint16_t + } diff --git a/Deeploy/Targets/Snitch/Parsers.py b/Deeploy/Targets/Snitch/Parsers.py new file mode 100644 index 0000000..dfd3248 --- /dev/null +++ b/Deeploy/Targets/Snitch/Parsers.py @@ -0,0 +1,97 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchParsers.py +# +# Last edited: 07.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# - Luka Macan, luka.macan@unibo.it, University of Bologna +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the Lic +# ense is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import onnx_graphsurgeon as gs + +from Deeploy.DeeployTypes import NetworkContext +from Deeploy.Targets.Generic.Parsers import GEMMParser, RQGEMMParser + + +class SnitchGEMMParser(GEMMParser): + + def parseNode(self, node: gs.Node) -> bool: + ret = super().parseNode(node) + + if not ret: + return False + + if not all([ + self.operatorRepresentation['transA'] == 0, + ]): + return False + + return True + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if not ret: + return ctxt, False + + if not all([ + self.operatorRepresentation['batch'] == 1, + ]): + return ctxt, False + + return newCtxt, True + + +class SnitchRQGEMMParser(RQGEMMParser): + + def parseNode(self, node: gs.Node) -> bool: + ret = super().parseNode(node) + + if not ret: + return False + + if not all([ + self.operatorRepresentation['transA'] == 0, + ]): + return False + + return True + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if not ret: + return ctxt, False + + if not all([ + self.operatorRepresentation['batch'] == 1, + ]): + return ctxt, False + + return newCtxt, True diff --git a/Deeploy/Targets/Snitch/Platform.py b/Deeploy/Targets/Snitch/Platform.py index 89cb97e..3b45d9e 100644 --- a/Deeploy/Targets/Snitch/Platform.py +++ b/Deeploy/Targets/Snitch/Platform.py @@ -30,16 +30,22 @@ from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NodeMapper, NodeTemplate, \ StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer -from Deeploy.Targets.Generic.Bindings import BasicGatherBindings, BasicMatMulBinding, BasicPad1DBindings, \ - BasicPad2DBindings, BasicReshapeBindings, BasicRQIntegerDivBinding -from Deeploy.Targets.Generic.Layers import GatherLayer, MatMulLayer, PadLayer, ReshapeLayer, RQIntegerDivLayer -from Deeploy.Targets.Generic.Parsers import GatherParser, MatMulParser, Pad1DParser, Pad2DParser, RQIntegerDivParser, \ - UnsqueezeParser +from Deeploy.Targets.Generic.Bindings import BasicGatherBindings, BasicLayerNormBinding, BasicMatMulBinding, \ + BasicPad1DBindings, BasicPad2DBindings, BasicReshapeBindings, BasicRQIntegerDivBinding +from Deeploy.Targets.Generic.Layers import AddLayer, GatherLayer, GEMMLayer, MatMulLayer, PadLayer, ReshapeLayer, \ + RQGEMMLayer, RQIntegerDivLayer, iLayerNormLayer, iNoNormLayer, iSoftmaxLayer +from Deeploy.Targets.Generic.Parsers import AddParser, GatherParser, MatMulParser, Pad1DParser, Pad2DParser, \ + RQAddParser, RQIntegerDivParser, UnsqueezeParser, iLayerNormParser, iNoNormParser, iSoftmaxParser from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate -from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import IntegerDivRequantMergePass, \ - MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, RQSSplitPass, SkipEmptyConcatPass, \ - SkipUnityRequantPass, iGELURequantMergePass, iHardswishRequantMergePass +from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import AddRequantMergePass, GEMMRequantMergePass, \ + IntegerDivRequantMergePass, MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, RQSSplitPass, \ + SkipEmptyConcatPass, SkipUnityRequantPass, iGELURequantMergePass, iHardswishRequantMergePass +from Deeploy.Targets.PULPOpen.Platform import RQAddMapper +from Deeploy.Targets.Snitch.Parsers import SnitchGEMMParser, SnitchRQGEMMParser from Deeploy.Targets.Snitch.Templates import AllocateTemplate, FreeTemplate +from Deeploy.Targets.Snitch.Tiler import SnitchAddTileReadyBindings, SnitchGemmTilingReadyBindings, \ + SnitchiNoNormTilingReadyBindings, SnitchiSoftmaxTilingReadyBindings, SnitchRQAddTilingReadyBindings, \ + SnitchRqGemmTilingReadyBindings GatherMapper = NodeMapper(GatherParser(), BasicGatherBindings) Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings) @@ -49,6 +55,13 @@ RQIntegerDivMapper = NodeMapper(RQIntegerDivParser(), [BasicRQIntegerDivBinding]) MatMulMapper = NodeMapper(MatMulParser(), [BasicMatMulBinding]) +GemmMapper = NodeMapper(SnitchGEMMParser(), SnitchGemmTilingReadyBindings) +RqGemmMapper = NodeMapper(SnitchRQGEMMParser(), SnitchRqGemmTilingReadyBindings) +iSoftmaxMapper = NodeMapper(iSoftmaxParser(), SnitchiSoftmaxTilingReadyBindings) +iNoNormMapper = NodeMapper(iNoNormParser(), SnitchiNoNormTilingReadyBindings) +iLayerNormMapper = NodeMapper(iLayerNormParser(), [BasicLayerNormBinding]) +RQAddMapper = NodeMapper(RQAddParser(), SnitchRQAddTilingReadyBindings) +AddMapper = NodeMapper(AddParser(), SnitchAddTileReadyBindings) SnitchMapping = { 'RQIntegerDiv': RQIntegerDivLayer([RQIntegerDivMapper]), @@ -56,6 +69,13 @@ 'Pad': PadLayer([Pad1DMapper, Pad2DMapper]), 'Unsqueeze': ReshapeLayer([UnsqueezeMapper]), 'MatMul': MatMulLayer([MatMulMapper]), + 'Gemm': GEMMLayer([GemmMapper]), + 'RQGemm': RQGEMMLayer([RqGemmMapper]), + 'iSoftmax': iSoftmaxLayer([iSoftmaxMapper]), + 'iNoNorm': iNoNormLayer([iNoNormMapper]), + 'iLayerNorm': iLayerNormLayer([iLayerNormMapper]), + 'RequantizedAdd': AddLayer([RQAddMapper]), + 'Add': AddLayer([AddMapper]), } @@ -136,6 +156,8 @@ class SnitchStructBuffer(StructBuffer): iGELURequantMergePass(), iHardswishRequantMergePass(), MergeConstAddAndRequantPass(), + AddRequantMergePass(), + GEMMRequantMergePass(), ]) _includeList = [ diff --git a/Deeploy/Targets/Snitch/Templates/AddTemplate.py b/Deeploy/Targets/Snitch/Templates/AddTemplate.py new file mode 100644 index 0000000..f604625 --- /dev/null +++ b/Deeploy/Targets/Snitch/Templates/AddTemplate.py @@ -0,0 +1,58 @@ +# ---------------------------------------------------------------------- +# +# File: AddTemplate.py +# +# Last edited: 11.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class _SnitchAddTemplate(NodeTemplate): + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + + data_in_1 = ctxt.lookup(operatorRepresentation['data_in_1']) + data_in_2 = ctxt.lookup(operatorRepresentation['data_in_2']) + data_out = ctxt.lookup(operatorRepresentation['data_out']) + + input_1_offset = 0 + if hasattr(data_in_1, "_signed") and hasattr(data_in_1, "nLevels"): + input_1_offset = (data_in_1._signed == 0) * int(data_in_1.nLevels / 2) + input_2_offset = 0 + if hasattr(data_in_2, "_signed") and hasattr(data_in_2, "nLevels"): + input_2_offset = (data_in_2._signed == 0) * int(data_in_2.nLevels / 2) + output_offset = 0 + if hasattr(data_out, "_signed") and hasattr(data_out, "nLevels"): + output_offset = -(data_out._signed == 0) * int(data_out.nLevels // 2) + + operatorRepresentation['offset'] = input_1_offset + input_2_offset + output_offset + + return ctxt, operatorRepresentation, [] + + +referenceTemplate = _SnitchAddTemplate(""" +// Snitch Add (Name: ${nodeName}, Op: ${nodeOp}) +SnitchAdd(${data_in_1}, ${data_in_2}, ${data_out}, ${size}, ${offset}); +""") diff --git a/Deeploy/Targets/Snitch/Templates/GemmTemplate.py b/Deeploy/Targets/Snitch/Templates/GemmTemplate.py new file mode 100644 index 0000000..8bc0fee --- /dev/null +++ b/Deeploy/Targets/Snitch/Templates/GemmTemplate.py @@ -0,0 +1,31 @@ +from typing import Dict, List, Tuple + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class SnitchGemmTemplate(NodeTemplate): + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + if isinstance(operatorRepresentation['alpha'], float): + assert operatorRepresentation['alpha'].is_integer( + ), f"Parameter alpha is not an integer: {operatorRepresentation['alpha']}" + operatorRepresentation['alpha'] = int(operatorRepresentation['alpha']) + if isinstance(operatorRepresentation['beta'], float): + assert operatorRepresentation['beta'].is_integer( + ), f"Parameter beta is not an integer: {operatorRepresentation['beta']}" + operatorRepresentation['beta'] = int(operatorRepresentation['beta']) + + if operatorRepresentation['transB']: + operatorRepresentation['kernelName'] = "Gemm_s8_transB_row_parallel" + else: + operatorRepresentation['kernelName'] = "Gemm_s8_row_parallel" + + return ctxt, operatorRepresentation, [] + + +SnitchGemmTemplateStr = r""" +${kernelName}(${A}, ${B}, ${C}, ${data_out}, ${M}, ${N}, ${O}, ${alpha}, ${beta}); +""" + +SnitchGemm_Template = SnitchGemmTemplate(SnitchGemmTemplateStr) diff --git a/Deeploy/Targets/Snitch/Templates/RQAddTemplate.py b/Deeploy/Targets/Snitch/Templates/RQAddTemplate.py new file mode 100644 index 0000000..afc637c --- /dev/null +++ b/Deeploy/Targets/Snitch/Templates/RQAddTemplate.py @@ -0,0 +1,50 @@ +# ---------------------------------------------------------------------- +# +# File: RQAddTemplate.py +# +# Last edited: 11.11.2023 +# +# Copyright (C) 2023, ETH Zurich and University of Bologna. +# +# Author: +# - Moritz Scherer, ETH Zurich +# - Victor Jung, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from Deeploy.Targets.Generic.Templates.RQAddTemplate import RQAddTemplate + +referenceTemplate = RQAddTemplate(""" + +<% +signatureString = '' +if input_signed: + signatureString += '_i8' +else: + signatureString += '_u8' +if input_2_signed: + signatureString += '_i8' +else: + signatureString += '_u8' +if output_signed: + signatureString += '_i8' +else: + signatureString += '_u8' +%> + +// PULP NN RQADD +snitch_nn_add${signatureString}(${data_in_1}, ${data_in_2}, ${data_out}, ${rqs1_mul}, ${rqs1_add}, ${rqs1_log2D}, ${rqs2_mul}, ${rqs2_add}, ${rqs2_log2D}, ${rqsOut_mul}, ${rqsOut_add}, ${rqsOut_log2D}, 1, ${size}, 1, 1); +""") diff --git a/Deeploy/Targets/Snitch/Templates/RqGemmTemplate.py b/Deeploy/Targets/Snitch/Templates/RqGemmTemplate.py new file mode 100644 index 0000000..918690e --- /dev/null +++ b/Deeploy/Targets/Snitch/Templates/RqGemmTemplate.py @@ -0,0 +1,33 @@ +from typing import Dict, List, Tuple + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class SnitchRqGemmTemplate(NodeTemplate): + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + if isinstance(operatorRepresentation['alpha'], float): + assert operatorRepresentation['alpha'].is_integer() + operatorRepresentation['alpha'] = int(operatorRepresentation['alpha']) + if isinstance(operatorRepresentation['beta'], float): + assert operatorRepresentation['beta'].is_integer() + operatorRepresentation['beta'] = int(operatorRepresentation['beta']) + + #LMACAN: WARNING: Assumes rounding is expected + add = ctxt.lookup(operatorRepresentation['add']) + add.values += 2**(operatorRepresentation['log2D'] - 1) + + if operatorRepresentation['transB']: + operatorRepresentation['kernelName'] = 'RQGemm_s8_transB_row_parallel_unrolled' + else: + operatorRepresentation['kernelName'] = 'RQGemm_s8_row_parallel_unrolled' + + return ctxt, operatorRepresentation, [] + + +SnitchRqGemmTemplateStr = r""" +${kernelName}(${A}, ${B}, ${C}, ${data_out}, ${M}, ${N}, ${O}, ${alpha}, ${beta}, ${mul}, ${add}, ${log2D}); +""" + +SnitchRqGemm_Template = SnitchRqGemmTemplate(SnitchRqGemmTemplateStr) diff --git a/Deeploy/Targets/Snitch/Templates/iSoftmaxTemplate.py b/Deeploy/Targets/Snitch/Templates/iSoftmaxTemplate.py new file mode 100644 index 0000000..9a15d91 --- /dev/null +++ b/Deeploy/Targets/Snitch/Templates/iSoftmaxTemplate.py @@ -0,0 +1,41 @@ +# ---------------------------------------------------------------------- +# +# File: iSoftmaxTemplate.py +# +# Last edited: 30.05.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from Deeploy.Targets.Generic.Templates.iSoftmaxPreAllocatedBuffTemplate import iSoftmaxPreAllocatedBuffTemplate + +referenceTemplate = iSoftmaxPreAllocatedBuffTemplate(""" +<% +signatureString = '' +if input_signed: + signatureString += '_i8' +else: + signatureString += '_u8' +if output_signed: + signatureString += '_i8' +else: + signatureString += '_u8' +%> +SnitchSoftmax${signatureString}(${data_in}, ${data_out}, ${lastDimBuffer}, ${size}, ${lastDimLength}, ${coeffB}, ${coeffC}, ${log2}); +""") diff --git a/Deeploy/Targets/Snitch/TileConstraints/GemmTileConstraint.py b/Deeploy/Targets/Snitch/TileConstraints/GemmTileConstraint.py new file mode 100644 index 0000000..99fdddd --- /dev/null +++ b/Deeploy/Targets/Snitch/TileConstraints/GemmTileConstraint.py @@ -0,0 +1,161 @@ +from typing import Dict, List, Tuple + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import PerformanceHint, TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \ + VariableReplacementScheme + + +class GemmTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + + # Get to-be-tiled tensor's buffers + bufferA = ctxt.lookup(name = parseDict['A']) + bufferB = ctxt.lookup(name = parseDict['B']) + bufferY = ctxt.lookup(name = parseDict['data_out']) + + # Add I/O dimensions to the model as variables + for bufferName in [bufferA.name, bufferB.name, bufferY.name]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + dimCountA = len(bufferA.shape) + if parseDict['transA'] == 0: + heightIdxA, widthIdxA = dimCountA - 2, dimCountA - 1 + else: + heightIdxA, widthIdxA = dimCountA - 1, dimCountA - 2 + AHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = heightIdxA) + AWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = widthIdxA) + + dimCountB = len(bufferB.shape) + if parseDict['transB'] == 0: + heightIdxB, widthIdxB = dimCountB - 2, dimCountB - 1 + else: + heightIdxB, widthIdxB = dimCountB - 1, dimCountB - 2 + BHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = heightIdxB) + BWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = widthIdxB) + + dimCountY = len(bufferY.shape) + heightIdxY, widthIdxY = dimCountY - 2, dimCountY - 1 + YHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = heightIdxY) + YWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = widthIdxY) + + tilerModel.addConstraint(YHeightDimVar == AHeightDimVar) + tilerModel.addConstraint(YWidthDimVar == BWidthDimVar) + tilerModel.addConstraint(AWidthDimVar == BHeightDimVar) + + if 'C' in parseDict: + bufferC = ctxt.lookup(name = parseDict['C']) + + tilerModel.addTensorDimToModel(ctxt, bufferC.name) + + dimCountC = len(bufferC.shape) + heightIdxC, widthIdxC = dimCountC - 2, dimCountC - 1 + CHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = heightIdxC) + CWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = widthIdxC) + + tilerModel.addConstraint(CHeightDimVar == YHeightDimVar) + tilerModel.addConstraint(CWidthDimVar == YWidthDimVar) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + bufferA = ctxt.lookup(name = parseDict['A']) + bufferY = ctxt.lookup(name = parseDict['data_out']) + + dimCountA = len(bufferA.shape) + if parseDict['transA'] == 0: + heightIdxA, widthIdxA = dimCountA - 2, dimCountA - 1 + else: + heightIdxA, widthIdxA = dimCountA - 1, dimCountA - 2 + AHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = heightIdxA) + AWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = widthIdxA) + + dimCountY = len(bufferY.shape) + heightIdxY, widthIdxY = dimCountY - 2, dimCountY - 1 + YHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = heightIdxY) + YWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = widthIdxY) + + # Full inner dimension + tilerModel.addConstraint(AWidthDimVar == AWidthDimVar.Max()) + + # We parallelize over the output height dimension so try to keep it divisible by the number of cores (8) + if parseDict["M"] > 8: + tilerModel.addTileSizeDivisibleConstraint(parseDict, + "M", + YHeightDimVar, + 8, + strategy = PerformanceHint(priority = 1)) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['A', 'B', 'C', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + NOffset = 0 + NSize = operatorRepresentation["N"] + + replacements = { + "M": [], + "O": [], + "batch": [], + } + + replacementTypes = { + "M": PointerClass(uint32_t), + "O": PointerClass(uint32_t), + "batch": PointerClass(uint32_t), + } + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for YCube in outputCubes: + assert len(YCube.offset) >= 2 or len( + YCube.offset) <= 3, f"Unsupported YCube dimensionality: {len(YCube.offset)}" + + MOffset, OOffset = YCube.offset[-2:] + MSize, OSize = YCube.dims[-2:] + + replacements["M"].append(MSize) + replacements["O"].append(OSize) + + if len(YCube.offset) == 3: + BatchOffset = YCube.offset[0] + BatchSize = YCube.dims[0] + else: + BatchOffset = 0 + BatchSize = 1 + + replacements["batch"].append(BatchSize) + + if operatorRepresentation['transA'] == 0: + ACube = HyperRectangle((BatchOffset, MOffset, NOffset), (BatchSize, MSize, NSize)) + else: + ACube = HyperRectangle((BatchOffset, NOffset, MOffset), (BatchSize, NSize, MSize)) + + if operatorRepresentation['transB'] == 0: + BCube = HyperRectangle((BatchOffset, NOffset, OOffset), (BatchSize, NSize, OSize)) + else: + BCube = HyperRectangle((BatchOffset, OOffset, NOffset), (BatchSize, OSize, NSize)) + + inputLoadSchedule.append({"A": ACube, "B": BCube, "C": YCube}) + outputLoadSchedule.append({"data_out": YCube}) + + schedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + + return VariableReplacementScheme(replacements, replacementTypes), schedule diff --git a/Deeploy/Targets/Snitch/TileConstraints/RqGemmTileConstraint.py b/Deeploy/Targets/Snitch/TileConstraints/RqGemmTileConstraint.py new file mode 100644 index 0000000..5feae3b --- /dev/null +++ b/Deeploy/Targets/Snitch/TileConstraints/RqGemmTileConstraint.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Tuple + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import PerformanceHint, TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \ + VariableReplacementScheme + + +class RqGemmTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + + # Get to-be-tiled tensor's buffers + bufferA = ctxt.lookup(name = parseDict['A']) + bufferB = ctxt.lookup(name = parseDict['B']) + bufferY = ctxt.lookup(name = parseDict['data_out']) + + # Add I/O dimensions to the model as variables + for bufferName in [bufferA.name, bufferB.name, bufferY.name]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + dimCountA = len(bufferA.shape) + if parseDict['transA'] == 0: + heightIdxA, widthIdxA = dimCountA - 2, dimCountA - 1 + else: + heightIdxA, widthIdxA = dimCountA - 1, dimCountA - 2 + AHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = heightIdxA) + AWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = widthIdxA) + + dimCountB = len(bufferB.shape) + if parseDict['transB'] == 0: + heightIdxB, widthIdxB = dimCountB - 2, dimCountB - 1 + else: + heightIdxB, widthIdxB = dimCountB - 1, dimCountB - 2 + BHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = heightIdxB) + BWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = widthIdxB) + + dimCountY = len(bufferY.shape) + heightIdxY, widthIdxY = dimCountY - 2, dimCountY - 1 + YHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = heightIdxY) + YWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = widthIdxY) + + tilerModel.addConstraint(YHeightDimVar == AHeightDimVar) + tilerModel.addConstraint(YWidthDimVar == BWidthDimVar) + tilerModel.addConstraint(AWidthDimVar == BHeightDimVar) + + if 'C' in parseDict: + bufferC = ctxt.lookup(name = parseDict['C']) + + tilerModel.addTensorDimToModel(ctxt, bufferC.name) + + dimCountC = len(bufferC.shape) + heightIdxC, widthIdxC = dimCountC - 2, dimCountC - 1 + CHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = heightIdxC) + CWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = widthIdxC) + + tilerModel.addConstraint(CHeightDimVar == YHeightDimVar) + tilerModel.addConstraint(CWidthDimVar == YWidthDimVar) + + bufferMul = ctxt.lookup(name = parseDict['mul']) + bufferAdd = ctxt.lookup(name = parseDict['add']) + + # Add I/O dimensions to the model as variables + for bufferName in [bufferMul.name, bufferAdd.name]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + MulDimVar = tilerModel.getTensorDimVar(tensorName = bufferMul.name, dimIdx = 0) + AddDimVar = tilerModel.getTensorDimVar(tensorName = bufferAdd.name, dimIdx = 0) + + tilerModel.addConstraint(MulDimVar == YHeightDimVar) + tilerModel.addConstraint(MulDimVar == AddDimVar) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + bufferA = ctxt.lookup(name = parseDict['A']) + bufferY = ctxt.lookup(name = parseDict['data_out']) + + dimCountA = len(bufferA.shape) + if parseDict['transA'] == 0: + heightIdxA, widthIdxA = dimCountA - 2, dimCountA - 1 + else: + heightIdxA, widthIdxA = dimCountA - 1, dimCountA - 2 + AHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = heightIdxA) + AWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = widthIdxA) + + dimCountY = len(bufferY.shape) + heightIdxY, widthIdxY = dimCountY - 2, dimCountY - 1 + YHeightDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = heightIdxY) + YWidthDimVar = tilerModel.getTensorDimVar(tensorName = bufferY.name, dimIdx = widthIdxY) + + # Full inner dimension + tilerModel.addConstraint(AWidthDimVar == AWidthDimVar.Max()) + + # We parallelize over the output height dimension so try to keep it divisible by the number of cores (8) + if parseDict["M"] > 8: + tilerModel.addTileSizeDivisibleConstraint(parseDict, + "M", + YHeightDimVar, + 8, + strategy = PerformanceHint(priority = 1)) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['A', 'B', 'C', 'mul', 'add', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + NOffset = 0 + NSize = operatorRepresentation["N"] + + replacements = { + "M": [], + "O": [], + "batch": [], + } + + replacementTypes = { + "M": PointerClass(uint32_t), + "O": PointerClass(uint32_t), + "batch": PointerClass(uint32_t), + } + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for YCube in outputCubes: + assert len(YCube.offset) >= 2 or len( + YCube.offset) <= 3, f"Unsupported YCube dimensionality: {len(YCube.offset)}" + + MOffset, OOffset = YCube.offset[-2:] + MSize, OSize = YCube.dims[-2:] + + replacements["M"].append(MSize) + replacements["O"].append(OSize) + + if len(YCube.offset) == 3: + BatchOffset = YCube.offset[0] + BatchSize = YCube.dims[0] + else: + BatchOffset = 0 + BatchSize = 1 + + replacements["batch"].append(BatchSize) + + if operatorRepresentation['transA'] == 0: + ACube = HyperRectangle((BatchOffset, MOffset, NOffset), (BatchSize, MSize, NSize)) + else: + ACube = HyperRectangle((BatchOffset, NOffset, MOffset), (BatchSize, NSize, MSize)) + + if operatorRepresentation['transB'] == 0: + BCube = HyperRectangle((BatchOffset, NOffset, OOffset), (BatchSize, NSize, OSize)) + else: + BCube = HyperRectangle((BatchOffset, OOffset, NOffset), (BatchSize, OSize, NSize)) + + MulCube = HyperRectangle((MOffset,), (MSize,)) + AddCube = HyperRectangle((MOffset,), (MSize,)) + + inputLoadSchedule.append({"A": ACube, "B": BCube, "C": YCube, "mul": MulCube, "add": AddCube}) + outputLoadSchedule.append({"data_out": YCube}) + + schedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + + return VariableReplacementScheme(replacements, replacementTypes), schedule diff --git a/Deeploy/Targets/Snitch/TileConstraints/__init__.py b/Deeploy/Targets/Snitch/TileConstraints/__init__.py new file mode 100644 index 0000000..93b3563 --- /dev/null +++ b/Deeploy/Targets/Snitch/TileConstraints/__init__.py @@ -0,0 +1,28 @@ +# ---------------------------------------------------------------------- +# +# File: __init__.py +# +# Last edited: 03.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import * +from .iNoNormTileConstraint import * +from .iSoftmaxTileConstraint import * diff --git a/Deeploy/Targets/Snitch/TileConstraints/iNoNormTileConstraint.py b/Deeploy/Targets/Snitch/TileConstraints/iNoNormTileConstraint.py new file mode 100644 index 0000000..ab7b0cf --- /dev/null +++ b/Deeploy/Targets/Snitch/TileConstraints/iNoNormTileConstraint.py @@ -0,0 +1,114 @@ +# ---------------------------------------------------------------------- +# +# File: iNoNormTileConstraint.py +# +# Last edited: 06.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +import numpy as np + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, TilingSchedule, VariableReplacementScheme + + +class iNoNormTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + + inputBufferName = parseDict['data_in'] + weightsBufferName = parseDict['weights'] + biasBufferName = parseDict['bias'] + outputBufferName = parseDict['data_out'] + + # Add I/O dimensions to the model as variables + for bufferName in [inputBufferName, weightsBufferName, biasBufferName, outputBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + inputShape = ctxt.lookup(inputBufferName).shape + + weigthsBufferShapeLen = len(ctxt.lookup(weightsBufferName).shape) + biasBufferShapeLen = len(ctxt.lookup(biasBufferName).shape) + + weightsLastDimVar = tilerModel.getTensorDimVar(tensorName = weightsBufferName, + dimIdx = weigthsBufferShapeLen - 1) + biasLastDimVar = tilerModel.getTensorDimVar(tensorName = biasBufferName, dimIdx = biasBufferShapeLen - 1) + + tilerModel.addConstraint(biasLastDimVar == weightsLastDimVar) + + for dim in range(len(inputShape)): + inputDimVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = dim) + weightDimVar = tilerModel.getTensorDimVar(tensorName = weightsBufferName, dimIdx = dim) + outputDimVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = dim) + tilerModel.addConstraint(inputDimVar == outputDimVar) + tilerModel.addConstraint(weightDimVar == outputDimVar) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['data_in', 'weights', 'bias', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + replacements = {"size": []} + replacementTypes = {"size": PointerClass(uint32_t)} + + inputCubes = [] + weightCubes = [] + biasCubes = [] + + for outputCube in outputCubes: + + size = np.prod(outputCube.dims[1:]) + lastDimLength = outputCube.dims[-1] + + replacements['size'].append(size) + + inputCubes.append(outputCube) + weightCubes.append(outputCube) + biasCubes.append(outputCube) + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for inp, w, b in zip(inputCubes, weightCubes, biasCubes): + inputLoadSchedule.append({"data_in": inp, "weights": w, "bias": b}) + + for out in outputCubes: + outputLoadSchedule.append({"data_out": out}) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule diff --git a/Deeploy/Targets/Snitch/TileConstraints/iSoftmaxTileConstraint.py b/Deeploy/Targets/Snitch/TileConstraints/iSoftmaxTileConstraint.py new file mode 100644 index 0000000..5528491 --- /dev/null +++ b/Deeploy/Targets/Snitch/TileConstraints/iSoftmaxTileConstraint.py @@ -0,0 +1,119 @@ +# ---------------------------------------------------------------------- +# +# File: iSoftmaxTileConstraint.py +# +# Last edited: 13.11.2023 +# +# Copyright (C) 2023, ETH Zurich and University of Bologna. +# +# Author: Moritz Scherer, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple, Union + +import numpy as np +from ortools.constraint_solver.pywrapcp import IntVar + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, TilingSchedule, VariableReplacementScheme + + +class iSoftmaxTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputBufferName = parseDict['data_in'] + outputBufferName = parseDict['data_out'] + + shapeLen = len(ctxt.lookup(inputBufferName).shape) + + # Add I/O dimensions to the model as variables + for bufferName in [inputBufferName, outputBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + for idx in range(shapeLen): + outputDim = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = idx) + inputDim = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = idx) + tilerModel.addConstraint(outputDim == inputDim) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputBufferName = parseDict['data_in'] + inputBuffer = ctxt.lookup(inputBufferName) + + lastDimLength = inputBuffer.shape[-1] + lastDimIdx = len(inputBuffer.shape) - 1 + lastDimVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = lastDimIdx) + + tilerModel.addConstraint(lastDimVar == lastDimLength) + + return tilerModel + + @staticmethod + def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict, + ctxt: NetworkContext) -> Dict[str, Union[int, IntVar]]: + + inputBufferName = parseDict['data_in'] + inputBuffer = ctxt.lookup(inputBufferName) + + lastDimIdx = len(inputBuffer.shape) - 1 + + symbolicParseDict = parseDict.copy() + symbolicParseDict['lastDimLength'] = tilerModel.getTensorDimVar(inputBuffer.name, lastDimIdx) + + return symbolicParseDict + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['data_in', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + replacements = {"lastDimLength": [], "size": []} + + replacementTypes = {"lastDimLength": PointerClass(uint32_t), "size": PointerClass(uint32_t)} + + for cube in outputCubes: + lastDimLength = cube.dims[-1] + size = np.prod(cube.dims) + + replacements['lastDimLength'].append(lastDimLength) + replacements['size'].append(size) + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for out in outputCubes: + inputLoadSchedule.append({"data_in": out}) + outputLoadSchedule.append({"data_out": out}) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule diff --git a/Deeploy/Targets/Snitch/Tiler.py b/Deeploy/Targets/Snitch/Tiler.py new file mode 100644 index 0000000..38ba29f --- /dev/null +++ b/Deeploy/Targets/Snitch/Tiler.py @@ -0,0 +1,46 @@ +# ---------------------------------------------------------------------- +# +# File: SnitchTiler.py +# +# Last edited: 03.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from Deeploy.Targets.Generic.TileConstraints.AddTileConstraint import AddTileConstraint +from Deeploy.Targets.Snitch.Bindings import SnitchAddBindings, SnitchGemmBindings, SnitchiNoNormBindings, \ + SnitchiSoftmaxBindings, SnitchRQAddBindings, SnitchRqGemmBindings +from Deeploy.Targets.Snitch.TileConstraints import iNoNormTileConstraint, iSoftmaxTileConstraint +from Deeploy.Targets.Snitch.TileConstraints.GemmTileConstraint import GemmTileConstraint +from Deeploy.Targets.Snitch.TileConstraints.RqGemmTileConstraint import RqGemmTileConstraint +from Deeploy.TilingExtension.TilerExtension import TilingReadyNodeBindings + +SnitchiSoftmaxTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchiSoftmaxBindings, + tileConstraint = iSoftmaxTileConstraint()) +SnitchiNoNormTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchiNoNormBindings, + tileConstraint = iNoNormTileConstraint()) +SnitchRQAddTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchRQAddBindings, + tileConstraint = AddTileConstraint()) +SnitchGemmTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchGemmBindings, + tileConstraint = GemmTileConstraint()) +SnitchRqGemmTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchRqGemmBindings, + tileConstraint = RqGemmTileConstraint()) + +SnitchAddTileReadyBindings = TilingReadyNodeBindings(nodeBindings = SnitchAddBindings, + tileConstraint = AddTileConstraint()) diff --git a/Deeploy/TilingExtension/CodeTransformationPasses/TilingCodeGeneration.py b/Deeploy/TilingExtension/CodeTransformationPasses/TilingCodeGeneration.py index e19d1ca..f4e4d9a 100644 --- a/Deeploy/TilingExtension/CodeTransformationPasses/TilingCodeGeneration.py +++ b/Deeploy/TilingExtension/CodeTransformationPasses/TilingCodeGeneration.py @@ -24,16 +24,16 @@ # limitations under the License. from abc import abstractmethod -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple, Type import Deeploy.CommonExtensions.DataTypes as BasicDataTypes -from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.AbstractDataTypes import Immediate, PointerClass from Deeploy.CommonExtensions.CodeTransformationPasses.Closure import ClosureExecutionBlock from Deeploy.CommonExtensions.CodeTransformationPasses.IntrospectiveCodeTransformation import \ IntrospectiveCodeTransformationMixIn from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration -from Deeploy.DeeployTypes import CodeGenVerbosity, CodeTransformationPass, ExecutionBlock, NetworkContext, \ - NodeTemplate, OperatorRepresentation, VariableBuffer, _NoVerbosity +from Deeploy.DeeployTypes import CodeGenVerbosity, CodeTransformationPass, ConstantBuffer, ExecutionBlock, \ + NetworkContext, NodeTemplate, OperatorRepresentation, VariableBuffer, _NoVerbosity from Deeploy.TilingExtension.CodeTransformationPasses.TilingPrototypes import PrototypeTilingMixIn from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint from Deeploy.TilingExtension.TilingCodegen import TilingSchedule, VariableReplacementScheme, minimizeVariableReplacement @@ -121,6 +121,35 @@ def _hoistNumTiles(self, return newPtrName + def _hoistConstantAndReference(self, + ctxt: NetworkContext, + constBuf: ConstantBuffer, + operatorRepresentation: OperatorRepresentation, + nodeName: str, + operatorRepresentationName: str, + immediateType: Optional[Type[Immediate]] = None) -> Tuple[NetworkContext, Dict]: + + if immediateType is None: + _type = PointerClass(BasicDataTypes.int32_t) + else: + _type = PointerClass(immediateType) + + name = constBuf.name + + ctxt.add(constBuf, "global") + constBuf._type = _type + constBuf._instance = constBuf._type(name, ctxt) + constBuf._users = [nodeName] + constBuf._memoryLevel = self.targetMemLevel + + refName = name + "_ref" + reference = ctxt.hoistReference(name, refName) + ctxt.lookup(reference)._memoryLevel = self.targetMemLevel + + operatorRepresentation[operatorRepresentationName] = refName + + return ctxt, operatorRepresentation + def apply(self, ctxt: NetworkContext, executionBlock: ExecutionBlock, diff --git a/DeeployTest/Platforms/Snitch/main.c b/DeeployTest/Platforms/Snitch/main.c index 9ed91aa..38cfa96 100644 --- a/DeeployTest/Platforms/Snitch/main.c +++ b/DeeployTest/Platforms/Snitch/main.c @@ -31,6 +31,10 @@ #include "testinputs.h" #include "testoutputs.h" +// #define NOPRINT +// #define NOTEST +// #define CI + int main(void) { uint32_t core_id = snrt_global_core_idx(); @@ -50,8 +54,9 @@ int main(void) { snrt_cluster_num() * snrt_cluster_dm_core_num(), snrt_cluster_num()); #endif +#ifndef NOPRINT printf("Initializing...\r\n"); - +#endif InitNetwork(core_id, 1); #ifndef CI @@ -61,7 +66,7 @@ int main(void) { DeeployNetwork_inputs[buf], DeeployNetwork_inputs_bytes[buf]); } for (uint32_t buf = 0; buf < DeeployNetwork_num_outputs; buf++) { - printf("testInputVector%d @ %p\r\n", buf, testOutputVector[buf]); + printf("testOutputVector%d @ %p\r\n", buf, testOutputVector[buf]); printf("DeeployNetwork_output_%d @ %p and %u elements\r\n", buf, DeeployNetwork_outputs[buf], DeeployNetwork_outputs_bytes[buf]); } @@ -69,7 +74,9 @@ int main(void) { printf("Initialized\r\n"); #endif +#ifndef NOPRINT printf("Copy inputs...\r\n"); +#endif // WIESEP: Copy inputs to allocated memory for (uint32_t buf = 0; buf < DeeployNetwork_num_inputs; buf++) { @@ -83,18 +90,31 @@ int main(void) { #endif } +#ifndef NOPRINT + if (snrt_is_dm_core()) { + printf("Running network...\r\n"); + } +#endif + snrt_cluster_hw_barrier(); +#ifndef BANSHEE_SIMULATION if (snrt_is_dm_core()) { - printf("Running network...\r\n"); + ResetTimer(); + StartTimer(); } +#endif // BANSHEE_SIMULATION - // ResetTimer(); - // StartTimer(); - if (snrt_is_compute_core()) { - RunNetwork(compute_core_id, num_compute_cores); + RunNetwork(compute_core_id, num_compute_cores); + + uint32_t runtimeCycles = 0; +#ifndef BANSHEE_SIMULATION + if (snrt_is_dm_core()) { + runtimeCycles = getCycles(); + DUMP(runtimeCycles); + StopTimer(); } - // StopTimer(); +#endif // BANSHEE_SIMULATION snrt_cluster_hw_barrier(); @@ -108,6 +128,8 @@ int main(void) { #ifndef CI printf("Output:\r\n"); #endif + +#ifndef NOTEST int32_t tot_err = 0; uint32_t tot = 0; int32_t diff; @@ -130,8 +152,12 @@ int main(void) { } } } - printf("Runtime: %u cycles\r\n", getCycles()); printf("Errors: %u out of %u \r\n", tot_err, tot); +#endif + +#ifndef NOPRINT + printf("Runtime: %u cycles\r\n", runtimeCycles); +#endif } snrt_cluster_hw_barrier(); diff --git a/DeeployTest/Tests/TestAdderLarge/inputs.npz b/DeeployTest/Tests/TestAdderLarge/inputs.npz new file mode 100644 index 0000000..0b45680 Binary files /dev/null and b/DeeployTest/Tests/TestAdderLarge/inputs.npz differ diff --git a/DeeployTest/Tests/TestAdderLarge/network.onnx b/DeeployTest/Tests/TestAdderLarge/network.onnx new file mode 100644 index 0000000..513e19e --- /dev/null +++ b/DeeployTest/Tests/TestAdderLarge/network.onnx @@ -0,0 +1,22 @@ + onnx1.14.0:  +& +input_0 +input_1output_0Add"AddAddNetZ# +input_0 + + + +€ +€Z# +input_1 + + + +€ +€b$ +output_0 + + + +€ +€B \ No newline at end of file diff --git a/DeeployTest/Tests/TestAdderLarge/outputs.npz b/DeeployTest/Tests/TestAdderLarge/outputs.npz new file mode 100644 index 0000000..f578ee2 Binary files /dev/null and b/DeeployTest/Tests/TestAdderLarge/outputs.npz differ diff --git a/DeeployTest/Tests/TestRQAdd/activations.npz b/DeeployTest/Tests/TestRQAdd/activations.npz new file mode 100644 index 0000000..3bab8fb Binary files /dev/null and b/DeeployTest/Tests/TestRQAdd/activations.npz differ diff --git a/DeeployTest/Tests/TestRQAdd/inputs.npz b/DeeployTest/Tests/TestRQAdd/inputs.npz new file mode 100644 index 0000000..d3bea52 Binary files /dev/null and b/DeeployTest/Tests/TestRQAdd/inputs.npz differ diff --git a/DeeployTest/Tests/TestRQAdd/network.onnx b/DeeployTest/Tests/TestRQAdd/network.onnx new file mode 100644 index 0000000..46ff5c2 Binary files /dev/null and b/DeeployTest/Tests/TestRQAdd/network.onnx differ diff --git a/DeeployTest/Tests/TestRQAdd/outputs.npz b/DeeployTest/Tests/TestRQAdd/outputs.npz new file mode 100644 index 0000000..f5aa6cf Binary files /dev/null and b/DeeployTest/Tests/TestRQAdd/outputs.npz differ diff --git a/DeeployTest/Tests/TestiNoNorm/activations.npz b/DeeployTest/Tests/TestiNoNorm/activations.npz new file mode 100644 index 0000000..15cb0ec Binary files /dev/null and b/DeeployTest/Tests/TestiNoNorm/activations.npz differ diff --git a/DeeployTest/Tests/TestiNoNorm/inputs.npz b/DeeployTest/Tests/TestiNoNorm/inputs.npz new file mode 100644 index 0000000..3116968 Binary files /dev/null and b/DeeployTest/Tests/TestiNoNorm/inputs.npz differ diff --git a/DeeployTest/Tests/TestiNoNorm/network.onnx b/DeeployTest/Tests/TestiNoNorm/network.onnx new file mode 100644 index 0000000..58550bc Binary files /dev/null and b/DeeployTest/Tests/TestiNoNorm/network.onnx differ diff --git a/DeeployTest/Tests/TestiNoNorm/outputs.npz b/DeeployTest/Tests/TestiNoNorm/outputs.npz new file mode 100644 index 0000000..bb5e96c Binary files /dev/null and b/DeeployTest/Tests/TestiNoNorm/outputs.npz differ diff --git a/DeeployTest/Tests/TestiSoftmaxLarge/activations.npz b/DeeployTest/Tests/TestiSoftmaxLarge/activations.npz new file mode 100644 index 0000000..15cb0ec Binary files /dev/null and b/DeeployTest/Tests/TestiSoftmaxLarge/activations.npz differ diff --git a/DeeployTest/Tests/TestiSoftmaxLarge/inputs.npz b/DeeployTest/Tests/TestiSoftmaxLarge/inputs.npz new file mode 100644 index 0000000..e7a20ce Binary files /dev/null and b/DeeployTest/Tests/TestiSoftmaxLarge/inputs.npz differ diff --git a/DeeployTest/Tests/TestiSoftmaxLarge/network.onnx b/DeeployTest/Tests/TestiSoftmaxLarge/network.onnx new file mode 100644 index 0000000..a25ad83 Binary files /dev/null and b/DeeployTest/Tests/TestiSoftmaxLarge/network.onnx differ diff --git a/DeeployTest/Tests/TestiSoftmaxLarge/outputs.npz b/DeeployTest/Tests/TestiSoftmaxLarge/outputs.npz new file mode 100644 index 0000000..8840303 Binary files /dev/null and b/DeeployTest/Tests/TestiSoftmaxLarge/outputs.npz differ diff --git a/DeeployTest/Tests/testGEMM/inputs.npz b/DeeployTest/Tests/testGEMM/inputs.npz index d94a87e..fed9cbd 100644 Binary files a/DeeployTest/Tests/testGEMM/inputs.npz and b/DeeployTest/Tests/testGEMM/inputs.npz differ diff --git a/DeeployTest/Tests/testGEMM/network.onnx b/DeeployTest/Tests/testGEMM/network.onnx index 23adcef..2ce6397 100644 Binary files a/DeeployTest/Tests/testGEMM/network.onnx and b/DeeployTest/Tests/testGEMM/network.onnx differ diff --git a/DeeployTest/Tests/testGEMM/outputs.npz b/DeeployTest/Tests/testGEMM/outputs.npz index c2cb4e2..a089802 100644 Binary files a/DeeployTest/Tests/testGEMM/outputs.npz and b/DeeployTest/Tests/testGEMM/outputs.npz differ diff --git a/DeeployTest/Tests/testRQGEMMTransB/inputs.npz b/DeeployTest/Tests/testRQGEMMTransB/inputs.npz new file mode 100644 index 0000000..fb63c07 Binary files /dev/null and b/DeeployTest/Tests/testRQGEMMTransB/inputs.npz differ diff --git a/DeeployTest/Tests/testRQGEMMTransB/network.onnx b/DeeployTest/Tests/testRQGEMMTransB/network.onnx new file mode 100644 index 0000000..b1e31b8 Binary files /dev/null and b/DeeployTest/Tests/testRQGEMMTransB/network.onnx differ diff --git a/DeeployTest/Tests/testRQGEMMTransB/outputs.npz b/DeeployTest/Tests/testRQGEMMTransB/outputs.npz new file mode 100644 index 0000000..d709708 Binary files /dev/null and b/DeeployTest/Tests/testRQGEMMTransB/outputs.npz differ diff --git a/DeeployTest/testRunner_tiled_snitch.py b/DeeployTest/testRunner_tiled_snitch.py index 34ebefb..4cd5324 100644 --- a/DeeployTest/testRunner_tiled_snitch.py +++ b/DeeployTest/testRunner_tiled_snitch.py @@ -37,9 +37,17 @@ default = 9, help = 'Set number of cluster cores') parser.set_defaults(toolchain_install_dir = "/usr/pack/riscv-1.0-kgf/pulp-llvm-0.12.0") + parser.add_argument('--simulator', + metavar = "", + dest = "simulator", + type = str, + choices = ["banshee", "vsim", "vsim.gui"], + default = "banshee", + help = "Select the simulator to use") + args = parser.parse_args() - testRunner = TestRunner(platform = "Snitch", simulator = "banshee", tiling = True, argument_parser = parser) + testRunner = TestRunner(platform = "Snitch", simulator = args.simulator, tiling = True, argument_parser = parser) testRunner.cmake_args += f" -D NUM_CORES={args.cores}" diff --git a/Makefile b/Makefile index b2b33c9..5ce77e7 100644 --- a/Makefile +++ b/Makefile @@ -265,7 +265,8 @@ pulp-sdk: ${PULP_SDK_INSTALL_DIR} ${TOOLCHAIN_DIR}/snitch_cluster: cd ${TOOLCHAIN_DIR} && \ git clone https://github.com/pulp-platform/snitch_cluster.git && \ - cd ${TOOLCHAIN_DIR}/snitch_cluster && git submodule update --init --recursive && \ + cd ${TOOLCHAIN_DIR}/snitch_cluster && git checkout ${SNITCH_COMMIT_HASH} && \ + git submodule update --init --recursive && \ git checkout ${SNITCH_COMMIT_HASH} && git apply ${TOOLCHAIN_DIR}/snitch_cluster.patch ${SNITCH_INSTALL_DIR}: ${TOOLCHAIN_DIR}/snitch_cluster diff --git a/TargetLibraries/Snitch/CMakeLists.txt b/TargetLibraries/Snitch/CMakeLists.txt index 634aa3c..78a214f 100644 --- a/TargetLibraries/Snitch/CMakeLists.txt +++ b/TargetLibraries/Snitch/CMakeLists.txt @@ -8,6 +8,7 @@ add_deeploy_library(deeploysnitch STATIC ${SOURCES}) target_include_directories(deeploysnitch PUBLIC ${CMAKE_CURRENT_LIST_DIR}/inc + ${CMAKE_CURRENT_LIST_DIR}/inc/kernel ) diff --git a/TargetLibraries/Snitch/inc/DeeploySnitchMath.h b/TargetLibraries/Snitch/inc/DeeploySnitchMath.h index e64124a..193df11 100644 --- a/TargetLibraries/Snitch/inc/DeeploySnitchMath.h +++ b/TargetLibraries/Snitch/inc/DeeploySnitchMath.h @@ -54,5 +54,8 @@ #include "kernel/RQMatMul.h" #include "kernel/Softmax.h" #include "kernel/UniformRequantShift.h" +#include "kernel/iNoNorm.h" + +#include "dmaStruct.h" #endif //__DEEPLOY_MATH_HEADER_ diff --git a/TargetLibraries/Snitch/inc/dmaStruct.h b/TargetLibraries/Snitch/inc/dmaStruct.h new file mode 100644 index 0000000..bf36074 --- /dev/null +++ b/TargetLibraries/Snitch/inc/dmaStruct.h @@ -0,0 +1,37 @@ +/* ---------------------------------------------------------------------- +# +# File: dmaStruct.h +# +# Last edited: 03.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +#include "snrt.h" + +typedef struct { + void *dst; + void *src; + size_t size; + size_t dst_stride; + size_t src_stride; + size_t repeat; + snrt_dma_txid_t tid; +} DMA_copy; \ No newline at end of file diff --git a/TargetLibraries/Snitch/inc/kernel/Gemm.h b/TargetLibraries/Snitch/inc/kernel/Gemm.h index f320942..b1197b7 100644 --- a/TargetLibraries/Snitch/inc/kernel/Gemm.h +++ b/TargetLibraries/Snitch/inc/kernel/Gemm.h @@ -48,6 +48,38 @@ /* General Matrix Multiplication (8bit) */ /******************************************************************************/ +/* + * General Matrix Multiplication + * transposed A = no + * transposed B = no + * multi-core = yes + * unrolling = no + * simd = no + * parallelization = row-wise + * bias pushing = no + */ +void Gemm_s8_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int32_t *__restrict__ pDstY, uint32_t M, uint32_t N, + uint32_t O, int32_t alpha, int32_t beta); + +/* + * General Matrix Multiplication + * transposed A = no + * transposed B = yes + * multi-core = yes + * unrolling = no + * simd = no + * parallelization = row-wise + * bias pushing = no + */ +void Gemm_s8_transB_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int32_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, + int32_t beta); /* * General Matrix Multiplication ---------------------------------- * kernel = Gemm_parallel_s8_rv32im diff --git a/TargetLibraries/Snitch/inc/kernel/RQGemm.h b/TargetLibraries/Snitch/inc/kernel/RQGemm.h index e86b579..b1d77c1 100644 --- a/TargetLibraries/Snitch/inc/kernel/RQGemm.h +++ b/TargetLibraries/Snitch/inc/kernel/RQGemm.h @@ -48,6 +48,75 @@ /* General Requantized Matrix Multiplication (8bit) */ /******************************************************************************/ +/* + * General Requantized Matrix Multiplication ---------------------------------- + * transposed A = no + * transposed B = no + * multi-core = yes + * unrolling = no + * simd = no + * parallelization = row-wise + * bias pushing = no + */ +void RQGemm_s8_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int8_t *__restrict__ pDstY, uint32_t M, uint32_t N, + uint32_t O, int32_t alpha, int32_t beta, + int32_t *mul, int32_t *add, int32_t log2D); + +/* + * General Requantized Matrix Multiplication ---------------------------------- + * transposed A = no + * transposed B = no + * multi-core = yes + * unrolling = yes + * simd = no + * parallelization = row-wise + * bias pushing = no + */ +void RQGemm_s8_row_parallel_unrolled(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int8_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, + int32_t beta, int32_t *mul, int32_t *add, + int32_t log2D); + +/* + * General Requantized Matrix Multiplication ---------------------------------- + * transposed A = no + * transposed B = yes + * multi-core = yes + * unrolling = no + * simd = no + * parallelization = row-wise + * bias pushing = no + */ +void RQGemm_s8_transB_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int8_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, + int32_t beta, int32_t *mul, int32_t *add, + int32_t log2D); + +/* + * General Requantized Matrix Multiplication ---------------------------------- + * transposed A = no + * transposed B = yes + * multi-core = yes + * unrolling = yes + * simd = no + * parallelization = row-wise + * bias pushing = no + */ +void RQGemm_s8_transB_row_parallel_unrolled( + int8_t const *__restrict__ pSrcA, int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, int8_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, int32_t beta, int32_t *mul, + int32_t *add, int32_t log2D); + /* * General Requantized Matrix Multiplication ---------------------------------- * kernel = RQGemm_parallel_s8_rv32im diff --git a/TargetLibraries/Snitch/inc/kernel/iNoNorm.h b/TargetLibraries/Snitch/inc/kernel/iNoNorm.h new file mode 100644 index 0000000..56b58e0 --- /dev/null +++ b/TargetLibraries/Snitch/inc/kernel/iNoNorm.h @@ -0,0 +1,31 @@ +/* ---------------------------------------------------------------------- +# +# File: iNoNorm.h +# +# Last edited: 06.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +#include "DeeploySnitchMath.h" + +void SnitchiNoNorm_s8_s8(int8_t *data_in, int8_t *data_out, int8_t *weights, + int32_t *bias, uint32_t size, int32_t mul, + int32_t log2D); diff --git a/TargetLibraries/Snitch/inc/macros.h b/TargetLibraries/Snitch/inc/macros.h index 44c708e..a54bc24 100644 --- a/TargetLibraries/Snitch/inc/macros.h +++ b/TargetLibraries/Snitch/inc/macros.h @@ -29,7 +29,14 @@ #ifndef __DEEPLOY_MATH_MACROS_HEADER_ #define __DEEPLOY_MATH_MACROS_HEADER_ -// #define log2(x) __builtin_pulp_fl1(x) #define INT_LOG2(x) __builtin_ctz(x) +#define CLAMP(x, low, high) \ + (((x) > (high)) ? (high) : (((x) < (low)) ? (low) : (x))) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// JUNGVI: The following macros are here to ensure compatibility with some +// PULP-NN kernels +#define clips8(x) CLAMP(x, -128, 127) #endif //__DEEPLOY_MATH_MACROS_HEADER_ diff --git a/TargetLibraries/Snitch/src/Add.c b/TargetLibraries/Snitch/src/Add.c new file mode 100644 index 0000000..094739c --- /dev/null +++ b/TargetLibraries/Snitch/src/Add.c @@ -0,0 +1,50 @@ +/* ---------------------------------------------------------------------- +# +# File: Add.c +# +# Last edited: 11.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +#include "DeeploySnitchMath.h" + +void SnitchAdd(int8_t *pIn1, int8_t *pIn2, int32_t *pOut, uint32_t size, + int32_t offset) { + + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + uint32_t chunk, chunkSize, start, stop; + + chunkSize = size / numThreads; + if (core_id < (numThreads - 1)) { + chunk = chunkSize * core_id; + stop = chunk + chunkSize; + } else { + chunk = (chunkSize * core_id - 1) + (size - chunk); + stop = size; + } + start = chunk; + +#pragma loopunroll 2 + for (int i = start; i < stop; i++) { + pOut[i] = pIn1[i] + pIn2[i] + offset; + } +} diff --git a/TargetLibraries/Snitch/src/Gemm_s8.c b/TargetLibraries/Snitch/src/Gemm_s8.c new file mode 100644 index 0000000..eefd407 --- /dev/null +++ b/TargetLibraries/Snitch/src/Gemm_s8.c @@ -0,0 +1,61 @@ +#include "DeeploySnitchMath.h" +#include "Gemm.h" + +void Gemm_s8_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int32_t *__restrict__ pDstY, uint32_t M, uint32_t N, + uint32_t O, int32_t alpha, int32_t beta) { + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + + // Parallelize by assigning each core a row tile + uint32_t const MQuotient = M / numThreads; + uint32_t const MRemainder = M % numThreads; + uint32_t const MSize = MQuotient + (core_id < MRemainder ? 1 : 0); + uint32_t const MStart = + core_id * MQuotient + (core_id < MRemainder ? core_id : MRemainder); + uint32_t const MEnd = MStart + MSize; + + for (uint32_t m = MStart; m < MEnd; m++) { + for (uint32_t o = 0; o < O; o++) { + int32_t sum = 0; + for (uint32_t n = 0; n < N; ++n) { + sum += (int32_t)pSrcA[m * N + n] * pSrcB[n * O + o]; + } + sum = alpha * sum + beta * pSrcC[m * O + o]; + + pDstY[m * O + o] = sum; + } + } +} + +void Gemm_s8_transB_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int32_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, + int32_t beta) { + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + + // Parallelize by assigning each core a row tile + uint32_t const MQuotient = M / numThreads; + uint32_t const MRemainder = M % numThreads; + uint32_t const MSize = MQuotient + (core_id < MRemainder ? 1 : 0); + uint32_t const MStart = + core_id * MQuotient + (core_id < MRemainder ? core_id : MRemainder); + uint32_t const MEnd = MStart + MSize; + + for (uint32_t m = MStart; m < MEnd; m++) { + for (uint32_t o = 0; o < O; o++) { + int32_t sum = 0; + for (uint32_t n = 0; n < N; ++n) { + sum += (int32_t)pSrcA[m * N + n] * pSrcB[o * N + n]; + } + sum = alpha * sum + beta * pSrcC[m * O + o]; + + pDstY[m * O + o] = sum; + } + } +} diff --git a/TargetLibraries/Snitch/src/RQGemm_s8.c b/TargetLibraries/Snitch/src/RQGemm_s8.c index 1850bf0..cfbc867 100644 --- a/TargetLibraries/Snitch/src/RQGemm_s8.c +++ b/TargetLibraries/Snitch/src/RQGemm_s8.c @@ -28,6 +28,302 @@ */ #include "DeeploySnitchMath.h" +#include "RQGemm.h" + +// Assumptions: +// - per-row requantization +// - single batch +void RQGemm_s8_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int8_t *__restrict__ pDstY, uint32_t M, uint32_t N, + uint32_t O, int32_t alpha, int32_t beta, + int32_t *mul, int32_t *add, int32_t log2D) { + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + + // Parallelize by assigning each core a row tile + uint32_t const MQuotient = M / numThreads; + uint32_t const MRemainder = M % numThreads; + uint32_t const MSize = MQuotient + (core_id < MRemainder ? 1 : 0); + uint32_t const MStart = + core_id * MQuotient + (core_id < MRemainder ? core_id : MRemainder); + uint32_t const MEnd = MStart + MSize; + + if (core_id < numThreads) { + for (uint32_t m = MStart; m < MEnd; m++) { + for (uint32_t o = 0; o < O; o++) { + int32_t sum = 0; + for (uint32_t n = 0; n < N; ++n) { + sum += (int32_t)pSrcA[m * N + n] * pSrcB[n * O + o]; + } + sum = alpha * sum + beta * pSrcC[m * O + o]; + + // Requantize value + sum = (sum * mul[m] + add[m]) >> log2D; + pDstY[m * O + o] = (int8_t)CLAMP(sum, -128, 127); + } + } + } +} + +// Assumptions: +// - per-row requantization +// - transposed input B +// - single batch +void RQGemm_s8_row_parallel_unrolled(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int8_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, + int32_t beta, int32_t *mul, int32_t *add, + int32_t log2D) { + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + + // Parallelize by assigning each core a row tile + uint32_t const MQuotient = M / numThreads; + uint32_t const MRemainder = M % numThreads; + uint32_t const MSize = MQuotient + (core_id < MRemainder ? 1 : 0); + uint32_t const MStart = + core_id * MQuotient + (core_id < MRemainder ? core_id : MRemainder); + uint32_t const MEnd = MStart + MSize; + + if (core_id < numThreads) { + for (uint32_t m = MStart; m + 1 < MEnd; m += 2) { + for (uint32_t o = 0; o + 1 < O; o += 2) { + int32_t sum0 = 0; + int32_t sum1 = 0; + int32_t sum2 = 0; + int32_t sum3 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[n * O + (o + 0)]; + sum1 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[n * O + (o + 1)]; + sum2 += (int32_t)pSrcA[(m + 1) * N + n] * pSrcB[n * O + (o + 0)]; + sum3 += (int32_t)pSrcA[(m + 1) * N + n] * pSrcB[n * O + (o + 1)]; + } + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (o + 0)]; + sum1 = alpha * sum1 + beta * pSrcC[(m + 0) * O + (o + 1)]; + sum2 = alpha * sum2 + beta * pSrcC[(m + 1) * O + (o + 0)]; + sum3 = alpha * sum3 + beta * pSrcC[(m + 1) * O + (o + 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + sum1 = (sum1 * mul[m + 0] + add[m + 0]) >> log2D; + sum2 = (sum2 * mul[m + 1] + add[m + 1]) >> log2D; + sum3 = (sum3 * mul[m + 1] + add[m + 1]) >> log2D; + pDstY[(m + 0) * O + (o + 0)] = (int8_t)CLAMP(sum0, -128, 127); + pDstY[(m + 0) * O + (o + 1)] = (int8_t)CLAMP(sum1, -128, 127); + pDstY[(m + 1) * O + (o + 0)] = (int8_t)CLAMP(sum2, -128, 127); + pDstY[(m + 1) * O + (o + 1)] = (int8_t)CLAMP(sum3, -128, 127); + } + + if (O % 2 == 1) { + int32_t sum0 = 0; + int32_t sum1 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[n * O + (O - 1)]; + sum1 += (int32_t)pSrcA[(m + 1) * N + n] * pSrcB[n * O + (O - 1)]; + } + + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (O - 1)]; + sum1 = alpha * sum1 + beta * pSrcC[(m + 1) * O + (O - 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + sum1 = (sum1 * mul[m + 1] + add[m + 1]) >> log2D; + pDstY[(m + 0) * O + (O - 1)] = (int8_t)CLAMP(sum0, -128, 127); + pDstY[(m + 1) * O + (O - 1)] = (int8_t)CLAMP(sum1, -128, 127); + } + } + + if (MSize % 2 == 1) { + uint32_t m = MEnd - 1; + + for (uint32_t o = 0; o + 1 < O; o += 2) { + int32_t sum0 = 0; + int32_t sum1 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[n * O + (o + 0)]; + sum1 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[n * O + (o + 1)]; + } + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (o + 0)]; + sum1 = alpha * sum1 + beta * pSrcC[(m + 0) * O + (o + 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + sum1 = (sum1 * mul[m + 0] + add[m + 0]) >> log2D; + pDstY[(m + 0) * O + (o + 0)] = (int8_t)CLAMP(sum0, -128, 127); + pDstY[(m + 0) * O + (o + 1)] = (int8_t)CLAMP(sum1, -128, 127); + } + + if (O % 2 == 1) { + int32_t sum0 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[n * O + (O - 1)]; + } + + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (O - 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + pDstY[(m + 0) * O + (O - 1)] = (int8_t)CLAMP(sum0, -128, 127); + } + } + } +} + +// Assumptions: +// - per-row requantization +// - transposed input B +// - single batch +void RQGemm_s8_transB_row_parallel(int8_t const *__restrict__ pSrcA, + int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, + int8_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, + int32_t beta, int32_t *mul, int32_t *add, + int32_t log2D) { + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + + // Parallelize by assigning each core a row tile + uint32_t const MQuotient = M / numThreads; + uint32_t const MRemainder = M % numThreads; + uint32_t const MSize = MQuotient + (core_id < MRemainder ? 1 : 0); + uint32_t const MStart = + core_id * MQuotient + (core_id < MRemainder ? core_id : MRemainder); + uint32_t const MEnd = MStart + MSize; + + if (core_id < numThreads) { + for (uint32_t m = MStart; m < MEnd; m++) { + for (uint32_t o = 0; o < O; o++) { + int32_t sum = 0; + for (uint32_t n = 0; n < N; ++n) { + sum += (int32_t)pSrcA[m * N + n] * pSrcB[o * N + n]; + } + sum = alpha * sum + beta * pSrcC[m * O + o]; + + // Requantize value + sum = (sum * mul[m] + add[m]) >> log2D; + pDstY[m * O + o] = (int8_t)CLAMP(sum, -128, 127); + } + } + } +} + +// Assumptions: +// - per-row requantization +// - transposed input B +// - single batch +void RQGemm_s8_transB_row_parallel_unrolled( + int8_t const *__restrict__ pSrcA, int8_t const *__restrict__ pSrcB, + int32_t const *__restrict__ pSrcC, int8_t *__restrict__ pDstY, uint32_t M, + uint32_t N, uint32_t O, int32_t alpha, int32_t beta, int32_t *mul, + int32_t *add, int32_t log2D) { + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + + // Parallelize by assigning each core a row tile + uint32_t const MQuotient = M / numThreads; + uint32_t const MRemainder = M % numThreads; + uint32_t const MSize = MQuotient + (core_id < MRemainder ? 1 : 0); + uint32_t const MStart = + core_id * MQuotient + (core_id < MRemainder ? core_id : MRemainder); + uint32_t const MEnd = MStart + MSize; + + if (core_id < numThreads) { + for (uint32_t m = MStart; m + 1 < MEnd; m += 2) { + for (uint32_t o = 0; o + 1 < O; o += 2) { + int32_t sum0 = 0; + int32_t sum1 = 0; + int32_t sum2 = 0; + int32_t sum3 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[(o + 0) * N + n]; + sum1 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[(o + 1) * N + n]; + sum2 += (int32_t)pSrcA[(m + 1) * N + n] * pSrcB[(o + 0) * N + n]; + sum3 += (int32_t)pSrcA[(m + 1) * N + n] * pSrcB[(o + 1) * N + n]; + } + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (o + 0)]; + sum1 = alpha * sum1 + beta * pSrcC[(m + 0) * O + (o + 1)]; + sum2 = alpha * sum2 + beta * pSrcC[(m + 1) * O + (o + 0)]; + sum3 = alpha * sum3 + beta * pSrcC[(m + 1) * O + (o + 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + sum1 = (sum1 * mul[m + 0] + add[m + 0]) >> log2D; + sum2 = (sum2 * mul[m + 1] + add[m + 1]) >> log2D; + sum3 = (sum3 * mul[m + 1] + add[m + 1]) >> log2D; + pDstY[(m + 0) * O + (o + 0)] = (int8_t)CLAMP(sum0, -128, 127); + pDstY[(m + 0) * O + (o + 1)] = (int8_t)CLAMP(sum1, -128, 127); + pDstY[(m + 1) * O + (o + 0)] = (int8_t)CLAMP(sum2, -128, 127); + pDstY[(m + 1) * O + (o + 1)] = (int8_t)CLAMP(sum3, -128, 127); + } + + if (O % 2 == 1) { + int32_t sum0 = 0; + int32_t sum1 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[(O - 1) * N + n]; + sum1 += (int32_t)pSrcA[(m + 1) * N + n] * pSrcB[(O - 1) * N + n]; + } + + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (O - 1)]; + sum1 = alpha * sum1 + beta * pSrcC[(m + 1) * O + (O - 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + sum1 = (sum1 * mul[m + 1] + add[m + 1]) >> log2D; + pDstY[(m + 0) * O + (O - 1)] = (int8_t)CLAMP(sum0, -128, 127); + pDstY[(m + 1) * O + (O - 1)] = (int8_t)CLAMP(sum1, -128, 127); + } + } + + if (MSize % 2 == 1) { + uint32_t m = MEnd - 1; + + for (uint32_t o = 0; o + 1 < O; o += 2) { + int32_t sum0 = 0; + int32_t sum1 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[(o + 0) * N + n]; + sum1 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[(o + 1) * N + n]; + } + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (o + 0)]; + sum1 = alpha * sum1 + beta * pSrcC[(m + 0) * O + (o + 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + sum1 = (sum1 * mul[m + 0] + add[m + 0]) >> log2D; + pDstY[(m + 0) * O + (o + 0)] = (int8_t)CLAMP(sum0, -128, 127); + pDstY[(m + 0) * O + (o + 1)] = (int8_t)CLAMP(sum1, -128, 127); + } + + if (O % 2 == 1) { + int32_t sum0 = 0; +#pragma unroll 2 + for (uint32_t n = 0; n < N; ++n) { + sum0 += (int32_t)pSrcA[(m + 0) * N + n] * pSrcB[(O - 1) * N + n]; + } + + sum0 = alpha * sum0 + beta * pSrcC[(m + 0) * O + (O - 1)]; + + // Requantize value + sum0 = (sum0 * mul[m + 0] + add[m + 0]) >> log2D; + pDstY[(m + 0) * O + (O - 1)] = (int8_t)CLAMP(sum0, -128, 127); + } + } + } +} + void RQGemm_parallel_s8_rv32im( int8_t const *__restrict__ pSrcA, int8_t const *__restrict__ pSrcB, int32_t const *__restrict__ pSrcC, int8_t *__restrict__ pDstY, uint32_t M, diff --git a/TargetLibraries/Snitch/src/iNoNorm.c b/TargetLibraries/Snitch/src/iNoNorm.c new file mode 100644 index 0000000..30b3c68 --- /dev/null +++ b/TargetLibraries/Snitch/src/iNoNorm.c @@ -0,0 +1,108 @@ +/* ---------------------------------------------------------------------- +# +# File: iNoNorm.c +# +# Last edited: 06.06.2024 +# +# Copyright (C) 2024, ETH Zurich and University of Bologna. +# +# Author: +# - Victor Jung, jungvi@iis.ee.ethz.ch, ETH Zurich +# +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +#include "DeeploySnitchMath.h" + +void SnitchiNoNorm_s8_s8(int8_t *data_in, int8_t *data_out, int8_t *weights, + int32_t *bias, uint32_t size, int32_t mul, + int32_t log2D) { + + uint32_t core_id = snrt_global_compute_core_idx(); + uint32_t numThreads = snrt_global_compute_core_num(); + uint32_t chunk, chunkSize, start, stop; + + chunkSize = size / numThreads; + if (core_id < (numThreads - 1)) { + chunk = chunkSize * core_id; + stop = chunk + chunkSize; + } else { + chunk = (chunkSize * core_id - 1) + (size - chunk); + stop = size; + } + start = chunk; + + uint32_t packedIn, packedWeights; + int8_t unpackedIn1, unpackedIn2, unpackedIn3, unpackedIn4; + int8_t unpackedWeights1, unpackedWeights2, unpackedWeights3, unpackedWeights4; + int16_t partialProduct1, partialProduct2, partialProduct3, partialProduct4; + + int32_t *dataInPtr = (int32_t *)(data_in); + int32_t *weightsPtr = (int32_t *)(weights); + int32_t *outputPtr = (int32_t *)(data_out); + + uint32_t firstReminderLoopSize = start % 4; + uint32_t lastReminderLoopSize = stop % 4; + uint32_t firstReminderLoopIdx = start; + uint32_t lastReminderLoopIdx = stop - lastReminderLoopSize; + start = (start + firstReminderLoopSize) >> 2; + stop = (stop - lastReminderLoopSize) >> 2; + uint32_t biasIdx = start * 4; + + // JUNGVI: Compute sequentially the first elements not aligned to a word (32b) + for (uint32_t i = firstReminderLoopIdx; + i < firstReminderLoopIdx + firstReminderLoopSize; i++) { + data_out[i] = + ((((int32_t)data_in[i] * weights[i]) + bias[i]) * mul) >> log2D; + } + + for (uint32_t i = start; i < stop; i++) { + + packedIn = dataInPtr[i]; + packedWeights = weightsPtr[i]; + + unpackedIn1 = (packedIn & 0x000000FF); + unpackedIn2 = (packedIn & 0x0000FF00) >> 8; + unpackedIn3 = (packedIn & 0x00FF0000) >> 16; + unpackedIn4 = packedIn >> 24; + + unpackedWeights1 = (packedWeights & 0x000000FF); + unpackedWeights2 = (packedWeights & 0x0000FF00) >> 8; + unpackedWeights3 = (packedWeights & 0x00FF0000) >> 16; + unpackedWeights4 = packedWeights >> 24; + + partialProduct1 = (int16_t)(unpackedIn1 * unpackedWeights1); + partialProduct2 = (int16_t)(unpackedIn2 * unpackedWeights2); + partialProduct3 = (int16_t)(unpackedIn3 * unpackedWeights3); + partialProduct4 = (int16_t)(unpackedIn4 * unpackedWeights4); + + uint8_t outBuf1 = ((partialProduct1 + bias[biasIdx + 0]) * mul) >> log2D; + uint8_t outBuf2 = ((partialProduct2 + bias[biasIdx + 1]) * mul) >> log2D; + uint8_t outBuf3 = ((partialProduct3 + bias[biasIdx + 2]) * mul) >> log2D; + uint8_t outBuf4 = ((partialProduct4 + bias[biasIdx + 3]) * mul) >> log2D; + + uint32_t outPacked = + (outBuf1 << 0) | (outBuf2 << 8) | (outBuf3 << 16) | (outBuf4 << 24); + outputPtr[i] = outPacked; + biasIdx += 4; + } + + // JUNGVI: Compute sequentially the last elements not aligned to a word (32b) + for (uint32_t i = lastReminderLoopIdx; + i < lastReminderLoopIdx + lastReminderLoopSize; i++) { + data_out[i] = + ((((int32_t)data_in[i] * weights[i]) + bias[i]) * mul) >> log2D; + } +} diff --git a/TargetLibraries/Snitch/src/snitch_nn_add_i8_i8_i8.c b/TargetLibraries/Snitch/src/snitch_nn_add_i8_i8_i8.c new file mode 100644 index 0000000..f83e3ba --- /dev/null +++ b/TargetLibraries/Snitch/src/snitch_nn_add_i8_i8_i8.c @@ -0,0 +1,205 @@ +/* + * pulp_nn_add_i8_i8_i8.c + * Georg Rutishauser + * Victor Jung + * + * Copyright (C) 2018-2020 University of Bologna + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DeeploySnitchMath.h" + +void __attribute__((noinline)) +snitch_nn_add_i8_i8_i8(int8_t *pIn1, int8_t *pIn2, int8_t *pOut, + int32_t in1_mul, int32_t in1_add, uint16_t in1_shift, + int32_t in2_mul, int32_t in2_add, uint16_t in2_shift, + int32_t out_mul, int32_t out_add, uint16_t out_shift, + uint16_t dim_im_in_x, uint16_t dim_im_in_y, + uint16_t ch_im_in, int out_requant_flag) { + int core_id = snrt_global_compute_core_idx(); + int n_cores = snrt_global_compute_core_num(); + + if (dim_im_in_y < n_cores) { + n_cores = dim_im_in_y; + } + + int Log2Core = INT_LOG2(n_cores); + int chunck = (dim_im_in_y >> Log2Core) + ((dim_im_in_y & (n_cores - 1)) != 0); + + int32_t in1_rq1, in1_rq2, in1_rq3, in1_rq4, in2_rq1, in2_rq2, in2_rq3, + in2_rq4; + int32_t sum1, sum2, sum3, sum4; + int32_t sum_out1, sum_out2, sum_out3, sum_out4; + int32_t out1, out2, out3, out4, sum_int1, sum_int2, sum_int3, sum_int4; + + int ch_im_in1_r = ch_im_in >> 0; + int ch_im_in2_r = ch_im_in >> 0; + int ch_im_out_r = ch_im_in >> 0; + + int start = MIN(chunck * core_id, dim_im_in_y); + int stop = MIN(start + chunck, dim_im_in_y); + + int8_t *target1 = pIn1 + start * ch_im_in1_r * dim_im_in_x; + int8_t *target2 = pIn2 + start * ch_im_in2_r * dim_im_in_x; + int8_t *pOutBuffer = pOut + start * ch_im_out_r * dim_im_in_x; + + int a = 0; + int b = 0; + + int8_t *target1_ext = &a; + int8_t *target2_ext = &b; + + for (int i = 0; i < (((stop - start) * ch_im_out_r * dim_im_in_x) >> 2); + i++) { + target1_ext = target1; + target1 += 4; + + target2_ext = target2; + target2 += 4; +#ifdef ADD_VERBOSE + printf("core %d - in1 it0 before requant: %d\n", core_id, *(target1_ext)); + printf("core %d - in2 it0 before requant: %d\n", core_id, *(target2_ext)); +#endif + in1_rq1 = ((*(target1_ext)) * in1_mul + in1_add) >> in1_shift; + in2_rq1 = ((*(target2_ext)) * in2_mul + in2_add) >> in2_shift; + sum1 = clips8(in1_rq1) + clips8(in2_rq1); +#ifdef ADD_VERBOSE + printf("core %d - in1_rq1 it0 after requant: %d\nclipped in1_rq1: %d\n", + core_id, in1_rq1, clips8(in1_rq1)); + printf("core %d - in2_rq1 it0 after requant: %d\nclipped in2_rq1: %d\n", + core_id, in2_rq1), + clips8(in2_rq1); + printf("core %d - sum1: %d\n", core_id, sum1); +#endif +#ifdef ADD_VERBOSE + printf("core %d - in1 it1 before requant: %d\n", core_id, + *(target1_ext + 1)); + printf("core %d - in2 it1 before requant: %d\n", core_id, + *(target2_ext + 1)); +#endif + in1_rq2 = ((*(target1_ext + 1)) * in1_mul + in1_add) >> in1_shift; + in2_rq2 = ((*(target2_ext + 1)) * in2_mul + in2_add) >> in2_shift; + sum2 = clips8(in1_rq2) + clips8(in2_rq2); +#ifdef ADD_VERBOSE + printf("core %d - in1_rq2 it1 after requant: %d\nclipped in1_rq2: %d\n", + core_id, in1_rq2, clips8(in1_rq2)); + printf("core %d - in2_rq2 it1 after requant: %d\nclipped in2_rq2: %d\n", + core_id, in2_rq2), + clips8(in2_rq2); + printf("core %d - sum2: %d\n", core_id, sum2); +#endif +#ifdef ADD_VERBOSE + printf("core %d - in1 it2 before requant: %d\n", core_id, + *(target1_ext + 2)); + printf("core %d - in2 it2 before requant: %d\n", core_id, + *(target2_ext + 2)); +#endif + in1_rq3 = ((*(target1_ext + 2)) * in1_mul + in1_add) >> in1_shift; + in2_rq3 = ((*(target2_ext + 2)) * in2_mul + in2_add) >> in2_shift; + sum3 = clips8(in1_rq3) + clips8(in2_rq3); +#ifdef ADD_VERBOSE + printf("core %d - in1_rq3 it2 after requant: %d\nclipped in1_rq3: %d\n", + core_id, in1_rq3, clips8(in1_rq3)); + printf("core %d - in2_rq3 it2 after requant: %d\nclipped in2_rq3: %d\n", + core_id, in2_rq3), + clips8(in2_rq3); + printf("core %d - sum3: %d\n", core_id, sum3); +#endif +#ifdef ADD_VERBOSE + printf("core %d - in1 it3 before requant: %d\n", core_id, + *(target1_ext + 3)); + printf("core %d - in2 it3 before requant: %d\n", core_id, + *(target2_ext + 3)); +#endif + in1_rq4 = ((*(target1_ext + 3)) * in1_mul + in1_add) >> in1_shift; + in2_rq4 = ((*(target2_ext + 3)) * in2_mul + in2_add) >> in2_shift; + sum4 = clips8(in1_rq4) + clips8(in2_rq4); +#ifdef ADD_VERBOSE + printf("core %d - in1_rq4 it3 after requant: %d\nclipped in1_rq4: %d\n", + core_id, in1_rq4, clips8(in1_rq4)); + printf("core %d - in2_rq4 it3 after requant: %d\nclipped in2_rq4: %d\n", + core_id, in2_rq4), + clips8(in2_rq4); + printf("core %d - sum4: %d\n", core_id, sum4); +#endif + + if (out_requant_flag) { + sum1 = (sum1 * out_mul + out_add) >> out_shift; +#ifdef ADD_VERBOSE + printf("core %d - requantized sum1: %d\n", core_id, sum1); +#endif + sum2 = (sum2 * out_mul + out_add) >> out_shift; +#ifdef ADD_VERBOSE + printf("core %d - requantized sum2: %d\n", core_id, sum2); +#endif + sum3 = (sum3 * out_mul + out_add) >> out_shift; +#ifdef ADD_VERBOSE + printf("core %d - requantized sum3: %d\n", core_id, sum3); +#endif + sum4 = (sum4 * out_mul + out_add) >> out_shift; +#ifdef ADD_VERBOSE + printf("core %d - requantized sum4: %d\n", core_id, sum4); +#endif + } + out1 = clips8(sum1); +#ifdef ADD_VERBOSE + printf("core %d - out1 clipped: %d\n", core_id, out1); +#endif + out2 = clips8(sum2); +#ifdef ADD_VERBOSE + printf("core %d - out2 clipped: %d\n", core_id, out2); +#endif + out3 = clips8(sum3); +#ifdef ADD_VERBOSE + printf("core %d - out3 clipped: %d\n", core_id, out3); +#endif + out4 = clips8(sum4); +#ifdef ADD_VERBOSE + printf("core %d - out4 clipped: %d\n", core_id, out4); +#endif + + *pOutBuffer = (int8_t)out1; + pOutBuffer++; + *pOutBuffer = (int8_t)out2; + pOutBuffer++; + *pOutBuffer = (int8_t)out3; + pOutBuffer++; + *pOutBuffer = (int8_t)out4; + pOutBuffer++; + } + // SCHEREMO: Cleanup leftovers, not doing it with this codebase for sub-byte + // formats + for (int i = 0; i < (((stop - start) * ch_im_out_r * dim_im_in_x) % 4); i++) { + in1_rq1 = ((*(target1)) * in1_mul + in1_add) >> in1_shift; + in2_rq1 = ((*(target2)) * in2_mul + in2_add) >> in2_shift; + +// SCHEREMO: Maybe it's just LLVM, but unless I hack 3 non-unrolled nops in +// here, stuff fails +#pragma nounroll + for (int j = 0; j < 3; j++) { + asm volatile("nop" ::); + } + + target1++; + target2++; + sum1 = clips8(in1_rq1) + clips8(in2_rq1); + if (out_requant_flag) { + sum1 = (sum1 * out_mul + out_add) >> out_shift; + } + + out1 = clips8(sum1); + *pOutBuffer = (int8_t)out1; + pOutBuffer++; + } +} diff --git a/toolchain/snitch_cluster.patch b/toolchain/snitch_cluster.patch index c39c525..b50b33c 100644 --- a/toolchain/snitch_cluster.patch +++ b/toolchain/snitch_cluster.patch @@ -15,21 +15,34 @@ index d0979b7..171921d 100644 + KEEP(*(.cbss .cbss.*)) __cbss_end = .; } >L3 - + +diff --git a/sw/snRuntime/src/alloc.h b/sw/snRuntime/src/alloc.h +index ba1dee9..f79769f 100644 +--- a/sw/snRuntime/src/alloc.h ++++ b/sw/snRuntime/src/alloc.h +@@ -69,6 +69,8 @@ inline void *snrt_l3alloc(size_t size) { + + // TODO: L3 alloc size check + ++ size = ALIGN_UP(size, MIN_CHUNK_SIZE); ++ + void *ret = (void *)alloc->next; + alloc->next += size; + return ret; diff --git a/sw/snRuntime/src/team.c b/sw/snRuntime/src/team.c index a9eb840..5290e1d 100644 --- a/sw/snRuntime/src/team.c +++ b/sw/snRuntime/src/team.c @@ -10,6 +10,10 @@ extern uint32_t snrt_global_core_idx(); - + extern uint32_t snrt_global_core_num(); - + +extern uint32_t snrt_global_compute_core_num(); + +extern uint32_t snrt_global_compute_core_idx(); + extern uint32_t snrt_cluster_idx(); - + extern uint32_t snrt_cluster_num(); diff --git a/target/snitch_cluster/sw/runtime/rtl/src/putchar.c b/target/snitch_cluster/sw/runtime/rtl/src/putchar.c index 0ad9500..215c8b1 100644 @@ -37,7 +50,7 @@ index 0ad9500..215c8b1 100644 +++ b/target/snitch_cluster/sw/runtime/rtl/src/putchar.c @@ -5,16 +5,19 @@ extern uintptr_t volatile tohost, fromhost; - + // Rudimentary string buffer for putc calls. -extern uint32_t _edram; #define PUTC_BUFFER_LEN (1024 - sizeof(size_t)) @@ -58,6 +71,6 @@ index 0ad9500..215c8b1 100644 +} putc_buffer_t; + +static volatile putc_buffer_t putc_buffer[SNRT_CLUSTER_NUM*SNRT_CLUSTER_CORE_NUM] __attribute__((section(".dram"))); - + // Provide an implementation for putchar. void _putchar(char character) {