From 2bd120f20e08fd276e2901f2a7e037f83b3b039e Mon Sep 17 00:00:00 2001 From: Seokho Son Date: Thu, 13 Jun 2024 01:07:10 +0900 Subject: [PATCH] Add new built-in function feature for remote cmd --- src/core/mcis/manageInfo.go | 4 + src/core/mcis/remoteCommand.go | 212 ++++++++++++++++++++++++++++++++- 2 files changed, 213 insertions(+), 3 deletions(-) diff --git a/src/core/mcis/manageInfo.go b/src/core/mcis/manageInfo.go index 9b4640455..9f0848809 100644 --- a/src/core/mcis/manageInfo.go +++ b/src/core/mcis/manageInfo.go @@ -970,6 +970,10 @@ func GetVmCurrentPublicIp(nsId string, mcisId string, vmId string) (TbVmStatusIn key := common.GenMcisKey(nsId, mcisId, vmId) keyValue, err := common.CBStore.Get(key) if err != nil || keyValue == nil { + if keyValue == nil { + log.Error().Err(err).Msgf("Not found: %s keyValue is nil", key) + return errorInfo, fmt.Errorf("Not found: %s keyValue is nil", key) + } log.Error().Err(err).Msg("") return errorInfo, err } diff --git a/src/core/mcis/remoteCommand.go b/src/core/mcis/remoteCommand.go index 9dabd365f..7792e4fdf 100644 --- a/src/core/mcis/remoteCommand.go +++ b/src/core/mcis/remoteCommand.go @@ -17,10 +17,13 @@ package mcis import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net" "os" + "regexp" + "strings" "sync" "time" @@ -145,16 +148,31 @@ func RemoteCommandToMcis(nsId string, mcisId string, subGroupId string, vmId str vmList = []string{vmId} } - //goroutine sync wg + // goroutine sync wg var wg sync.WaitGroup var resultArray []SshCmdResult + // Preprocess commands for each VM + vmCommands := make(map[string][]string) for _, vmId := range vmList { + processedCommands := make([]string, len(req.Command)) + for i, cmd := range req.Command { + processedCmd, err := processCommand(cmd, nsId, mcisId, vmId) + if err != nil { + return nil, err + } + processedCommands[i] = processedCmd + } + vmCommands[vmId] = processedCommands + } + + // Execute commands in parallel using goroutines + for vmId, commands := range vmCommands { wg.Add(1) - go RunRemoteCommandAsync(&wg, nsId, mcisId, vmId, req.UserName, req.Command, &resultArray) + go RunRemoteCommandAsync(&wg, nsId, mcisId, vmId, req.UserName, commands, &resultArray) } - wg.Wait() //goroutine sync wg + wg.Wait() // goroutine sync wg return resultArray, nil } @@ -714,3 +732,191 @@ func GetBastionNodes(nsId string, mcisId string, targetVmId string) ([]mcir.Bast return returnValue, fmt.Errorf("failed to get bastion in Subnet (ID: %s) of VNet (ID: %s) for VM (ID: %s)", vmObj.SubnetId, vmObj.VNetId, targetVmId) } + +// Helper function to extract function name and parameters from the string +func extractFunctionAndParams(funcCall string) (string, map[string]string, error) { + regex := regexp.MustCompile(`^\s*([a-zA-Z0-9]+)\((.*?)\)\s*$`) + matches := regex.FindStringSubmatch(funcCall) + if len(matches) < 3 { + return "", nil, errors.New("Built-in function error in command: no function found in command") + } + + funcName := matches[1] + paramsPart := matches[2] + params := make(map[string]string) + + paramPairs := splitParams(paramsPart) + + for _, pair := range paramPairs { + kv := strings.SplitN(pair, "=", 2) + if len(kv) == 2 { + key := strings.TrimSpace(kv[0]) + value := strings.TrimSpace(kv[1]) + if strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'") { + value = strings.Trim(value, "'") + } + params[key] = value + } + } + + return funcName, params, nil +} + +// Helper function to split parameters by comma, considering quoted parts +func splitParams(paramsPart string) []string { + var result []string + var current strings.Builder + inQuotes := false // Initialize inQuotes + + for i := 0; i < len(paramsPart); i++ { + switch paramsPart[i] { + case '\'': + inQuotes = !inQuotes + current.WriteByte(paramsPart[i]) + case ',': + if inQuotes { + current.WriteByte(paramsPart[i]) + } else { + result = append(result, current.String()) + current.Reset() + } + default: + current.WriteByte(paramsPart[i]) + } + } + if current.Len() > 0 { + result = append(result, current.String()) + } + + return result +} + +// extractFunctionAndParams is a helper function to find matching parenthesis +func findMatchingParenthesis(command string, start int) int { + count := 1 + for i := start; i < len(command); i++ { + switch command[i] { + case '(': + count++ + case ')': + count-- + if count == 0 { + return i + } + } + } + return -1 +} + +// processCommand is function to replace the keywords with actual values +func processCommand(command, nsId, mcisId, vmId string) (string, error) { + start := 0 + for { + start = strings.Index(command[start:], "$$Func(") + if start == -1 { + break + } + start += 7 // Move past "$$Func(" + end := findMatchingParenthesis(command, start) + if end == -1 { + return "", errors.New("Built-in function error in command: no matching parenthesis found") + } + + funcCall := command[start:end] + + funcName, params, err := extractFunctionAndParams(funcCall) + if err != nil { + return "", err + } + + var replacement string + if strings.EqualFold(funcName, "GetPublicIP") { + targetMcisId := mcisId + targetVmId := vmId + if val, ok := params["target"]; ok { + parts := strings.Split(val, ".") + if len(parts) == 2 { + targetMcisId = parts[0] + targetVmId = parts[1] + } else if strings.EqualFold(val, "this") { + targetMcisId = mcisId + targetVmId = vmId + } + } + prefix := "" + if pre, ok := params["prefix"]; ok { + prefix = pre + } + postfix := "" + if post, ok := params["postfix"]; ok { + postfix = post + } + replacement, err = getPublicIP(nsId, targetMcisId, targetVmId, prefix, postfix) + + if err != nil { + return "", fmt.Errorf("Built-in function getPublicIP error: %s", err.Error()) + } + + } else if strings.EqualFold(funcName, "GetPublicIPs") { + targetMcisId := mcisId + + if val, ok := params["target"]; ok { + if strings.EqualFold(val, "this") { + targetMcisId = mcisId + } else { + targetMcisId = val + } + } + separator := "," + if sep, ok := params["separator"]; ok { + separator = sep + } + prefix := "" + if pre, ok := params["prefix"]; ok { + prefix = pre + } + postfix := "" + if post, ok := params["postfix"]; ok { + postfix = post + } + replacement, err = getPublicIPs(nsId, targetMcisId, separator, prefix, postfix) + + if err != nil { + return "", fmt.Errorf("Built-in function getPublicIPs error: %s", err.Error()) + } + + } else { + return "", fmt.Errorf("Built-in function error in command: Unknown function: %s", funcName) + } + + // Replace the entire $$Func(...) expression with the result + command = command[:start-7] + replacement + command[end+1:] + start = start - 7 + len(replacement) // Adjust start for the next iteration + } + + return command, nil +} + +// Built-in functions for remote command +// getPublicIP function to get and replace string with the public IP of the target +func getPublicIP(nsId, mcisId, vmId, prefix, postfix string) (string, error) { + vmStatus, err := GetVmCurrentPublicIp(nsId, mcisId, vmId) + if err != nil { + return "", err + } + ip := vmStatus.PublicIp + return prefix + ip + postfix, err +} + +// getPublicIP function to get and replace string with the public IP list of the target +func getPublicIPs(nsId, mcisId, separator, prefix, postfix string) (string, error) { + mcisStatus, err := GetMcisStatus(nsId, mcisId) + if err != nil { + return "", err + } + ips := make([]string, len(mcisStatus.Vm)) + for i, vmStatus := range mcisStatus.Vm { + ips[i] = prefix + vmStatus.PublicIp + postfix + } + return strings.Join(ips, separator), nil +}