diff --git a/core/genesis.go b/core/genesis.go index 97e36190a..eae3bee7a 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -614,17 +614,16 @@ func DeveloperGenesisBlock(gasLimit uint64, faucet *common.Address) *Genesis { BaseFee: big.NewInt(params.InitialBaseFee), Difficulty: big.NewInt(0), Alloc: map[common.Address]types.Account{ - common.BytesToAddress([]byte{1}): {Balance: big.NewInt(1)}, // ECRecover - common.BytesToAddress([]byte{2}): {Balance: big.NewInt(1)}, // SHA256 - common.BytesToAddress([]byte{3}): {Balance: big.NewInt(1)}, // RIPEMD - common.BytesToAddress([]byte{4}): {Balance: big.NewInt(1)}, // Identity - common.BytesToAddress([]byte{5}): {Balance: big.NewInt(1)}, // ModExp - common.BytesToAddress([]byte{6}): {Balance: big.NewInt(1)}, // ECAdd - common.BytesToAddress([]byte{7}): {Balance: big.NewInt(1)}, // ECScalarMul - common.BytesToAddress([]byte{8}): {Balance: big.NewInt(1)}, // ECPairing - common.BytesToAddress([]byte{9}): {Balance: big.NewInt(1)}, // BLAKE2b - common.BytesToAddress([]byte{26}): {Balance: big.NewInt(1)}, // ipGraph - + common.BytesToAddress([]byte{1}): {Balance: big.NewInt(1)}, // ECRecover + common.BytesToAddress([]byte{2}): {Balance: big.NewInt(1)}, // SHA256 + common.BytesToAddress([]byte{3}): {Balance: big.NewInt(1)}, // RIPEMD + common.BytesToAddress([]byte{4}): {Balance: big.NewInt(1)}, // Identity + common.BytesToAddress([]byte{5}): {Balance: big.NewInt(1)}, // ModExp + common.BytesToAddress([]byte{6}): {Balance: big.NewInt(1)}, // ECAdd + common.BytesToAddress([]byte{7}): {Balance: big.NewInt(1)}, // ECScalarMul + common.BytesToAddress([]byte{8}): {Balance: big.NewInt(1)}, // ECPairing + common.BytesToAddress([]byte{9}): {Balance: big.NewInt(1)}, // BLAKE2b + common.BytesToAddress([]byte{0x01, 0x01}): {Balance: big.NewInt(1)}, // ipGraph // Pre-deploy EIP-4788 system contract params.BeaconRootsAddress: {Nonce: 1, Code: params.BeaconRootsCode, Balance: common.Big0}, }, diff --git a/core/vm/contracts.go b/core/vm/contracts.go index e61f7137b..6041a10e5 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -100,17 +100,17 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ // PrecompiledContractsCancun contains the default set of pre-compiled Ethereum // contracts used in the Cancun release. var PrecompiledContractsCancun = map[common.Address]PrecompiledContract{ - common.BytesToAddress([]byte{0x1}): &ecrecover{}, - common.BytesToAddress([]byte{0x2}): &sha256hash{}, - common.BytesToAddress([]byte{0x3}): &ripemd160hash{}, - common.BytesToAddress([]byte{0x4}): &dataCopy{}, - common.BytesToAddress([]byte{0x5}): &bigModExp{eip2565: true}, - common.BytesToAddress([]byte{0x6}): &bn256AddIstanbul{}, - common.BytesToAddress([]byte{0x7}): &bn256ScalarMulIstanbul{}, - common.BytesToAddress([]byte{0x8}): &bn256PairingIstanbul{}, - common.BytesToAddress([]byte{0x9}): &blake2F{}, - common.BytesToAddress([]byte{0xa}): &kzgPointEvaluation{}, - common.BytesToAddress([]byte{0x1a}): &ipGraph{}, + common.BytesToAddress([]byte{0x1}): &ecrecover{}, + common.BytesToAddress([]byte{0x2}): &sha256hash{}, + common.BytesToAddress([]byte{0x3}): &ripemd160hash{}, + common.BytesToAddress([]byte{0x4}): &dataCopy{}, + common.BytesToAddress([]byte{0x5}): &bigModExp{eip2565: true}, + common.BytesToAddress([]byte{0x6}): &bn256AddIstanbul{}, + common.BytesToAddress([]byte{0x7}): &bn256ScalarMulIstanbul{}, + common.BytesToAddress([]byte{0x8}): &bn256PairingIstanbul{}, + common.BytesToAddress([]byte{0x9}): &blake2F{}, + common.BytesToAddress([]byte{0xa}): &kzgPointEvaluation{}, + common.BytesToAddress([]byte{0x01, 0x01}): &ipGraph{}, } // PrecompiledContractsPrague contains the set of pre-compiled Ethereum diff --git a/core/vm/ipgraph.go b/core/vm/ipgraph.go index 29c6b89c7..f15fede2b 100644 --- a/core/vm/ipgraph.go +++ b/core/vm/ipgraph.go @@ -10,10 +10,21 @@ import ( "github.com/ethereum/go-ethereum/log" ) +const ( + ipGraphWriteGas = 100 + ipGraphReadGas = 10 + averageAncestorIpCount = 30 + averageParentIpCount = 4 + intrinsicGas = 100 +) + var ( + royaltyPolicyKindLAP = big.NewInt(0) // Liquid Absolute Percentage (LAP) Royalty Policy + royaltyPolicyKindLRP = big.NewInt(1) // Liquid Relative Percentage (LRP) Royalty Policy + hundredPercent = big.NewInt(100000000) // 100% in the integer format + ipGraphAddress = common.HexToAddress("0x0000000000000000000000000000000000000101") aclAddress = common.HexToAddress("0x680E66e4c7Df9133a7AFC1ed091089B32b89C4ae") aclSlot = "af99b37fdaacca72ee7240cb1435cc9e498aee6ef4edc19c8cc0cd787f4e6800" - ipGraphAddress = common.HexToAddress("0x000000000000000000000000000000000000001A") addParentIpSelector = crypto.Keccak256Hash([]byte("addParentIp(address,address[])")).Bytes()[:4] hasParentIpSelector = crypto.Keccak256Hash([]byte("hasParentIp(address,address)")).Bytes()[:4] getParentIpsSelector = crypto.Keccak256Hash([]byte("getParentIps(address)")).Bytes()[:4] @@ -21,19 +32,63 @@ var ( getAncestorIpsSelector = crypto.Keccak256Hash([]byte("getAncestorIps(address)")).Bytes()[:4] getAncestorIpsCountSelector = crypto.Keccak256Hash([]byte("getAncestorIpsCount(address)")).Bytes()[:4] hasAncestorIpsSelector = crypto.Keccak256Hash([]byte("hasAncestorIp(address,address)")).Bytes()[:4] - setRoyaltySelector = crypto.Keccak256Hash([]byte("setRoyalty(address,address,uint256)")).Bytes()[:4] - getRoyaltySelector = crypto.Keccak256Hash([]byte("getRoyalty(address,address)")).Bytes()[:4] - getRoyaltyStackSelector = crypto.Keccak256Hash([]byte("getRoyaltyStack(address)")).Bytes()[:4] + setRoyaltySelector = crypto.Keccak256Hash([]byte("setRoyalty(address,address,uint256,uint256)")).Bytes()[:4] + getRoyaltySelector = crypto.Keccak256Hash([]byte("getRoyalty(address,address,uint256)")).Bytes()[:4] + getRoyaltyStackSelector = crypto.Keccak256Hash([]byte("getRoyaltyStack(address,uint256)")).Bytes()[:4] ) type ipGraph struct{} func (c *ipGraph) RequiredGas(input []byte) uint64 { - return uint64(1) + // Smart contract function's selector is the first 4 bytes of the input + if len(input) < 4 { + return intrinsicGas + } + + selector := input[:4] + + switch { + case bytes.Equal(selector, addParentIpSelector): + return ipGraphWriteGas + case bytes.Equal(selector, hasParentIpSelector): + return ipGraphReadGas * averageParentIpCount + case bytes.Equal(selector, getParentIpsSelector): + return ipGraphReadGas * averageParentIpCount + case bytes.Equal(selector, getParentIpsCountSelector): + return ipGraphReadGas + case bytes.Equal(selector, getAncestorIpsSelector): + return ipGraphReadGas * averageAncestorIpCount * 2 + case bytes.Equal(selector, getAncestorIpsCountSelector): + return ipGraphReadGas * averageParentIpCount * 2 + case bytes.Equal(selector, hasAncestorIpsSelector): + return ipGraphReadGas * averageAncestorIpCount * 2 + case bytes.Equal(selector, setRoyaltySelector): + return ipGraphWriteGas + case bytes.Equal(selector, getRoyaltySelector): + royaltyPolicyKind := new(big.Int).SetBytes(getData(input, 64+4, 32)) + if royaltyPolicyKind.Cmp(royaltyPolicyKindLAP) == 0 { + return ipGraphReadGas * (averageAncestorIpCount * 3) + } else if royaltyPolicyKind.Cmp(royaltyPolicyKindLRP) == 0 { + return ipGraphReadGas * (averageAncestorIpCount*2 + 2) + } else { + return intrinsicGas + } + case bytes.Equal(selector, getRoyaltyStackSelector): + royaltyPolicyKind := new(big.Int).SetBytes(getData(input, 32+4, 32)) + if royaltyPolicyKind.Cmp(royaltyPolicyKindLAP) == 0 { + return ipGraphReadGas * (averageParentIpCount + 1) + } else if royaltyPolicyKind.Cmp(royaltyPolicyKindLRP) == 0 { + return ipGraphReadGas * (averageAncestorIpCount * 2) + } else { + return intrinsicGas + } + default: + return intrinsicGas + } } func (c *ipGraph) Run(evm *EVM, input []byte) ([]byte, error) { - log.Info("ipGraph.Run", "input", input) + log.Info("ipGraph.Run", "ipGraphAddress", ipGraphAddress, "input", input) if len(input) < 4 { return nil, fmt.Errorf("input too short") @@ -44,25 +99,25 @@ func (c *ipGraph) Run(evm *EVM, input []byte) ([]byte, error) { switch { case bytes.Equal(selector, addParentIpSelector): - return c.addParentIp(args, evm) + return c.addParentIp(args, evm, ipGraphAddress) case bytes.Equal(selector, hasParentIpSelector): - return c.hasParentIp(args, evm) + return c.hasParentIp(args, evm, ipGraphAddress) case bytes.Equal(selector, getParentIpsSelector): - return c.getParentIps(args, evm) + return c.getParentIps(args, evm, ipGraphAddress) case bytes.Equal(selector, getParentIpsCountSelector): - return c.getParentIpsCount(args, evm) + return c.getParentIpsCount(args, evm, ipGraphAddress) case bytes.Equal(selector, getAncestorIpsSelector): - return c.getAncestorIps(args, evm) + return c.getAncestorIps(args, evm, ipGraphAddress) case bytes.Equal(selector, getAncestorIpsCountSelector): - return c.getAncestorIpsCount(args, evm) + return c.getAncestorIpsCount(args, evm, ipGraphAddress) case bytes.Equal(selector, hasAncestorIpsSelector): - return c.hasAncestorIp(args, evm) + return c.hasAncestorIp(args, evm, ipGraphAddress) case bytes.Equal(selector, setRoyaltySelector): - return c.setRoyalty(args, evm) + return c.setRoyalty(args, evm, ipGraphAddress) case bytes.Equal(selector, getRoyaltySelector): - return c.getRoyalty(args, evm) + return c.getRoyalty(args, evm, ipGraphAddress) case bytes.Equal(selector, getRoyaltyStackSelector): - return c.getRoyaltyStack(args, evm) + return c.getRoyaltyStack(args, evm, ipGraphAddress) default: return nil, fmt.Errorf("unknown selector") } @@ -83,7 +138,7 @@ func (c *ipGraph) isAllowed(evm *EVM) (bool, error) { return false, nil } -func (c *ipGraph) addParentIp(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) addParentIp(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { allowed, err := c.isAllowed(evm) if err != nil { @@ -122,7 +177,7 @@ func (c *ipGraph) addParentIp(input []byte, evm *EVM) ([]byte, error) { return nil, nil } -func (c *ipGraph) hasParentIp(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) hasParentIp(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { if len(input) < 64 { return nil, fmt.Errorf("input too short for hasParentIp") } @@ -146,7 +201,7 @@ func (c *ipGraph) hasParentIp(input []byte, evm *EVM) ([]byte, error) { return common.LeftPadBytes([]byte{0}, 32), nil } -func (c *ipGraph) getParentIps(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) getParentIps(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { log.Info("getParentIps", "input", input) if len(input) < 32 { return nil, fmt.Errorf("input too short for getParentIps") @@ -170,7 +225,7 @@ func (c *ipGraph) getParentIps(input []byte, evm *EVM) ([]byte, error) { return output, nil } -func (c *ipGraph) getParentIpsCount(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) getParentIpsCount(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { log.Info("getParentIpsCount", "input", input) if len(input) < 32 { return nil, fmt.Errorf("input too short for getParentIpsCount") @@ -184,13 +239,13 @@ func (c *ipGraph) getParentIpsCount(input []byte, evm *EVM) ([]byte, error) { return common.BigToHash(currentLength).Bytes(), nil } -func (c *ipGraph) getAncestorIps(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) getAncestorIps(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { log.Info("getAncestorIps", "input", input) if len(input) < 32 { return nil, fmt.Errorf("input too short for getAncestorIps") } ipId := common.BytesToAddress(input[0:32]) - ancestors := c.findAncestors(ipId, evm) + ancestors := c.findAncestors(ipId, evm, ipGraphAddress) output := make([]byte, 64+len(ancestors)*32) copy(output[0:32], common.BigToHash(new(big.Int).SetUint64(32)).Bytes()) @@ -206,26 +261,26 @@ func (c *ipGraph) getAncestorIps(input []byte, evm *EVM) ([]byte, error) { return output, nil } -func (c *ipGraph) getAncestorIpsCount(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) getAncestorIpsCount(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { log.Info("getAncestorIpsCount", "input", input) if len(input) < 32 { return nil, fmt.Errorf("input too short for getAncestorIpsCount") } ipId := common.BytesToAddress(input[0:32]) - ancestors := c.findAncestors(ipId, evm) + ancestors := c.findAncestors(ipId, evm, ipGraphAddress) count := new(big.Int).SetUint64(uint64(len(ancestors))) log.Info("getAncestorIpsCount", "ipId", ipId, "count", count) return common.BigToHash(count).Bytes(), nil } -func (c *ipGraph) hasAncestorIp(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) hasAncestorIp(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { if len(input) < 64 { return nil, fmt.Errorf("input too short for hasAncestorIp") } ipId := common.BytesToAddress(input[0:32]) parentIpId := common.BytesToAddress(input[32:64]) - ancestors := c.findAncestors(ipId, evm) + ancestors := c.findAncestors(ipId, evm, ipGraphAddress) if _, found := ancestors[parentIpId]; found { log.Info("hasAncestorIp", "found", true) @@ -235,7 +290,7 @@ func (c *ipGraph) hasAncestorIp(input []byte, evm *EVM) ([]byte, error) { return common.LeftPadBytes([]byte{0}, 32), nil } -func (c *ipGraph) findAncestors(ipId common.Address, evm *EVM) map[common.Address]struct{} { +func (c *ipGraph) findAncestors(ipId common.Address, evm *EVM, ipGraphAddress common.Address) map[common.Address]struct{} { ancestors := make(map[common.Address]struct{}) var stack []common.Address stack = append(stack, ipId) @@ -261,7 +316,7 @@ func (c *ipGraph) findAncestors(ipId common.Address, evm *EVM) map[common.Addres return ancestors } -func (c *ipGraph) setRoyalty(input []byte, evm *EVM) ([]byte, error) { +func (c *ipGraph) setRoyalty(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { allowed, err := c.isAllowed(evm) if err != nil { @@ -272,41 +327,58 @@ func (c *ipGraph) setRoyalty(input []byte, evm *EVM) ([]byte, error) { return nil, fmt.Errorf("caller not allowed to set Royalty") } - log.Info("setRoyalty", "input", input) + log.Info("setRoyalty", "ipGraphAddress", ipGraphAddress, "input", input) if len(input) < 96 { return nil, fmt.Errorf("input too short for setRoyalty") } ipId := common.BytesToAddress(input[0:32]) parentIpId := common.BytesToAddress(input[32:64]) - royalty := new(big.Int).SetBytes(getData(input, 64, 32)) - slot := crypto.Keccak256Hash(ipId.Bytes(), parentIpId.Bytes()).Big() - log.Info("setRoyalty", "ipId", ipId, "parentIpId", parentIpId, "royalty", royalty, "slot", slot) + royaltyPolicyKind := new(big.Int).SetBytes(getData(input, 64, 32)) + royalty := new(big.Int).SetBytes(getData(input, 96, 32)) + slot := crypto.Keccak256Hash(ipId.Bytes(), parentIpId.Bytes(), royaltyPolicyKind.Bytes()).Big() + log.Info("setRoyalty", "ipId", ipId, "ipGraphAddress", ipGraphAddress, "parentIpId", parentIpId, + "royaltyPolicyKind", royaltyPolicyKind, "royalty", royalty, "slot", slot) evm.StateDB.SetState(ipGraphAddress, common.BigToHash(slot), common.BigToHash(royalty)) return nil, nil } -func (c *ipGraph) getRoyalty(input []byte, evm *EVM) ([]byte, error) { - log.Info("getRoyalty", "input", input) +func (c *ipGraph) getRoyalty(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { + log.Info("getRoyalty", "ipGraphAddress", ipGraphAddress, "input", input) if len(input) < 64 { return nil, fmt.Errorf("input too short for getRoyalty") } ipId := common.BytesToAddress(input[0:32]) ancestorIpId := common.BytesToAddress(input[32:64]) - ancestors := c.findAncestors(ipId, evm) + royaltyPolicyKind := new(big.Int).SetBytes(getData(input, 64, 32)) + totalRoyalty := big.NewInt(0) + if royaltyPolicyKind.Cmp(royaltyPolicyKindLAP) == 0 { + totalRoyalty = c.getRoyaltyLap(ipId, ancestorIpId, evm, ipGraphAddress) + } else if royaltyPolicyKind.Cmp(royaltyPolicyKindLRP) == 0 { + totalRoyalty = c.getRoyaltyLrp(ipId, ancestorIpId, evm, ipGraphAddress) + } else { + return nil, fmt.Errorf("unknown royalty policy kind") + } + + log.Info("getRoyalty", "ipId", ipId, "ancestorIpId", ancestorIpId, "ipGraphAddress", ipGraphAddress, "royaltyPolicyKind", royaltyPolicyKind, "totalRoyalty", totalRoyalty) + return common.BigToHash(totalRoyalty).Bytes(), nil +} + +func (c *ipGraph) getRoyaltyLap(ipId, ancestorIpId common.Address, evm *EVM, ipGraphAddress common.Address) *big.Int { + log.Info("getRoyaltyLap", "ipId", ipId, "ancestorIpId", ancestorIpId, "ipGraphAddress", ipGraphAddress) + ancestors := c.findAncestors(ipId, evm, ipGraphAddress) totalRoyalty := big.NewInt(0) for ancestor := range ancestors { + log.Info("getRoyaltyLap", "found_ancestor", ancestor) if ancestor == ancestorIpId { // Traverse the graph to accumulate royalties - totalRoyalty.Add(totalRoyalty, c.getRoyaltyForAncestor(ipId, ancestorIpId, evm)) + totalRoyalty.Add(totalRoyalty, c.getRoyaltyLapForAncestor(ipId, ancestorIpId, evm, ipGraphAddress)) } } - - log.Info("getRoyalty", "ipId", ipId, "ancestorIpId", ancestorIpId, "totalRoyalty", totalRoyalty) - return common.BigToHash(totalRoyalty).Bytes(), nil + return totalRoyalty } -func (c *ipGraph) getRoyaltyForAncestor(ipId, ancestorIpId common.Address, evm *EVM) *big.Int { +func (c *ipGraph) getRoyaltyLapForAncestor(ipId, ancestorIpId common.Address, evm *EVM, ipGraphAddress common.Address) *big.Int { ancestors := make(map[common.Address]struct{}) totalRoyalty := big.NewInt(0) var stack []common.Address @@ -330,7 +402,7 @@ func (c *ipGraph) getRoyaltyForAncestor(ipId, ancestorIpId common.Address, evm * } if parentIpId == ancestorIpId { - royaltySlot := crypto.Keccak256Hash(node.Bytes(), ancestorIpId.Bytes()).Big() + royaltySlot := crypto.Keccak256Hash(node.Bytes(), ancestorIpId.Bytes(), royaltyPolicyKindLAP.Bytes()).Big() royalty := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(royaltySlot)).Big() totalRoyalty.Add(totalRoyalty, royalty) } @@ -339,12 +411,113 @@ func (c *ipGraph) getRoyaltyForAncestor(ipId, ancestorIpId common.Address, evm * return totalRoyalty } -func (c *ipGraph) getRoyaltyStack(input []byte, evm *EVM) ([]byte, error) { - log.Info("getRoyaltyStack", "input", input) +func (c *ipGraph) getRoyaltyLrp(ipId, ancestorIpId common.Address, evm *EVM, ipGraphAddress common.Address) *big.Int { + royalty := make(map[common.Address]*big.Int) + royalty[ipId] = hundredPercent + + topoOrder, allParents, err := c.topologicalSort(ipId, ancestorIpId, evm, ipGraphAddress) + if err != nil { + log.Error("Failed to perform topological sort", "error", err) + return big.NewInt(0) // Return 0 if any error occurs + } + log.Info("getRoyaltyLrp", "topoOrder", topoOrder, "allParents", allParents) + + for i := len(topoOrder) - 1; i >= 0; i-- { + node := topoOrder[i] + // If we reached the ancestor IP, we can stop the calculation + if node == ancestorIpId { + break + } + + currentRoyalty, exists := royalty[node] + if !exists || currentRoyalty.Sign() == 0 { + continue // Skip if there's no royalty to distribute + } + + parents := allParents[node] + for _, parentIpId := range parents { + royaltySlot := crypto.Keccak256Hash(node.Bytes(), parentIpId.Bytes(), royaltyPolicyKindLRP.Bytes()).Big() + royaltyHash := common.BigToHash(royaltySlot) + parentRoyalty := evm.StateDB.GetState(ipGraphAddress, royaltyHash).Big() + + contribution := new(big.Int).Div(new(big.Int).Mul(currentRoyalty, parentRoyalty), hundredPercent) + + if existingRoyalty, exists := royalty[parentIpId]; exists { + royalty[parentIpId] = new(big.Int).Add(existingRoyalty, contribution) + } else { + royalty[parentIpId] = contribution + } + } + } + + if result, exists := royalty[ancestorIpId]; exists { + log.Info("getRoyaltyLrp", "msg", "Royalty for ancestor IP", "ancestorIpId", ancestorIpId, "royalty", result) + return result + } + log.Info("getRoyaltyLrp", "msg", "Royalty for ancestor IP not found", "ancestorIpId", ancestorIpId) + return big.NewInt(0) +} + +func (c *ipGraph) topologicalSort(ipId, ancestorIpId common.Address, evm *EVM, ipGraphAddress common.Address) ( + []common.Address, map[common.Address][]common.Address, error) { + + allParents := make(map[common.Address][]common.Address) + visited := make(map[common.Address]bool) + topoOrder := []common.Address{} + stack := []common.Address{ipId} + + for len(stack) > 0 { + current := stack[len(stack)-1] + stack = stack[:len(stack)-1] // pop from stack + + if visited[current] { + topoOrder = append(topoOrder, current) + continue + } + visited[current] = true + stack = append(stack, current) + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(current.Bytes())) + currentLength := currentLengthHash.Big() + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(current.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + parentIpIdBytes := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)).Bytes() + parentIpId := common.BytesToAddress(parentIpIdBytes) + allParents[current] = append(allParents[current], parentIpId) + + if !visited[parentIpId] { + stack = append(stack, parentIpId) + } + } + } + if !visited[ancestorIpId] { + return []common.Address{}, map[common.Address][]common.Address{}, nil + } + return topoOrder, allParents, nil +} + +func (c *ipGraph) getRoyaltyStack(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) { + log.Info("getRoyaltyStack", "ipGraphAddress", ipGraphAddress, "input", input) + totalRoyalty := big.NewInt(0) if len(input) < 32 { return nil, fmt.Errorf("input too short for getRoyaltyStack") } ipId := common.BytesToAddress(input[0:32]) + royaltyPolicyKind := new(big.Int).SetBytes(getData(input, 32, 32)) + if royaltyPolicyKind.Cmp(royaltyPolicyKindLAP) == 0 { + totalRoyalty = c.getRoyaltyStackLap(ipId, evm, ipGraphAddress) + } else if royaltyPolicyKind.Cmp(royaltyPolicyKindLRP) == 0 { + totalRoyalty = c.getRoyaltyStackLrp(ipId, evm, ipGraphAddress) + } else { + return nil, fmt.Errorf("unknown royalty policy kind") + } + log.Info("getRoyaltyStack", "ipId", ipId, "ipGraphAddress", ipGraphAddress, "royaltyPolicyKind", royaltyPolicyKind, "totalRoyalty", totalRoyalty) + return common.BigToHash(totalRoyalty).Bytes(), nil +} + +func (c *ipGraph) getRoyaltyStackLap(ipId common.Address, evm *EVM, ipGraphAddress common.Address) *big.Int { + log.Info("getRoyaltyStackLap", "ipGraphAddress", ipGraphAddress, "IP ID", ipId) ancestors := make(map[common.Address]struct{}) totalRoyalty := big.NewInt(0) var stack []common.Address @@ -367,10 +540,28 @@ func (c *ipGraph) getRoyaltyStack(input []byte, evm *EVM) ([]byte, error) { stack = append(stack, parentIpId) } - royaltySlot := crypto.Keccak256Hash(node.Bytes(), parentIpId.Bytes()).Big() + royaltySlot := crypto.Keccak256Hash(node.Bytes(), parentIpId.Bytes(), royaltyPolicyKindLAP.Bytes()).Big() royalty := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(royaltySlot)).Big() totalRoyalty.Add(totalRoyalty, royalty) } } - return common.BigToHash(totalRoyalty).Bytes(), nil + return totalRoyalty +} + +func (c *ipGraph) getRoyaltyStackLrp(ipId common.Address, evm *EVM, ipGraphAddress common.Address) *big.Int { + log.Info("getRoyaltyStackLrp", "ipGraphAddress", ipGraphAddress, "IP ID", ipId) + totalRoyalty := big.NewInt(0) + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(ipId.Bytes())) + currentLength := currentLengthHash.Big() + + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(ipId.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + storedParent := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)) + parentIpId := common.BytesToAddress(storedParent.Bytes()) + royaltySlot := crypto.Keccak256Hash(ipId.Bytes(), parentIpId.Bytes(), royaltyPolicyKindLRP.Bytes()).Big() + royalty := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(royaltySlot)).Big() + totalRoyalty.Add(totalRoyalty, royalty) + } + return totalRoyalty }