Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto option ranks #84

Merged
merged 15 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AzslcBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ namespace AZ::ShaderCompiler
// We reserve the right to change it in the future so we make it explicit attribute here
shaderOption["order"] = optionOrder;
optionOrder++;
shaderOption["costImpact"] = varInfo->m_estimatedCostImpact;

bool isUdt = IsUserDefined(varInfo->GetTypeClass());
assert(isUdt || IsPredefinedType(varInfo->GetTypeClass()));
Expand Down
1 change: 0 additions & 1 deletion src/AzslcEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,6 @@ namespace AZ::ShaderCompiler
const ICodeEmissionMutator* codeMutator = m_codeMutator;

ssize_t ii = interval.a;
bool wasInPreprocessorDirective = false; // record a state to detect exit of directives, because they need to reside on their own lines
while (ii <= interval.b)
{
auto* token = GetNextToken(ii /*inout*/);
Expand Down
4 changes: 1 addition & 3 deletions src/AzslcIntermediateRepresentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ namespace AZ::ShaderCompiler
cout << " storage: " << sub.m_typeInfoExt.m_qualifiers.GetDisplayName() << "\n";
cout << " array dim: \"" << sub.m_typeInfoExt.m_arrayDims.ToString() << "\"\n";
cout << " has sampler state: " << (sub.m_samplerState ? "yes\n" : "no\n");
cout << "\n";
if (!holds_alternative<monostate>(sub.m_constVal))
{
cout << " val: " << ExtractValueAsInt64(sub.m_constVal) << "\n";
Expand Down Expand Up @@ -519,7 +518,7 @@ namespace AZ::ShaderCompiler
if (varInfo.GetTypeClass() == TypeClass::Enum)
{
auto* asClassInfo = GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
}

nextMemberStartingOffset = Packing::PackNextChunk(layoutPacking, size, startAt);
Expand Down Expand Up @@ -960,5 +959,4 @@ namespace AZ::ShaderCompiler
}
return memberList[memberList.size() - 1];
}

} // end of namespace AZ::ShaderCompiler
1 change: 0 additions & 1 deletion src/AzslcIntermediateRepresentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

