From 8eef2091a249cd086297a1690f89f95242a12b11 Mon Sep 17 00:00:00 2001 From: Luke Li Date: Tue, 3 Dec 2024 16:18:56 -0500 Subject: [PATCH] TransformIndirectLoadChain at JITServer Implement TransformIndirectLoadChain partially for the JITServer so it can employ the Vector API during optimization. Signed-off-by: Luke Li --- .../control/JITClientCompilationThread.cpp | 45 ++++ runtime/compiler/env/VMJ9Server.hpp | 3 +- runtime/compiler/net/CommunicationStream.hpp | 2 +- runtime/compiler/net/MessageTypes.cpp | 2 + runtime/compiler/net/MessageTypes.hpp | 3 + .../compiler/optimizer/J9TransformUtil.cpp | 226 +++++++++++++++++- .../compiler/optimizer/J9TransformUtil.hpp | 3 + 7 files changed, 278 insertions(+), 6 deletions(-) diff --git a/runtime/compiler/control/JITClientCompilationThread.cpp b/runtime/compiler/control/JITClientCompilationThread.cpp index c6f0c1e4fa7..f471099fe60 100644 --- a/runtime/compiler/control/JITClientCompilationThread.cpp +++ b/runtime/compiler/control/JITClientCompilationThread.cpp @@ -2950,6 +2950,51 @@ handleServerMessage(JITServer::ClientStream *client, TR_J9VM *fe, JITServer::Mes client->write(response, vectorBitSize); } break; + case MessageType::KnownObjectTable_addFieldAddressFromBaseIndex: + { + auto recv = client->getRecvData(); + TR::KnownObjectTable::Index baseObjectIndex = std::get<0>(recv); + intptr_t fieldOffset = std::get<1>(recv); + + TR::KnownObjectTable::Index resultIndex = TR::KnownObjectTable::UNKNOWN; + + { + TR::VMAccessCriticalSection addFieldAddressFromBaseIndex(fe); + uintptr_t baseObjectAddress = knot->getPointer(baseObjectIndex); + uintptr_t fieldAddress = baseObjectAddress + fieldOffset; + + uintptr_t objectPointer = fe->getReferenceFieldAtAddress(fieldAddress); + + if (objectPointer) + resultIndex = knot->getOrCreateIndex(objectPointer); + } + + uintptr_t *resultPointer = + (resultIndex == -1) ? NULL : knot->getPointerLocation(resultIndex); + + client->write(response, resultIndex, resultPointer); + } + break; + case MessageType::KnownObjectTable_getFieldAddressData: + { + auto recv = client->getRecvData(); + TR::KnownObjectTable::Index baseObjectIndex = std::get<0>(recv); + intptr_t fieldOffset = std::get<1>(recv); + + UDATA data = 0; + + { + TR::VMAccessCriticalSection addFieldAddressFromBaseIndex(fe); + uintptr_t baseObjectAddress = knot->getPointer(baseObjectIndex); + + uintptr_t fieldAddress = baseObjectAddress + fieldOffset; + + data = *(UDATA *) fieldAddress; + } + + client->write(response, data); + } + break; case MessageType::AOTCache_getROMClassBatch: { auto recv = client->getRecvData>(); diff --git a/runtime/compiler/env/VMJ9Server.hpp b/runtime/compiler/env/VMJ9Server.hpp index 3d20990a165..14eb6712a98 100644 --- a/runtime/compiler/env/VMJ9Server.hpp +++ b/runtime/compiler/env/VMJ9Server.hpp @@ -210,7 +210,8 @@ class TR_J9ServerVM: public TR_J9VM virtual intptr_t getVFTEntry(TR_OpaqueClassBlock *clazz, int32_t offset) override; virtual bool isClassArray(TR_OpaqueClassBlock *klass) override; virtual uintptr_t getFieldOffset(TR::Compilation * comp, TR::SymbolReference* classRef, TR::SymbolReference* fieldRef) override { return 0; } // safe answer - virtual bool canDereferenceAtCompileTime(TR::SymbolReference *fieldRef, TR::Compilation *comp) override { return false; } // safe answer, might change in the future +// The base version should be safe, no need to override. +// virtual bool canDereferenceAtCompileTime(TR::SymbolReference *fieldRef, TR::Compilation *comp) override; // safe answer, might change in the future virtual bool instanceOfOrCheckCast(J9Class *instanceClass, J9Class* castClass) override; virtual bool instanceOfOrCheckCastNoCacheUpdate(J9Class *instanceClass, J9Class* castClass) override; virtual bool transformJlrMethodInvoke(J9Method *callerMethod, J9Class *callerClass) override; diff --git a/runtime/compiler/net/CommunicationStream.hpp b/runtime/compiler/net/CommunicationStream.hpp index b637e190665..eca531df600 100644 --- a/runtime/compiler/net/CommunicationStream.hpp +++ b/runtime/compiler/net/CommunicationStream.hpp @@ -129,7 +129,7 @@ class CommunicationStream // likely to lose an increment when merging/rebasing/etc. // static const uint8_t MAJOR_NUMBER = 1; - static const uint16_t MINOR_NUMBER = 75; // ID: kzkyjklaOnYjEzzJyIl7 + static const uint16_t MINOR_NUMBER = 76; // ID: BpR0Syhau116Bh0vAoVr static const uint8_t PATCH_NUMBER = 0; static uint32_t CONFIGURATION_FLAGS; diff --git a/runtime/compiler/net/MessageTypes.cpp b/runtime/compiler/net/MessageTypes.cpp index 36cb325a2ef..4715a258704 100644 --- a/runtime/compiler/net/MessageTypes.cpp +++ b/runtime/compiler/net/MessageTypes.cpp @@ -263,6 +263,8 @@ const char *messageNames[] = "KnownObjectTable_getKnownObjectTableDumpInfo", "KnownObjectTable_getOpaqueClass", "KnownObjectTable_getVectorBitSize", + "KnownObjectTable_addFieldAddressFromBaseIndex", + "KnownObjectTable_getFieldAddressData", "AOTCache_getROMClassBatch", "AOTCacheMap_request", "AOTCacheMap_reply" diff --git a/runtime/compiler/net/MessageTypes.hpp b/runtime/compiler/net/MessageTypes.hpp index 8cdceb317b1..f90e791aa93 100644 --- a/runtime/compiler/net/MessageTypes.hpp +++ b/runtime/compiler/net/MessageTypes.hpp @@ -290,6 +290,9 @@ enum MessageType : uint16_t KnownObjectTable_getOpaqueClass, // for getting a vectorBitSize from KnownObjectTable KnownObjectTable_getVectorBitSize, + // used with J9TransformUtil + KnownObjectTable_addFieldAddressFromBaseIndex, + KnownObjectTable_getFieldAddressData, AOTCache_getROMClassBatch, diff --git a/runtime/compiler/optimizer/J9TransformUtil.cpp b/runtime/compiler/optimizer/J9TransformUtil.cpp index 80c07d0e656..53daa4bb6fc 100644 --- a/runtime/compiler/optimizer/J9TransformUtil.cpp +++ b/runtime/compiler/optimizer/J9TransformUtil.cpp @@ -1717,12 +1717,20 @@ bool J9::TransformUtil::transformIndirectLoadChain(TR::Compilation *comp, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, TR::Node **removedNode) { #if defined(J9VM_OPT_JITSERVER) - // JITServer KOT: Bypass this method at the JITServer. - // transformIndirectLoadChainImpl requires access to the VM. - // It is already bypassed by transformIndirectLoadChainAt(). + // Under JITServer, call a simplified version of transformIndirectLoadChain + // that does not access the VM if (comp->isOutOfProcessCompilation()) { - return false; + int32_t stableArrayRank = + comp->getKnownObjectTable()->getArrayWithStableElementsRank(baseKnownObject); + bool result = + TR::TransformUtil::transformIndirectLoadChainServerImpl(comp, + node, + baseExpression, + baseKnownObject, + stableArrayRank, + removedNode); + return result; } #endif /* defined(J9VM_OPT_JITSERVER) */ @@ -1733,6 +1741,216 @@ J9::TransformUtil::transformIndirectLoadChain(TR::Compilation *comp, TR::Node *n return result; } +#if defined(J9VM_OPT_JITSERVER) +/** Dereference node and fold it into a constant when possible. + * + * A simpler version of transformIndirectLoadChain() for the JITServer mode, which only considers + * the case where the node's symRef is a Java field. + */ +bool +J9::TransformUtil::transformIndirectLoadChainServerImpl(TR::Compilation *comp, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, int32_t baseStableArrayRank, TR::Node **removedNode) + { + bool isBaseStableArray = baseStableArrayRank > 0; + TR_J9VMBase *fej9 = comp->fej9(); + + TR_ASSERT(node->getOpCode().isLoadIndirect(), + "Expecting indirect load; found %s %p", node->getOpCode().getName(), node); + TR_ASSERT(node->getNumChildren() == 1, + "Expecting indirect load %s %p to have one child; actually has %d", + node->getOpCode().getName(), node, node->getNumChildren()); + + TR::SymbolReference *symRef = node->getSymbolReference(); + + if (comp->compileRelocatableCode() || + (isBaseStableArray && !symRef->getSymbol()->isArrayShadowSymbol()) || + symRef->hasKnownObjectIndex()) + { + return false; + } + + // Ignore the case of the J9Class whose finality is conditional on the holding value for now. + if (!symRef->isUnresolved() && + symRef == comp->getSymRefTab()->findInitializeStatusFromClassSymbolRef()) + { + return false; + } + + if (!isBaseStableArray && !fej9->canDereferenceAtCompileTime(symRef, comp)) + { + if (comp->getOption(TR_TraceOptDetails)) + { + traceMsg(comp, "Abort transformIndirectLoadChain - cannot dereference at compile time!\n"); + } + return false; + } + + + // Instead of the recursive dereferenceStructPointerChain, we only consider a single level + // of indirection + TR::Symbol *field = symRef->getSymbol(); + TR::Node *addressChildNode = field->isArrayShadowSymbol() ? + node->getFirstChild()->getFirstChild() : + node->getFirstChild(); + if (!addressChildNode->getOpCode().hasSymbolReference() + || addressChildNode != baseExpression) + return false; + // baseStruct is always the value of baseExpression; dereference is not needed + + // We only consider the case where isJavaField is true for verifyFieldAccess + if (isJavaField(symRef, comp)) + { + TR_OpaqueClassBlock *fieldClass = NULL; + + if (symRef->getCPIndex() < 0 && + field->getRecognizedField() != TR::Symbol::UnknownField) + { + const char* className; + int32_t length; + className = field->owningClassNameCharsForRecognizedField(length); + fieldClass = fej9->getClassFromSignature(className, length, symRef->getOwningMethod(comp)); + } + else + fieldClass = symRef->getOwningMethod(comp)->getDeclaringClassFromFieldOrStatic(comp, + symRef->getCPIndex()); + + TR_OpaqueClassBlock *objectClass = + fej9->getObjectClassFromKnownObjectIndex(comp, baseKnownObject); + + // field access verified + if ((fieldClass != NULL) && (fej9->isInstanceOf(objectClass, fieldClass, true) == TR_yes)) + { + + // check the recognized fields case of avoidFoldingInstanceField + // the non-null checks are done when we get the actual values + if (field->getRecognizedField() == TR::Symbol::Java_lang_invoke_CallSite_target || + field->getRecognizedField() == TR::Symbol::Java_lang_invoke_MethodHandle_form) + return false; + + TR::DataType loadType = node->getDataType(); + + if (loadType == TR::Address) + { + if (isFinalFieldPointingAtRepresentableNativeStruct(symRef, comp) || + isFinalFieldPointingAtNativeStruct(symRef, comp)) + { + return false; + } + else if (field->isCollectedReference()) + { + auto stream = comp->getStream(); + stream->write( + JITServer::MessageType::KnownObjectTable_addFieldAddressFromBaseIndex, + baseKnownObject, symRef->getOffset()); + auto recv = stream->read(); + TR::KnownObjectTable::Index value = std::get<0>(recv); + uintptr_t *objectReferenceLocationClient = std::get<1>(recv); + comp->getKnownObjectTable()->updateKnownObjectTableAtServer( + value, + objectReferenceLocationClient + ); + + if (value != -1) + { + TR::SymbolReference *improvedSymRef = + comp->getSymRefTab()->findOrCreateSymRefWithKnownObject(symRef, value); + + if (improvedSymRef->hasKnownObjectIndex() + && performTransformation(comp, + "O^O transformIndirectLoadChain: %s [%p] with fieldOffset %d is obj%d referenceAddr is %p\n", node->getOpCode().getName(), node, improvedSymRef->getKnownObjectIndex(), symRef->getOffset(), value)) + { + node->setSymbolReference(improvedSymRef); + node->setIsNull(false); + node->setIsNonNull(true); + + int32_t stableArrayRank = isArrayWithStableElements(symRef->getCPIndex(), + symRef->getOwningMethod(comp), + comp); + if (isBaseStableArray) + stableArrayRank = baseStableArrayRank - 1; + + if (stableArrayRank > 0) + { + TR::KnownObjectTable *knot = comp->getOrCreateKnownObjectTable(); + knot->addStableArray(improvedSymRef->getKnownObjectIndex(), + stableArrayRank); + } + return true; + } + else /* has known object index */ + { + return false; + } + } + else /* value != -1 */ + { + return false; + } + } + else /* collected reference */ + { + return false; + } + } + else // non-address types + { + auto stream = comp->getStream(); + stream->write( + JITServer::MessageType::KnownObjectTable_getFieldAddressData, + baseKnownObject, symRef->getOffset()); + UDATA data = std::get<0>(stream->read()); + + if (data == 0) + return false; + + switch (loadType) + { + case TR::Int32: + { + int32_t value = (int32_t)data; + if (changeIndirectLoadIntoConst(node, TR::iconst, removedNode, comp)) + node->setInt(value); + else + return false; + } + break; + case TR::Int64: + { + int64_t value = (int64_t)data; + if (changeIndirectLoadIntoConst(node, TR::lconst, removedNode, comp)) + node->setLongInt(value); + else + return false; + } + break; + case TR::Float: + { + float value = (float)data; + if (changeIndirectLoadIntoConst(node, TR::fconst, removedNode, comp)) + node->setFloat(value); + else + return false; + } + break; + case TR::Double: + { + double value = (double)data; + if (changeIndirectLoadIntoConst(node, TR::dconst, removedNode, comp)) + node->setDouble(value); + else + return false; + } + break; + default: + return false; + } + return true; + } + } + } + return false; + } +#endif /* defined(J9VM_OPT_JITSERVER) */ + /** Dereference node and fold it into a constant when possible * * @parm comp The compilation object diff --git a/runtime/compiler/optimizer/J9TransformUtil.hpp b/runtime/compiler/optimizer/J9TransformUtil.hpp index f3ac986ba94..34a58c1c4a5 100644 --- a/runtime/compiler/optimizer/J9TransformUtil.hpp +++ b/runtime/compiler/optimizer/J9TransformUtil.hpp @@ -223,6 +223,9 @@ class OMR_EXTENSIBLE TransformUtil : public OMR::TransformUtilConnector static bool transformIndirectLoadChain(TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, TR::Node **removedNode); static bool transformIndirectLoadChainAt(TR::Compilation *, TR::Node *node, TR::Node *baseExpression, uintptr_t *baseReferenceLocation, TR::Node **removedNode); static bool transformIndirectLoadChainImpl( TR::Compilation *, TR::Node *node, TR::Node *baseExpression, void *baseAddress, int32_t baseStableArrayRank, TR::Node **removedNode); +#if defined(J9VM_OPT_JITSERVER) + static bool transformIndirectLoadChainServerImpl( TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, int32_t baseStableArrayRank, TR::Node **removedNode); +#endif /* defined(J9VM_OPT_JITSERVER) */ static bool fieldShouldBeCompressed(TR::Node *node, TR::Compilation *comp);