Skip to content

Commit

Permalink
Fix array access implicit conversion and add error diagnostics for ou…
Browse files Browse the repository at this point in the history
…t-of-bounds array accesses (#38)

This PR fixes an issue with using arrays as inputs to gates and provides
an accompanying test. It also adds an error diagnostic to check for
out-of-bounds array accesses.
  • Loading branch information
taalexander authored Jun 12, 2024
1 parent 710d73c commit ec7731b
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 45 deletions.
1 change: 1 addition & 0 deletions conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def build(self):
cmake = CMake(self)
cmake.configure()

use_monitor = False
if self.should_build:
# Note that if a job does not produce output for a longer period of
# time, then Travis will cancel that job.
Expand Down
15 changes: 15 additions & 0 deletions include/qasm/AST/ASTArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <qasm/AST/OpenPulse/ASTOpenPulseFrame.h>
#include <qasm/AST/OpenPulse/ASTOpenPulsePort.h>
#include <qasm/AST/OpenPulse/ASTOpenPulseWaveform.h>
#include <qasm/Diagnostic/DIAGLineCounter.h>
#include <qasm/Frontend/QasmDiagnosticEmitter.h>

#include <any>
#include <sstream>
Expand All @@ -36,6 +38,8 @@

namespace QASM {

using DiagLevel = QASM::QasmDiagnosticEmitter::DiagLevel;

class ASTArrayNode : public ASTExpressionNode {
private:
ASTArrayNode() = delete;
Expand Down Expand Up @@ -66,6 +70,17 @@ class ASTArrayNode : public ASTExpressionNode {
: ASTExpressionNode(Id, ASTTypeArray), MM(), AType(ATy), SZ(Size),
EXT(Extents), INL(IL) {}

/// Validate the array access, emitting a diagnostic if invalid.
void ValidateIndex(unsigned Index, QASM::ASTLocation location) const {
if (Index >= Size()) {
std::stringstream M;
M << "Array index " << Index << " out of range for array of size "
<< Size() << ".";
QasmDiagnosticEmitter::Instance().EmitDiagnostic(location, M.str(),
DiagLevel::Error);
}
}

virtual ~ASTArrayNode() = default;

virtual ASTType GetASTType() const override { return ASTTypeArray; }
Expand Down
4 changes: 3 additions & 1 deletion lib/AST/ASTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4281,7 +4281,9 @@ ASTAngleNode *ASTBuilder::CreateASTAngleNodeFromExpression(
ASTAngleArrayNode *AAN = ASTE->GetValue()->GetValue<ASTAngleArrayNode *>();
assert(AAN && "Could not retrieve a valid ASTAngleArray Node from "
"the SymbolTable Entry!");
AN = AAN->GetElement(IdR->GetIndex());
auto Index = IdR->GetIndex();
AAN->ValidateIndex(Index, IdR->GetLocation());
AN = AAN->GetElement(Index);
assert(AN && "Could not obtain a valid ASTAngleNode from the "
"ASTAngleArrayNode!");
} break;
Expand Down
62 changes: 31 additions & 31 deletions lib/AST/ASTGates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,25 +250,25 @@ ASTGateNode::CreateAngleConversion(const ASTSymbolTableEntry *XSTE) const {
const ASTFloatNode *FN = XSTE->GetValue()->GetValue<ASTFloatNode *>();
assert(FN && "Could not obtain a valid ASTFloatNode!");

XAN =
new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), FN,
ASTAngleTypeGeneric, XSTE->GetIdentifier()->GetBits());
unsigned ConvertBits = FN->GetBits();

XAN = new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), FN,
ASTAngleTypeGeneric, ConvertBits);
assert(XAN && "Could not create a valid ASTAngleNode!");
ICE = new ASTImplicitConversionNode(FN, ASTTypeAngle,
XSTE->GetIdentifier()->GetBits());
ICE = new ASTImplicitConversionNode(FN, ASTTypeAngle, ConvertBits);
assert(ICE && "Could not create a valid ASTImplicitConversionNode!");
XAN->SetImplicitConversion(ICE);
} break;
case ASTTypeDouble: {
const ASTDoubleNode *DN = XSTE->GetValue()->GetValue<ASTDoubleNode *>();
assert(DN && "Could not obtain a valid ASTDoubleNode!");

XAN =
new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), DN,
ASTAngleTypeGeneric, XSTE->GetIdentifier()->GetBits());
unsigned ConvertBits = DN->GetBits();