namespace AZ::ShaderCompiler
{

//! We limit the maximum number of render targets to 8, with indices in the range [0..7]
static const uint32_t kMaxRenderTargets = 8;

Expand Down
4 changes: 3 additions & 1 deletion src/AzslcKindInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ namespace AZ::ShaderCompiler
//! Get the size of a single element, ignoring array dimensions
const uint32_t GetSingleElementSize(Packing::Layout layout, bool defaultRowMajor) const
{
auto baseSize = m_coreType.m_arithmeticInfo.GetBaseSize();
auto baseSize = m_coreType.m_arithmeticInfo.m_baseSize;
bool isRowMajor = (m_mtxMajor == Packing::MatrixMajor::RowMajor ||
(m_mtxMajor == Packing::MatrixMajor::Default && defaultRowMajor));
auto rows = m_coreType.m_arithmeticInfo.m_rows;
Expand Down Expand Up @@ -399,6 +399,7 @@ namespace AZ::ShaderCompiler
ConstNumericVal m_constVal; // (attempted folded) initializer value for simple scalars
optional<SamplerStateDesc> m_samplerState;
ExtendedTypeInfo m_typeInfoExt;
int m_estimatedCostImpact = -1; //!< Cached value calculated by AnalyzeOptionRanks
};

// VarInfo methods definitions
Expand Down Expand Up @@ -791,6 +792,7 @@ namespace AZ::ShaderCompiler
vector< IdentifierUID > m_overrides; //!< list of implementing functions in child classes
optional< IdentifierUID > m_base; //!< points to the overridden function in the base interface, if applies. only supports one base
FunctionMultiForwards m_multiFwds = FMF_None; //!< presence of redundant prototype-only declarations
int m_costScore = -1; //!< heuristical static analysis of the amount of instructions contained
struct Parameter
{
IdentifierUID m_varId;
Expand Down
3 changes: 1 addition & 2 deletions src/AzslcMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ namespace StdFs = std::filesystem;
// For large features or milestones. Minor version allows for breaking changes. Existing tests can change.
#define AZSLC_MINOR "8" // last change: introduction of class inheritance
// For small features or bug fixes. They cannot introduce breaking changes. Existing tests shouldn't change.
#define AZSLC_REVISION "17" // last change: fixup alignment check logic_error because of lack of an inter-scope check limiter.
// "16" change: fixup runtime error with redundant function declarations
#define AZSLC_REVISION "18" // last change: automatic option ranks

namespace AZ::ShaderCompiler
{
Expand Down
242 changes: 223 additions & 19 deletions src/AzslcReflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ namespace AZ::ShaderCompiler
else if (varInfo.GetTypeClass() == TypeClass::Enum)
{
auto* asClassInfo = m_ir->GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
}

offset = Packing::PackNextChunk(layoutPacking, size, startAt);
Expand Down Expand Up @@ -629,7 +629,9 @@ namespace AZ::ShaderCompiler

void CodeReflection::DumpVariantList(const Options& options) const
{
AnalyzeOptionRanks();
m_out << GetVariantList(options);
m_out << "\n";
}

static void ReflectBinding(Json::Value& output, const RootSigDesc::SrgParamDesc& bindInfo)
Expand Down Expand Up @@ -857,11 +859,12 @@ namespace AZ::ShaderCompiler
for (auto& seenat : kindInfo->GetSeenats())
{
assert(uid == seenat.m_referredDefinition);
// TODO: the assumption that intervals where distinct doesnt hold anymore now that we have unnamed scopes
auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
{
return value.first.properlyContains({key, key});
});
// careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes)
// ok for now because AZSL/HLSL don't have lambdas
auto intervalIter = FindIntervalInDisjointSet(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
{
return value.first.properlyContains({key, key});
});
if (intervalIter != scopes.cend())
{
const IdentifierUID& encloser = intervalIter->second.second;
Expand Down Expand Up @@ -909,16 +912,9 @@ namespace AZ::ShaderCompiler
uint32_t numOf32bitConst = GetNumberOf32BitConstants(options, m_ir->m_rootConstantStructUID);
RootSigDesc rootSignature = BuildSignatureDescription(options, numOf32bitConst);

// prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
MapOfBeginToSpanAndUid scopeStartToFunctionIntervals;
for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals)
{
if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types. We need a set of disjoint intervals as an invariant for the next algorithm.
{
// the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
scopeStartToFunctionIntervals[interval.a] = std::make_pair(interval, uid);
}
}
// Prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
// (truth: we need a set of disjoint intervals as an invariant for the following algorithm)
GenerateTokenScopeIntervalToUidReverseMap();

Json::Value srgRoot(Json::objectValue);
// Order the reflection by SRG for convenience
Expand Down Expand Up @@ -968,7 +964,7 @@ namespace AZ::ShaderCompiler
else
{
set<IdentifierUID> dependencyList;
DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, scopeStartToFunctionIntervals);
DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals);
srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {});
}
}
Expand All @@ -981,7 +977,7 @@ namespace AZ::ShaderCompiler
for (auto& srgConstant : srgInfo->m_implicitStruct.GetMemberFields())
{
allConstants.append({ srgConstant.GetNameLeaf() });
DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, scopeStartToFunctionIntervals);
DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, m_functionIntervals);
}
// variant fallback support
if (srgInfo->m_shaderVariantFallback)
Expand All @@ -992,7 +988,7 @@ namespace AZ::ShaderCompiler
{
if (varSub->CheckHasStorageFlag(StorageFlag::Option))
{
DiscoverTopLevelFunctionDependencies(varUid, dependencyList, scopeStartToFunctionIntervals);
DiscoverTopLevelFunctionDependencies(varUid, dependencyList, m_functionIntervals);
}
}
}
Expand All @@ -1004,4 +1000,212 @@ namespace AZ::ShaderCompiler

m_out << srgRoot;
}

// Helper routine for option rank analysis
static int GuesstimateIntrinsicFunctionCost(string_view funcName)
{
if (IsOneOf(funcName, "CallShader", "TraceRay"))
{ // non measurable but assumed high
return 100;
}
else if (IsOneOf(funcName, "Sample", "Load", "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append"))
{ // memory access, locked or not, will have high latency
return 10;
}
else
{ // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1.
return 1;
}
}

// Helper routine for option rank analysis. When picking AN overload is more useful than forfeiting.
// The function GetConcreteFunctionThatMatchesArgumentList forfeits when the overloadset contains
// strictly more than 1 concrete function with the queried arity. In our case, we prefer to just pick any.
static IdentifierUID PickAnyOverloadThatMatchesArgCount(IntermediateRepresentation* ir,
azslParser::FunctionCallExpressionContext* callNode,
KindInfo& overload)
{
IdentifierUID concrete;
size_t numArgs = NumArgs(callNode);
overload.GetSubAs<OverloadSetInfo>()->AnyOf(
[&](IdentifierUID const& uid)
{
auto* concreteFcInfo = ir->GetSymbolSubAs<FunctionInfo>(uid.GetName());
size_t numParams = concreteFcInfo->GetParameters(true).size();
if (numParams == numArgs)
{
concrete = uid; // we write the result through reference capture (not clean but convenient)
return true;
}
return false;
});
return concrete;
}

