Skip to content

Commit

Permalink
Additional cleanup in TensileCreateLibrary
Browse files Browse the repository at this point in the history
  • Loading branch information
ellosel committed Dec 19, 2024
1 parent 52c745c commit f5c3ff5
Showing 1 changed file with 10 additions and 70 deletions.
80 changes: 10 additions & 70 deletions tensilelite/Tensile/TensileCreateLibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@
exit(1)

from . import Common
from . import ClientExecutable
from . import EmbeddedData
from . import LibraryIO
from . import Utils
from .TensileInstructions import getGfxName, TensileInstructions
from .Common import globalParameters, HR, print1, print2, printExit, ensurePath, \
CHeader, CMakeHeader, assignGlobalParameters, \
architectureMap, printWarning, \
splitArchs
CHeader, assignGlobalParameters, architectureMap
from .KernelWriterAssembly import KernelWriterAssembly
from .SolutionLibrary import MasterSolutionLibrary
from .SolutionStructs import Solution
Expand Down Expand Up @@ -69,7 +65,8 @@ def wrapper(*args, **kwargs):

return res
return wrapper
################################################################################


def processKernelSource(kernel, kernelWriterAssembly, ti):
"""
Generate source for a single kernel.
Expand All @@ -91,8 +88,6 @@ def processKernelSource(kernel, kernelWriterAssembly, ti):
return (err, src, header, kernelName, filename)



################################################################################
def buildKernelSourceAndHeaderFiles(results, outputPath, kernelsWithBuildErrs):
"""
Logs errors and writes appropriate info to kernelSourceFile and kernelHeaderFile.
Expand All @@ -108,23 +103,19 @@ def buildKernelSourceAndHeaderFiles(results, outputPath, kernelsWithBuildErrs):
sourceFilenames: Array containing source kernel filenames
"""

# Find kernels to write
kernelsToWrite = []
filesToWrite = collections.defaultdict(list)
validKernelCount = 0
for (err,src,header,kernelName, filename) in results:

# Keep track of kernels with errors
if err:
kernelsWithBuildErrs[kernelName] = err

# Don't create a file for empty kernels
if len(src.strip()) == 0:
continue

kernelsToWrite.append((err, src, header, kernelName))

# Create list of files
if filename:
filesToWrite[os.path.join(os.path.normcase(outputPath),filename)].append((err, src, header, kernelName))
else:
Expand All @@ -139,9 +130,8 @@ def buildKernelSourceAndHeaderFiles(results, outputPath, kernelsWithBuildErrs):
kernelSuffix = ""
filesToWrite[os.path.join(os.path.normcase(outputPath), "Kernels"+kernelSuffix)] = []


# Write kernel data to files
#Parse list of files and write kernels
# Parse list of files and write kernels
for filename, kernelList in filesToWrite.items():
with open(filename+".h", "w", encoding="utf-8") as kernelHeaderFile, \
open(filename+".cpp", "w", encoding="utf-8") as kernelSourceFile:
Expand Down Expand Up @@ -184,14 +174,8 @@ def writeSolutionsAndKernels(outputPath, cxxCompiler, assembler, offloadBundler,
kernelSourceFile = None
kernelHeaderFile = None

##############################################################################
# Write Kernels
##############################################################################
kernelsWithBuildErrs = {}

# Kernels may be intended for different co files, but generate the same .o file
# Mark duplicate kernels to avoid race condition
# @TODO improve organization so this problem doesn't appear
objFilenames = set()
for kernel in kernels:
if kernel["KernelLanguage"] == "Assembly":
Expand Down Expand Up @@ -256,7 +240,6 @@ def success(kernel):
kernelHeaderFile = open(kernelHeaderFilename, "a", encoding="utf-8")

HeaderText = ""
# handle helper kernel function
for ko in kernelHelperObjs:
kernelName = ko.getKernelName()

Expand All @@ -267,7 +250,6 @@ def success(kernel):

HeaderText += ko.getHeaderFileString()

# write kernel.h in one shot
kernelHeaderFile.write(HeaderText)

if kernelSourceFile:
Expand All @@ -282,12 +264,9 @@ def success(kernel):
Common.popWorkingPath() # build_tmp
Common.popWorkingPath() # workingDir

return codeObjectFiles, numKernels
return numKernels


##############################################################################
# Min Naming / Solution and Kernel Writers
##############################################################################
@timing
def getSolutionAndKernelWriters(solutions, kernels, assembler):

Expand All @@ -300,9 +279,7 @@ def getSolutionAndKernelWriters(solutions, kernels, assembler):

return (kernelWriterAssembly, kernelMinNaming, solutionMinNaming)

################################################################################
# copy static cpp files and headers
################################################################################

@timing
def copyStaticFiles(outputPath=None):
if outputPath is None:
Expand All @@ -324,9 +301,6 @@ def copyStaticFiles(outputPath=None):
return libraryStaticFiles


################################################################################
# Generate Kernel Objects From Solutions
################################################################################
@timing
def generateKernelObjectsFromSolutions(solutions):
# create solution writer and kernel writer
Expand All @@ -351,13 +325,10 @@ def generateKernelObjectsFromSolutions(solutions):
kernelHelperObjs = list(dict.fromkeys(kernelHelperObjs))
return (kernels, kernelHelperObjs, kernelHelperNames)

################################################################################
# Generate Logic Data and Solutions
################################################################################

@timing
def generateLogicDataAndSolutions(logicFiles, args, cxxCompiler):

# skip the logic which architectureName is not in the build target.
if ";" in args.Architecture:
archs = args.Architecture.split(";") # user arg list format
else:
Expand Down Expand Up @@ -398,7 +369,6 @@ def libraryIter(lib: MasterSolutionLibrary):
fullMasterLibrary.merge(newLibrary)

if args.GenSolTable:
# Match yaml file solutions to solution index
for localIdx, _, s in libraryIter(newLibrary):
matchTable[s.index] = [srcFile, localIdx]

Expand Down Expand Up @@ -446,9 +416,7 @@ def validateLibrary(masterLibraries: MasterSolutionLibrary,

assert ok and "Inconsistent kernel sizes detected!"

################################################################################
# Tensile Create Library
################################################################################

@profile
def TensileCreateLibrary():
start = timer()
Expand All @@ -458,9 +426,6 @@ def TensileCreateLibrary():
print2(HR)
print2("")

##############################################################################
# Parse Command Line Arguments
##############################################################################
def splitExtraParameters(par):
"""
Allows the --global-parameters option to specify any parameters from the command line.
Expand Down Expand Up @@ -534,7 +499,6 @@ def splitExtraParameters(par):
assembler = args.Assembler
libraryFormat = args.LibraryFormat
useCompression = not args.NoCompress
print2("OutputPath: %s" % outputPath)
ensurePath(outputPath)
outputPath = os.path.abspath(outputPath)
arguments = {}
Expand All @@ -553,12 +517,6 @@ def splitExtraParameters(par):
arguments["LibraryFormat"] = args.LibraryFormat
if args.no_enumerate:
arguments["AMDGPUArchPath"] = False

