Skip to content

Commit

Permalink
apply Victor's comment
Browse files Browse the repository at this point in the history
  • Loading branch information
tahaelbayad committed Jan 8, 2025
1 parent 41dcc0f commit 478d618
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 59 deletions.
23 changes: 23 additions & 0 deletions Deeploy/Targets/Generic/TypeCheckers.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,26 @@ def _inferSignedness(self, inputs: List[VariableBuffer],
return [True]
else:
return [False]


class RQAddChecker(SignPropTypeChecker):

def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]):
super().__init__(input_types, output_types)

def _inferNumLevels(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[int]:
return [operatorRepresentation['rqsOut_n_levels']]

def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[bool]:
return [bool(operatorRepresentation["rqsOut_signed"])]

# Override this. This should compute the signednes of each output node of the Layer
def checkOutputType(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> bool:
outputTypeSigned = self.output_types[0].referencedType.typeMin < 0
if operatorRepresentation['rqsOut_signed'] and outputTypeSigned:
return True
if (not operatorRepresentation['rqsOut_signed']) and (not outputTypeSigned):
return True
return False
6 changes: 3 additions & 3 deletions Deeploy/Targets/PULPOpen/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
from Deeploy.Targets.Generic.Templates import ConcatTemplate, RQSiGELUTemplate, iHardswishTemplate
from Deeploy.Targets.Generic.TypeCheckers import ConcatChecker, GELUChecker, HardswishChecker, MatMulChecker, \
MulChecker, ReduceMeanChecker, RQHardswishChecker, SliceChecker, SoftmaxChecker, TransposeChecker, \
MulChecker, ReduceMeanChecker, RQAddChecker, RQHardswishChecker, SliceChecker, SoftmaxChecker, TransposeChecker, \
iLayerNormChecker
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterSynch import PULPSynchCoresPass
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterTiling import PULPClusterTiling
Expand All @@ -48,7 +48,7 @@
MulTemplate, ReduceMeanTemplate, RequantShiftTemplate, RQAddTemplate, RQSiHardswishTemplate, SliceTemplate, \
TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, iRMSNormTemplate, iSoftmaxTemplate
from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \
PULPRequantShiftChecker, PULPRQAddChecker
PULPRequantShiftChecker
from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement

_clusterEntryClosureCallTemplate = NodeTemplate("""
Expand Down Expand Up @@ -156,7 +156,7 @@
]

PULPRQAddBindings = [
NodeBinding(PULPRQAddChecker([PointerClass(_type), PointerClass(_type2)], [PointerClass(_type3)]),
NodeBinding(RQAddChecker([PointerClass(_type), PointerClass(_type2)], [PointerClass(_type3)]),
RQAddTemplate.referenceTemplate, ForkTransformer)
for _type in [int8_t, uint8_t]
for _type2 in [int8_t, uint8_t]
Expand Down
5 changes: 2 additions & 3 deletions Deeploy/Targets/Snitch/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
from Deeploy.Targets.Generic.Templates import iNoNormTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, GEMMChecker, SoftmaxChecker, iNoNormChecker
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, GEMMChecker, RQAddChecker, SoftmaxChecker, iNoNormChecker
from Deeploy.Targets.Snitch.CodeTransformationPasses import SnitchClusterTiling, SnitchCoreFilterPass, \
SnitchProfileExecutionBlockPass, SnitchSynchCoresPass
from Deeploy.Targets.Snitch.Templates import AddTemplate, RQAddTemplate, iSoftmaxTemplate
from Deeploy.Targets.Snitch.Templates.GemmTemplate import SnitchGemm_Template
from Deeploy.Targets.Snitch.Templates.RqGemmTemplate import SnitchRqGemm_Template
from Deeploy.Targets.Snitch.TypeCheckers import SnitchRQAddChecker
from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement

TilingCallClosure = partial(ClosureGeneration, closureSuffix = "_tiling_closure")
Expand Down Expand Up @@ -78,7 +77,7 @@
TiledTransformer) for _type in [int8_t]
]
SnitchRQAddBindings = [
NodeBinding(SnitchRQAddChecker([PointerClass(_type), PointerClass(_type)], [PointerClass(_type)]),
NodeBinding(RQAddChecker([PointerClass(_type), PointerClass(_type)], [PointerClass(_type)]),
RQAddTemplate.referenceTemplate, TiledTransformer) for _type in [int8_t]
]
SnitchAddBindings = [
Expand Down
53 changes: 0 additions & 53 deletions Deeploy/Targets/Snitch/TypeCheckers.py

This file was deleted.

0 comments on commit 478d618

Please sign in to comment.