XAN = new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), DN,
ASTAngleTypeGeneric, ConvertBits);
assert(XAN && "Could not create a valid ASTAngleNode!");
ICE = new ASTImplicitConversionNode(DN, ASTTypeAngle,
XSTE->GetIdentifier()->GetBits());
ICE = new ASTImplicitConversionNode(DN, ASTTypeAngle, ConvertBits);
assert(ICE && "Could not create a valid ASTImplicitConversionNode!");
XAN->SetImplicitConversion(ICE);
} break;
Expand All @@ -277,12 +277,12 @@ ASTGateNode::CreateAngleConversion(const ASTSymbolTableEntry *XSTE) const {
const ASTIntNode *IN = XSTE->GetValue()->GetValue<ASTIntNode *>();
assert(IN && "Could not obtain a valid ASTIntNode!");

XAN =
new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), IN,
ASTAngleTypeGeneric, XSTE->GetIdentifier()->GetBits());
unsigned ConvertBits = IN->GetBits();

XAN = new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), IN,
ASTAngleTypeGeneric, ConvertBits);
assert(XAN && "Could not create a valid ASTAngleNode!");
ICE = new ASTImplicitConversionNode(IN, ASTTypeAngle,
XSTE->GetIdentifier()->GetBits());
ICE = new ASTImplicitConversionNode(IN, ASTTypeAngle, ConvertBits);
assert(ICE && "Could not create a valid ASTImplicitConversionNode!");
XAN->SetImplicitConversion(ICE);
} break;
Expand All @@ -292,12 +292,12 @@ ASTGateNode::CreateAngleConversion(const ASTSymbolTableEntry *XSTE) const {
XSTE->GetValue()->GetValue<ASTMPIntegerNode *>();
assert(MPI && "Could not obtain a valid ASTMPIntegerNode!");

XAN =
new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), MPI,
ASTAngleTypeGeneric, XSTE->GetIdentifier()->GetBits());
unsigned ConvertBits = MPI->GetBits();

XAN = new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), MPI,
ASTAngleTypeGeneric, ConvertBits);
assert(XAN && "Could not create a valid ASTAngleNode!");
ICE = new ASTImplicitConversionNode(MPI, ASTTypeAngle,
XSTE->GetIdentifier()->GetBits());
ICE = new ASTImplicitConversionNode(MPI, ASTTypeAngle, ConvertBits);
assert(ICE && "Could not create a valid ASTImplicitConversionNode!");
XAN->SetImplicitConversion(ICE);
} break;
Expand All @@ -306,20 +306,22 @@ ASTGateNode::CreateAngleConversion(const ASTSymbolTableEntry *XSTE) const {
XSTE->GetValue()->GetValue<ASTMPDecimalNode *>();
assert(MPD && "Could not obtain a valid ASTMPDecimalNode!");

XAN =
new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), MPD,
ASTAngleTypeGeneric, XSTE->GetIdentifier()->GetBits());
unsigned ConvertBits = MPD->GetBits();