void CodeReflection::AnalyzeOptionRanks() const
{
// make sure we have the scope lookup cache ready
GenerateTokenScopeIntervalToUidReverseMap();
// loop over variables
for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
{
// only options
if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
{
int impactScore = 0;
// loop over appearances over the program
for (Seenat& ref : kindInfo->GetSeenats())
{
// determine an impact score
impactScore += AnalyzeImpact(ref.m_where) // dependent code that may be skipped depending on the value of that ref
+ 1; // by virtue of being mentioned (seenat), we count the reference as an access of cost 1.
}
varInfo->m_estimatedCostImpact = impactScore;
}
}
}

int CodeReflection::AnalyzeImpact(TokensLocation const& location) const
{
// find the node at `location`:
ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
// go up tree to meet a block node that has visitable depth:
// can be any of if/for/while/switch
// 4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for
if (auto* whileNode = DeepParentAs<azslParser::WhileStatementContext*>(node->parent, 3))
{
node = whileNode->embeddedStatement();
}
else if (auto* ifNode = DeepParentAs<azslParser::IfStatementContext*>(node->parent, 3))
{
node = ifNode->embeddedStatement();
}
else if (auto* forNode = DeepParentAs<azslParser::ForStatementContext*>(node->parent, 4))
{
node = forNode->embeddedStatement();
}
else if (auto* switchNode = DeepParentAs<azslParser::SwitchStatementContext*>(node->parent, 3))
{
node = switchNode->switchBlock();
}
int score = 0;
AnalyzeImpact(node, score);
return score;
}

void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const
{
for (auto& c : astNode->children)
{
if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
{
// branch into an overload specialized for function lookup:
AnalyzeImpact(callNode, scoreAccumulator);
}
else if (auto* node = As<ParserRuleContext*>(c))
{
AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;"
}
if (auto* leaf = As<tree::TerminalNode*>(c))
{
// determine cost by number of full expressions separated by semicolon
scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi; // bool as 0 or 1 trick
}
}
}

void CodeReflection::AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const
{
// to access the function symbol info we need the current scope, the function call name and perform a lookup.

// figure out the scope at this token.
// theoretically should be something in the like of the body of another function,
// or an anonymous block within another function.
auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex());
if (!interval.IsEmpty())
{
IdentifierUID encloser = m_intervalToUid[interval];

// Because we are past the end of the semantic analysis,
// the scope tracker is registering the last seen scope (surely "/").
// This is a stateful side-effect system unfortunately, and since we'll call
// some feature of the semantic orchestrator (like TypeofExpr) we need to hack
// the scope tracker:
m_ir->m_sema.m_scope->m_currentScopePath = encloser.GetName();
m_ir->m_sema.m_scope->UpdateCurScopeUID();

QualifiedName startupLookupScope = encloser.GetName();
UnqualifiedName funcName;
if (auto* idExpr = As<azslParser::IdentifierExpressionContext*>(callNode->Expr))
{
funcName = ExtractNameFromIdExpression(idExpr->idExpression());
}
else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
{
startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
funcName = ExtractNameFromIdExpression(maeExpr->Member);
}
IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
if (!overload) // in case of function not found, we assume it's an intrinsic.
{
scoreAccumulator += GuesstimateIntrinsicFunctionCost(funcName);
}
else
{
azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode);
IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args);
IdentifierUID concrete;
if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet)
{ // in case of strict selection failure, run a fuzzy select
concrete = PickAnyOverloadThatMatchesArgCount(m_ir, callNode, overload->second);
// if still not enough to get a fix (concrete=={}), it might be an ill-formed input. prefer to forfeit
}
else
{
concrete = symbolMeantUnderCallNode->first;
}

if (auto* funcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(concrete.GetName()))
{
if (funcInfo->m_costScore == -1) // cost not yet discovered for this function
{
funcInfo->m_costScore = 0;
using AstFDef = azslParser::HlslFunctionDefinitionContext;
AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
funcInfo->m_costScore); // recurse and cache
}
scoreAccumulator += funcInfo->m_costScore;
}
}
// other cases forfeited for now, but that would at least include things like eg braces (f)()
}
else // no interval found
{
// function calls outside of function bodies can appear in an initializer:
// int g_a = MakeA(); // global init
// class C { int m_a = CompA(); // constructor init (invalid AZSL/HLSL)
// class D { void Method(int a_a = DefaultA()); // default parameter value
// in any case, extracting the scope is impossible with this system.
// we forfeit evaluation of a score
}
}

void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const
{
if (m_functionIntervals.empty())
{
for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals)
{
if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types.
{
// the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
m_functionIntervals[interval.a] = std::make_pair(interval, uid);
}
auto i = Interval<ssize_t>{interval.a, interval.b};
m_intervals.Add(i);
m_intervalToUid[i] = uid;
}
m_intervals.Seal();
}
}
}
Loading