From b118de5666fadaedadb4307c9a532d3edc291b40 Mon Sep 17 00:00:00 2001 From: "jiangchuan.he" Date: Wed, 25 Oct 2023 00:39:14 -0700 Subject: [PATCH] support checking complex method signature --- airgap/method_registry.go | 50 +++++++++++++++++++++++++++------- airgap/method_registry_test.go | 5 ++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/airgap/method_registry.go b/airgap/method_registry.go index 3b9f7934..5b41b203 100644 --- a/airgap/method_registry.go +++ b/airgap/method_registry.go @@ -77,23 +77,53 @@ func unregisteredMethodFromString(methodSignature string) (*CeloMethod, error) { } func validateMethodSignature(methodSig string) error { - // Check if the method signature contains both opening and closing parentheses openParenIndex := strings.Index(methodSig, "(") - closeParenIndex := strings.Index(methodSig, ")") - if openParenIndex == -1 || closeParenIndex == -1 || openParenIndex > closeParenIndex { + if openParenIndex == -1 { return fmt.Errorf("Invalid method signature: %s", methodSig) } - // Extract the contents inside the parentheses - paramString := methodSig[openParenIndex+1 : closeParenIndex] + // Check if the method signature has non-empty method name + methodName := methodSig[:openParenIndex] + if len(methodName) == 0 { + return fmt.Errorf("Invalid method signature: %s", methodSig) + } - // If there are no contents, the signature is valid - if paramString == "" { - return nil + // Perform parentheses check + paramString := methodSig[openParenIndex:] + var stack []rune + pairs := map[rune]rune{')': '(', ']': '['} + for _, char := range paramString { + if char == '(' || char == '[' { + stack = append(stack, char) + } else if char == ')' || char == ']' { + if len(stack) == 0 || stack[len(stack)-1] != pairs[char] { + return fmt.Errorf("Invalid method signature: %s", methodSig) + } + stack = stack[:len(stack)-1] + } + } + if len(stack) != 0 { + return fmt.Errorf("Invalid method signature: %s", methodSig) } - // Split the contents by comma to get individual type strings - methodTypes := strings.Split(paramString, ",") + // Extract method parameter types into a string array + paramString = strings.Replace(paramString, "(", " ", -1) + paramString = strings.Replace(paramString, ")", " ", -1) + paramString = strings.Replace(paramString, "[", " ", -1) + paramString = strings.Replace(paramString, "]", " ", -1) + paramString = strings.Replace(paramString, ",", " ", -1) + + var methodTypes []string + for _, methodType := range strings.Split(paramString, " ") { + if methodType != "" { + methodTypes = append(methodTypes, methodType) + } + } + + // If there are no contents, the signature is valid (contract call without arguments). + if len(methodTypes) == 0 { + return nil + } // Iterate through each type string and validate for _, v := range methodTypes { diff --git a/airgap/method_registry_test.go b/airgap/method_registry_test.go index b54ba4aa..be52c68d 100644 --- a/airgap/method_registry_test.go +++ b/airgap/method_registry_test.go @@ -25,9 +25,14 @@ func Test_validateMethodSignature(t *testing.T) { {name: "valid signature with no args", methodSig: "noArgs()", wantErr: false}, {name: "valid signature with one arg", methodSig: "deploy(address)", wantErr: false}, {name: "valid signature with multiple args", methodSig: "deploy(address,uint8,bytes16,address)", wantErr: false}, + {name: "valid signature with nested args", methodSig: "batchTransfer((address,(address,(address,uint256)[])[])[],uint256)", wantErr: false}, {name: "signature with invalid arg type", methodSig: "batchTransfer(DepositWalletTransfer[])", wantErr: true}, {name: "closing parenthesis only", methodSig: "noArgs)", wantErr: true}, {name: "open parenthesis only", methodSig: "noArgs(", wantErr: true}, + {name: "missing closing bracket in the args", methodSig: "batchTransfer(bytes[)", wantErr: true}, + {name: "mismatch parenthesis in the args", methodSig: "batchTransfer(bytes[))", wantErr: true}, + {name: "missing open bracket in the args", methodSig: "batchTransfer(bytes])", wantErr: true}, + {name: "missing closing bracket in the nested args", methodSig: "batchTransfer((address,(address,(address,uint256)[)[])[],uint256)", wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {