Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SendInLoop Interprocedural Improvement #219

Merged
169 changes: 164 additions & 5 deletions src/detectors/builtin/sendInLoop.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import { CompilationUnit } from "../../internals/ir";
import { CallGraph, CGNodeId } from "../../internals/ir/callGraph";
import {
forEachStatement,
foldExpressions,
isSelf,
} from "../../internals/tact";
import { unreachable } from "../../internals/util";
import { MistiTactWarning, Severity } from "../../internals/warnings";
import { ASTDetector } from "../detector";
import {
AstStatement,
AstExpression,
idText,
AstFunctionDef,
AstReceiver,
AstContractInit,
AstNode,
} from "@tact-lang/compiler/dist/grammar/ast";

/**
* An optional detector that identifies send functions being called inside loops.
* An optional detector that identifies send functions being called inside loops,
* including indirect calls via other functions.
*
* ## Why is it bad?
* Calling send functions inside loops can lead to unintended consequences, such as
Expand Down Expand Up @@ -43,10 +50,60 @@ export class SendInLoop extends ASTDetector {
async check(cu: CompilationUnit): Promise<MistiTactWarning[]> {
const processedLoopIds = new Set<number>();
const allWarnings: MistiTactWarning[] = [];
const astStore = cu.ast;
const ctx = this.ctx;
const callGraph = new CallGraph(ctx).build(astStore);
const astIdToCGNodeId = new Map<number, CGNodeId>();
for (const [nodeId, node] of callGraph.getNodes()) {
if (node.astId !== undefined) {
astIdToCGNodeId.set(node.astId, nodeId);
}
}

// Collect functions that directly call send functions
const functionsCallingSend = new Set<CGNodeId>();

// Identify all functions that contain a send call
Esorat marked this conversation as resolved.
Show resolved Hide resolved
for (const func of astStore.getFunctions()) {
let containsSend = false;
foldExpressions(
func,
(acc, expr) => {
if (this.isSendCall(expr)) {
containsSend = true;
}
return acc;
},
null,
);

if (containsSend) {
const funcName = this.getFunctionName(func);
if (funcName) {
const nodeId = callGraph.getNodeIdByName(funcName);
if (nodeId !== undefined) {
functionsCallingSend.add(nodeId);
}
}
}
}

// Identify all functions that can lead to a send call
const functionsLeadingToSend = this.getFunctionsLeadingToSend(
callGraph,
functionsCallingSend,
);

// Analyze loops and check if any function called within leads to a send
Array.from(cu.ast.getProgramEntries()).forEach((node) => {
forEachStatement(node, (stmt) => {
const warnings = this.analyzeStatement(stmt, processedLoopIds);
const warnings = this.analyzeStatement(
stmt,
processedLoopIds,
callGraph,
astIdToCGNodeId,
functionsLeadingToSend,
);
allWarnings.push(...warnings);
});
});
Expand All @@ -57,15 +114,22 @@ export class SendInLoop extends ASTDetector {
private analyzeStatement(
stmt: AstStatement,
processedLoopIds: Set<number>,
callGraph: CallGraph,
astIdToCGNodeId: Map<number, CGNodeId>,
functionsLeadingToSend: Set<CGNodeId>,
): MistiTactWarning[] {
if (processedLoopIds.has(stmt.id)) {
return [];
}
if (this.isLoop(stmt)) {
processedLoopIds.add(stmt.id);
return foldExpressions(

const warnings: MistiTactWarning[] = [];

// Check direct send calls within the loop
foldExpressions(
stmt,
(acc, expr) => {
(acc: MistiTactWarning[], expr: AstExpression) => {
if (this.isSendCall(expr)) {
acc.push(
this.makeWarning("Send function called inside a loop", expr.loc, {
Expand All @@ -76,13 +140,69 @@ export class SendInLoop extends ASTDetector {
}
return acc;
},
[] as MistiTactWarning[],
warnings,
);

// Check function calls within the loop that lead to send
this.forEachExpression(stmt, (expr: AstExpression) => {
if (expr.kind === "static_call" || expr.kind === "method_call") {
const calleeName = this.getCalleeName(expr);
if (calleeName) {
const calleeNodeId = callGraph.getNodeIdByName(calleeName);
if (
calleeNodeId !== undefined &&
functionsLeadingToSend.has(calleeNodeId)
) {
warnings.push(
this.makeWarning(
`Function "${calleeName}" called inside a loop leads to a send function`,
expr.loc,
{
suggestion:
"Consider refactoring to avoid calling send functions inside loops",
},
),
);
}
}
}
});

return warnings;
}
// If the statement is not a loop, don't flag anything
return [];
}

private getFunctionsLeadingToSend(
callGraph: CallGraph,
functionsCallingSend: Set<CGNodeId>,
): Set<CGNodeId> {
const functionsLeadingToSend = new Set<CGNodeId>(functionsCallingSend);

// Use a queue for BFS
const queue: CGNodeId[] = Array.from(functionsCallingSend);

while (queue.length > 0) {
Esorat marked this conversation as resolved.
Show resolved Hide resolved
const current = queue.shift()!;
const currentNode = callGraph.getNode(current);
if (currentNode) {
for (const edgeId of currentNode.inEdges) {
const edge = callGraph.getEdge(edgeId);
if (edge) {
const callerId = edge.src;
if (!functionsLeadingToSend.has(callerId)) {
functionsLeadingToSend.add(callerId);
queue.push(callerId);
}
}
}
}
}

return functionsLeadingToSend;
}

private isSendCall(expr: AstExpression): boolean {
const staticSendFunctions = ["send", "nativeSendMessage"];
const selfMethodSendFunctions = ["reply", "forward", "notify", "emit"];
Expand All @@ -103,4 +223,43 @@ export class SendInLoop extends ASTDetector {
stmt.kind === "statement_foreach"
);
}

private getFunctionName(
Esorat marked this conversation as resolved.
Show resolved Hide resolved
func: AstFunctionDef | AstReceiver | AstContractInit,
): string | undefined {
switch (func.kind) {
case "function_def":
return func.name?.text;
case "contract_init":
return `contract_init_${func.id}`;
case "receiver":
return `receiver_${func.id}`;
default:
unreachable(func);
}
}

private getCalleeName(expr: AstExpression): string | undefined {
Esorat marked this conversation as resolved.
Show resolved Hide resolved
if (expr.kind === "static_call") {
return idText(expr.function);
} else if (expr.kind === "method_call") {
return idText(expr.method);
}
return undefined;
}

// Helper method to traverse expressions
private forEachExpression(
Esorat marked this conversation as resolved.
Show resolved Hide resolved
node: AstNode,
callback: (expr: AstExpression) => void,
): void {
foldExpressions(
node,
(acc, expr) => {
callback(expr);
return acc;
},
null,
);
}
}
65 changes: 28 additions & 37 deletions src/internals/ir/callGraph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
AstStaticCall,
} from "@tact-lang/compiler/dist/grammar/ast";

type CGNodeId = number & { readonly brand: unique symbol };
export type CGNodeId = number & { readonly brand: unique symbol };
type CGEdgeId = number & { readonly brand: unique symbol };
Esorat marked this conversation as resolved.
Show resolved Hide resolved

/**
Expand Down Expand Up @@ -80,6 +80,33 @@ export class CallGraph {
return this.edgesMap;
}

/**
* Retrieves the node ID associated with a given function name.
* @param name The function name.
* @returns The corresponding node ID, or undefined if not found.
*/
public getNodeIdByName(name: string): CGNodeId | undefined {
return this.nameToNodeId.get(name);
}

/**
* Retrieves a node from the graph by its ID.
* @param nodeId The ID of the node.
* @returns The `CGNode` instance, or undefined if not found.
*/
public getNode(nodeId: CGNodeId): CGNode | undefined {
return this.nodeMap.get(nodeId);
}

/**
* Retrieves an edge from the graph by its ID.
* @param edgeId The ID of the edge.
* @returns The `CGEdge` instance, or undefined if not found.
*/
public getEdge(edgeId: CGEdgeId): CGEdge | undefined {
return this.edgesMap.get(edgeId);
}

/**
* Builds the call graph based on functions in the provided AST store.
* @param astStore - The AST store containing functions to be added to the graph.
Expand All @@ -102,42 +129,6 @@ export class CallGraph {
return this;
}

/**
* Determines if there exists a path in the call graph from the source node to the destination node.
* This method performs a breadth-first search to find if the destination node is reachable from the source node.
*
* @param src The ID of the source node to start the search from
* @param dst The ID of the destination node to search for
* @returns true if there exists a path from src to dst in the call graph, false otherwise
* Returns false if either src or dst node IDs are not found in the graph
*/
public areConnected(src: CGNodeId, dst: CGNodeId): boolean {
const srcNode = this.nodeMap.get(src);
const dstNode = this.nodeMap.get(dst);
if (!srcNode || !dstNode) {
return false;
}
const queue: CGNodeId[] = [src];
const visited = new Set<CGNodeId>([src]);
while (queue.length > 0) {
const current = queue.shift()!;
if (current === dst) {
return true;
}
const currentNode = this.nodeMap.get(current);
if (currentNode) {
for (const edgeId of currentNode.outEdges) {
const edge = this.edgesMap.get(edgeId);
if (edge && !visited.has(edge.dst)) {
visited.add(edge.dst);
queue.push(edge.dst);
}
}
}
}
return false;
}

/**
* Analyzes function calls in the AST store and adds corresponding edges in the call graph.
* @param astStore The AST store to analyze for function calls.
Expand Down
Loading