XAN = new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), MPD,
ASTAngleTypeGeneric, ConvertBits);
assert(XAN && "Could not create a valid ASTAngleNode!");
ICE = new ASTImplicitConversionNode(MPD, ASTTypeAngle,
XSTE->GetIdentifier()->GetBits());
ICE = new ASTImplicitConversionNode(MPD, ASTTypeAngle, ConvertBits);
assert(ICE && "Could not create a valid ASTImplicitConversionNode!");
XAN->SetImplicitConversion(ICE);
} break;
case ASTTypeBitset: {
const ASTCBitNode *CBN = XSTE->GetValue()->GetValue<ASTCBitNode *>();
assert(CBN && "Could not obtain a valid ASTCBitNode!");

if (CBN->Size() > XSTE->GetIdentifier()->GetBits()) {
unsigned ConvertBits = CBN->GetBits();

if (CBN->Size() > ConvertBits) {
std::stringstream M;
M << "Conversion from " << PrintTypeEnum(XSTE->GetValueType())
<< " to Angle Type will result in truncation.";
Expand All @@ -328,8 +330,7 @@ ASTGateNode::CreateAngleConversion(const ASTSymbolTableEntry *XSTE) const {
M.str(), DiagLevel::Warning);
}

unsigned SZ = std::min(static_cast<unsigned>(CBN->Size()),
XSTE->GetIdentifier()->GetBits());
unsigned SZ = std::min(static_cast<unsigned>(CBN->Size()), ConvertBits);

if (SZ >= 4U)
SZ = SZ % 4;
Expand All @@ -340,9 +341,8 @@ ASTGateNode::CreateAngleConversion(const ASTSymbolTableEntry *XSTE) const {
if ((*CBN)[I])
D += static_cast<double>(M_PI / 2);

XAN =
new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), D,
ASTAngleTypeGeneric, XSTE->GetIdentifier()->GetBits());
XAN = new ASTAngleNode(ASTIdentifierNode::Angle.Clone(LC), D,
ASTAngleTypeGeneric, ConvertBits);
assert(XAN && "Could not create a valid ASTAngleNode!");
ICE = new ASTImplicitConversionNode(CBN, ASTTypeAngle, SZ);
assert(ICE && "Could not create a valid ASTImplicitConversionNode!");
Expand Down
7 changes: 7 additions & 0 deletions lib/AST/ASTOpenPulseFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ ASTOpenPulseFrameNodeResolver::ResolveAngle(ASTExpressionNode *E) {
ASTE->GetValue()->GetValue<ASTFloatArrayNode *>();
assert(FAN && "Could not obtain a valid ASTFloatArrayNode "
"from the SymbolTable!");
FAN->ValidateIndex(IX, AIdR->GetLocation());
ASTFloatNode *FN = FAN->GetElement(IX);
assert(FN && "Could not obtain a valid ASTFloatNode "
"from the ASTFloat array!");
Expand All @@ -250,6 +251,8 @@ ASTOpenPulseFrameNodeResolver::ResolveAngle(ASTExpressionNode *E) {
ASTE->GetValue()->GetValue<ASTMPDecimalArrayNode *>();
assert(MPAN && "Could not obtain a valid ASTMPDecimalArrayNode "
"from the SymbolTable!");

MPAN->ValidateIndex(IX, AIdR->GetLocation());
ASTMPDecimalNode *MPD = MPAN->GetElement(IX);
assert(MPD && "Could not obtain a valid ASTMPDecimalNode "
"from the ASTMPDecimal array!");
Expand All @@ -262,6 +265,7 @@ ASTOpenPulseFrameNodeResolver::ResolveAngle(ASTExpressionNode *E) {
ASTE->GetValue()->GetValue<ASTAngleArrayNode *>();
assert(AAN && "Could not obtain a valid ASTAngleArrayNode "
"from the SymbolTable!");
AAN->ValidateIndex(IX, AIdR->GetLocation());
AN = AAN->GetElement(IX);
assert(AN && "Could not obtain a valid ASTAngleNode "
"from the ASTAngle array!");
Expand Down Expand Up @@ -556,6 +560,7 @@ ASTOpenPulseFrameNodeResolver::ResolveFrequency(ASTExpressionNode *E) {
DSTE->GetValue()->GetValue<ASTFloatArrayNode *>();
assert(FAN && "Could not obtain a valid ASTFloatArrayNode "
"from the SymbolTable!");
FAN->ValidateIndex(IX, DIdR->GetLocation());
ASTFloatNode *FN = FAN->GetElement(IX);
assert(FN && "Could not obtain a valid ASTFloatNode "
"from the ASTFloat array!");
Expand All @@ -567,6 +572,7 @@ ASTOpenPulseFrameNodeResolver::ResolveFrequency(ASTExpressionNode *E) {
DSTE->GetValue()->GetValue<ASTMPDecimalArrayNode *>();
assert(MPDA && "Could not obtain a valid ASTMPDecimalArrayNode "
"from the SymbolTable!");
MPDA->ValidateIndex(IX, DIdR->GetLocation());
ASTMPDecimalNode *MPDD = MPDA->GetElement(IX);
assert(MPDD && "Could not obtain a valid ASTMPDecimalNode "
"from the ASTMPDecimal array!");
Expand All @@ -577,6 +583,7 @@ ASTOpenPulseFrameNodeResolver::ResolveFrequency(ASTExpressionNode *E) {
DSTE->GetValue()->GetValue<ASTMPIntegerArrayNode *>();
assert(MPIA && "Could not obtain a valid ASTMPIntegerArrayNode "
"from the SymbolTable!");
MPIA->ValidateIndex(IX, DIdR->GetLocation());
ASTMPIntegerNode *MPI = MPIA->GetElement(IX);
assert(MPI && "Could not obtain a valid ASTMPIntegerNode "
"from the ASTMPInteger array!");
Expand Down
39 changes: 31 additions & 8 deletions lib/AST/ASTProductionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ ASTProductionFactory::ProductionRule_105(const ASTToken *TK,
ASTIntArrayNode *IAN = dynamic_cast<ASTIntArrayNode *>(
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(IAN && "Invalid Value obtained from the SymbolTable Entry!");
IAN->ValidateIndex(Index, Id->GetLocation());
ASTIntNode *IN = IAN->GetElement(Index);
assert(IN && "Invalid ASTIntNode obtained from the SymbolTable Entry!");
BV = ASTUtils::Instance().GetBooleanValue(IN);
Expand All @@ -900,6 +901,7 @@ ASTProductionFactory::ProductionRule_105(const ASTToken *TK,
ASTMPIntegerArrayNode *MPIA = dynamic_cast<ASTMPIntegerArrayNode *>(
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(MPIA && "Invalid Value obtained from the SymbolTable Entry!");
MPIA->ValidateIndex(Index, Id->GetLocation());
ASTMPIntegerNode *MPI = MPIA->GetElement(Index);
assert(MPI &&
"Invalid ASTMPIntegerNode obtained from the SymbolTable Entry!");
Expand All @@ -909,6 +911,7 @@ ASTProductionFactory::ProductionRule_105(const ASTToken *TK,
ASTBoolArrayNode *BAN = dynamic_cast<ASTBoolArrayNode *>(
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(BAN && "Invalid Value obtained from the SymbolTable Entry!");
BAN->ValidateIndex(Index, Id->GetLocation());
ASTBoolNode *BN = BAN->GetElement(Index);
assert(BN && "Invalid ASTBoolNode obtained from the SymbolTable Entry!");
BV = BN->GetValue();
Expand All @@ -917,6 +920,7 @@ ASTProductionFactory::ProductionRule_105(const ASTToken *TK,
ASTCBitArrayNode *CBA = dynamic_cast<ASTCBitArrayNode *>(
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(CBA && "Invalid Value obtained from the SymbolTable Entry!");
CBA->ValidateIndex(Index, Id->GetLocation());
ASTCBitNode *CBN = CBA->GetElement(Index);
assert(CBN && "Invalid ASTCBitNode obtained from the SymbolTable Entry!");
BV = CBN->AsBool();
Expand All @@ -925,6 +929,7 @@ ASTProductionFactory::ProductionRule_105(const ASTToken *TK,
ASTCBitNArrayNode *CBNA = dynamic_cast<ASTCBitNArrayNode *>(
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(CBNA && "Invalid Value obtained from the SymbolTable Entry!");
CBNA->ValidateIndex(Index, Id->GetLocation());
ASTCBitNode *CBN = CBNA->GetElement(Index);
assert(CBN && "Invalid ASTCBitNode obtained from the SymbolTable Entry!");
BV = CBN->AsBool();
Expand Down Expand Up @@ -1322,7 +1327,9 @@ ASTProductionFactory::ProductionRule_110(const ASTToken *TK,
assert(CRN && "Could not obtain a valid ASTCBitArrayNode from "
"the SymbolTable!");

CBN = CRN->GetElement(IdR->GetIndex());
auto Index = IdR->GetIndex();
CRN->ValidateIndex(Index, IdR->GetLocation());
CBN = CRN->GetElement(Index);
assert(CBN && "Could not obtain a valid ASTCBitNode from "
"the ASTCBitArrayNode!");
FromArray = true;
Expand All @@ -1334,7 +1341,9 @@ ASTProductionFactory::ProductionRule_110(const ASTToken *TK,
assert(CRN && "Could not obtain a valid ASTCBitNArrayNode from "
"the SymbolTable!");

CBN = CRN->GetElement(IdR->GetIndex());
auto Index = IdR->GetIndex();
CRN->ValidateIndex(Index, IdR->GetLocation());
CBN = CRN->GetElement(Index);
assert(CBN && "Could not obtain a valid ASTCBitNode from "
"the ASTCBitNArrayNode!");
FromArray = true;
Expand Down Expand Up @@ -16661,7 +16670,9 @@ ASTDurationNode *ASTProductionFactory::ProductionRule_1202(
LSTE->GetValue()->GetValue<ASTArrayNode *>());
assert(DAN && "Could not dynamic_cast to an ASTDurationArrayNode!");

DRN = DAN->GetElement(IdR->GetIndex());
auto Index = IdR->GetIndex();
DAN->ValidateIndex(Index, IdR->GetLocation());
DRN = DAN->GetElement(Index);
assert(DRN && "Could not obtain a valid ASTDurationNode from "
"the ASTDurationArrayNode!");
LS = DRN->AsString();
Expand Down Expand Up @@ -16692,7 +16703,9 @@ ASTDurationNode *ASTProductionFactory::ProductionRule_1202(
LSTE->GetValue()->GetValue<ASTArrayNode *>());
assert(DAN && "Could not dynamic_cast to an ASTDurationArrayNode!");

DRN = DAN->GetElement(Id->GetBits());
auto Index = Id->GetBits();
DAN->ValidateIndex(Index, Id->GetLocation());
DRN = DAN->GetElement(Index);
assert(DRN && "Could not obtain a valid ASTDurationNode from "
"the ASTLengthArrayNode!");
LS = DRN->AsString();
Expand Down Expand Up @@ -16791,7 +16804,9 @@ ASTDurationNode *ASTProductionFactory::ProductionRule_1203(
LSTE->GetValue()->GetValue<ASTArrayNode *>());
assert(DAN && "Could not dynamic_cast to an ASTLengthArrayNode!");

DRN = DAN->GetElement(IdR->GetIndex());
auto Index = IdR->GetIndex();
DAN->ValidateIndex(Index, IdR->GetLocation());
DRN = DAN->GetElement(Index);
assert(DRN && "Could not obtain a valid ASTDurationNode from "
"the ASTDurationArrayNode!");
LS = DRN->AsString();
Expand Down Expand Up @@ -16822,7 +16837,9 @@ ASTDurationNode *ASTProductionFactory::ProductionRule_1203(
LSTE->GetValue()->GetValue<ASTArrayNode *>());
assert(DAN && "Could not dynamic_cast to an ASTLengthArrayNode!");

DRN = DAN->GetElement(Id->GetBits());
auto Index = Id->GetBits();
DAN->ValidateIndex(Index, Id->GetLocation());
DRN = DAN->GetElement(Index);
assert(DRN && "Could not obtain a valid ASTDurationNode from "
"the ASTDurationArrayNode!");
LS = DRN->AsString();
Expand Down Expand Up @@ -17968,7 +17985,9 @@ ASTProductionFactory::ProductionRule_1103(const ASTToken *TK,
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(QAN && "Could not retrieve a valid ASTQubitArrayNode!");

QCN = QAN->GetElement(QId->GetBits());
auto Index = QId->GetBits();
QAN->ValidateIndex(Index, QId->GetLocation());
QCN = QAN->GetElement(Index);
assert(QCN && "Could not dynamic_cast to an ASTQubitContainerNode!");
} break;
case ASTTypeQubitContainer: {
Expand Down Expand Up @@ -18204,7 +18223,9 @@ ASTDeclarationNode *ASTProductionFactory::ProductionRule_1106(
STE->GetValue()->GetValue<ASTArrayNode *>());
assert(QAN && "Could not retrieve a valid ASTQubitArrayNode!");

QCN = QAN->GetElement(QId->GetBits());
auto Index = QId->GetBits();
QAN->ValidateIndex(Index, QId->GetLocation());
QCN = QAN->GetElement(Index);
assert(QCN && "Could not dynamic_cast to an ASTQubitContainerNode!");
} break;
case ASTTypeQubitContainer: {
Expand Down Expand Up @@ -22158,6 +22179,7 @@ ASTProductionFactory::ProductionRule_1461(const ASTToken *TK,
unsigned I = RIdR ? RIdR->GetIndex() : RId->GetBits();
assert(I != static_cast<unsigned>(~0x0) &&
"Invalid ASTCBitArrayNode Index!");
CAN->ValidateIndex(I, RIdR ? RIdR->GetLocation() : RId->GetLocation());
CBN = CAN->GetElement(I);
assert(CBN && "Invalid Bitset obtained from the ASTCBitArrayNode!");
RTy = CBN->GetASTType();
Expand All @@ -22169,6 +22191,7 @@ ASTProductionFactory::ProductionRule_1461(const ASTToken *TK,
unsigned I = RIdR ? RIdR->GetIndex() : RId->GetBits();
assert(I != static_cast<unsigned>(~0x0) &&
"Invalid ASTCBitArrayNode Index!");
CAN->ValidateIndex(I, RIdR ? RIdR->GetLocation() : RId->GetLocation());
CBN = CAN->GetElement(I);
assert(CBN && "Invalid Bitset obtained from the ASTCBitNArrayNode!");
RTy = CBN->GetASTType();
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTTypeDiscovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,7 @@ ASTIdentifierRefNode *ResolveASTIdentifierRef(
__AT *A = dynamic_cast<__AT *>(STE->GetValue()->GetValue<ASTArrayNode *>());
assert(A && "Could not dynamic_cast to a valid array node!");

A->ValidateIndex(IX, ASN->GetLocation());
__ET *E = A->GetElement(IX);
assert(E && "Could not obtain a valid array element!");

Expand Down
Loading

0 comments on commit ec7731b

Please sign in to comment.