arguments["GenerateSourcesAndExit"] = args.GenerateSourcesAndExit
if arguments["GenerateSourcesAndExit"]:
# Generated sources are preserved and go into output dir
arguments["WorkingPath"] = outputPath

arguments["CpuThreads"] = args.CpuThreads
arguments["PrintLevel"] = args.PrintLevel
arguments["PrintTiming"] = args.PrintTiming
Expand Down Expand Up @@ -600,7 +558,6 @@ def splitExtraParameters(par):
else:
printExit("Architecture %s not supported" % arch)

# Recursive directory search
logicExtFormat = ".yaml"
if args.LogicFormat == "yaml":
pass
Expand Down Expand Up @@ -628,32 +585,21 @@ def validLogicFile(p: Path):
for logicFile in logicFiles:
print2("# %s" % logicFile)


##############################################################################
# Parse config files
##############################################################################

# Parse logicData, solutions, and masterLibraries from logic files
solutions, masterLibraries, fullMasterLibrary = generateLogicDataAndSolutions(logicFiles, args, cxxCompiler)

kernels, kernelHelperObjs, _ = generateKernelObjectsFromSolutions(solutions)

# if any kernels are assembly, append every ISA supported
kernelWriterAssembly, kernelMinNaming, _ = getSolutionAndKernelWriters(solutions, kernels, assembler)

if globalParameters["ValidateLibrary"]:
validateLibrary(masterLibraries, kernels, kernelWriterAssembly)

staticFiles = copyStaticFiles(outputPath)

# Make sure to copy the library static files.
for fileName in staticFiles:
shutil.copy( os.path.join(globalParameters["SourcePath"], fileName), \
outputPath )

# write solutions and kernels
codeObjectFiles, numKernels = writeSolutionsAndKernels(outputPath, cxxCompiler, assembler, offloadBundler, solutions,
kernels, kernelHelperObjs, kernelWriterAssembly, compress=useCompression)
numKernels = writeSolutionsAndKernels(outputPath, cxxCompiler, assembler, offloadBundler, solutions,
kernels, kernelHelperObjs, kernelWriterAssembly, compress=useCompression)

archs = [getGfxName(arch) for arch in globalParameters['SupportedISA'] \
if globalParameters["AsmCaps"][arch]["SupportedISA"]]
Expand All @@ -669,22 +615,16 @@ def validLogicFile(p: Path):
newMasterLibrary.applyNaming(kernelMinNaming)
LibraryIO.write(masterFile, Utils.state(newMasterLibrary), args.LibraryFormat)

#Write placeholder libraries
for name, lib in newMasterLibrary.lazyLibraries.items():
filename = os.path.join(newLibraryDir, name)
lib.applyNaming(kernelMinNaming) #@TODO Check to see if kernelMinNaming is correct
LibraryIO.write(filename, Utils.state(lib), args.LibraryFormat)

else:
masterFile = os.path.join(newLibraryDir, "TensileLibrary")
fullMasterLibrary.applyNaming = timing(fullMasterLibrary.applyNaming)
fullMasterLibrary.applyNaming(kernelMinNaming)
LibraryIO.write(masterFile, Utils.state(fullMasterLibrary), args.LibraryFormat)

theMasterLibrary = fullMasterLibrary
if globalParameters["SeparateArchitectures"]:
theMasterLibrary = list(masterLibraries.values())[0]

print1("# Tensile Library Writer DONE")
print1(HR)
print1("")
Expand Down

0 comments on commit f5c3ff5

Please sign in to comment.