-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
332 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
packages/gated-content/src/conditions/__tests__/advanced-contract-condition.spec.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import { Interface } from '@ethersproject/abi'; | ||
import { toChainId, ConditionComparisonOperator, toEvmAddress } from '@lens-protocol/metadata'; | ||
|
||
import { mockNetworkAddress, mockAdvancedContractCondition } from '../../__helpers__/mocks'; | ||
import { transformAdvancedContractCondition } from '../advanced-contract-condition'; | ||
import { LitConditionType, SupportedChains } from '../types'; | ||
import { resolveScalarOperatorSymbol } from '../utils'; | ||
import { InvalidAccessCriteriaError } from '../validators'; | ||
|
||
describe(`Given the "${transformAdvancedContractCondition.name}" function`, () => { | ||
describe('when called with an Advanced Contract condition', () => { | ||
const operatorPairs = Object.values(ConditionComparisonOperator).map((operator) => ({ | ||
operator, | ||
expectedLitOperator: resolveScalarOperatorSymbol(operator), | ||
})); | ||
|
||
it.each(operatorPairs)( | ||
'should support $operator comparisons', | ||
({ operator, expectedLitOperator }) => { | ||
const condition = mockAdvancedContractCondition({ | ||
comparison: operator, | ||
}); | ||
|
||
const actual = transformAdvancedContractCondition(condition); | ||
|
||
const expectedLitAccessConditions = [ | ||
{ | ||
conditionType: LitConditionType.EVM_CONTRACT, | ||
chain: SupportedChains.ETHEREUM, | ||
contractAddress: condition.contract.address.toLowerCase(), | ||
functionAbi: new Interface([ | ||
'function balanceOf(address) view returns (uint256)', | ||
]).getFunction(condition.functionName), | ||
functionName: 'balanceOf', | ||
functionParams: [':userAddress'], | ||
returnValueTest: { | ||
comparator: expectedLitOperator, | ||
value: '1', | ||
key: '', | ||
}, | ||
}, | ||
]; | ||
expect(actual).toEqual(expectedLitAccessConditions); | ||
}, | ||
); | ||
|
||
it.each([ | ||
{ | ||
description: 'if with invalid contract address', | ||
condition: mockAdvancedContractCondition({ | ||
contract: mockNetworkAddress({ | ||
address: toEvmAddress('0x000000000000000000000000000000000000000000000000'), | ||
}), | ||
}), | ||
}, | ||
|
||
{ | ||
description: 'if with invalid chain ID', | ||
condition: mockAdvancedContractCondition({ | ||
contract: mockNetworkAddress({ | ||
chainId: toChainId(2), | ||
}), | ||
}), | ||
}, | ||
|
||
{ | ||
description: 'if with invalid comparison value', | ||
condition: mockAdvancedContractCondition({ | ||
comparison: ConditionComparisonOperator.GREATER_THAN, | ||
value: 'a', | ||
}), | ||
}, | ||
|
||
{ | ||
description: 'if with invalid comparison operator', | ||
condition: mockAdvancedContractCondition({ | ||
comparison: 'a' as ConditionComparisonOperator, | ||
}), | ||
}, | ||
])(`should throw an ${InvalidAccessCriteriaError.name} $description`, ({ condition }) => { | ||
expect(() => transformAdvancedContractCondition(condition)).toThrow( | ||
InvalidAccessCriteriaError, | ||
); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
packages/gated-content/src/conditions/advanced-contract-condition.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import { Interface } from '@ethersproject/abi'; | ||
import { BigNumber } from '@ethersproject/bignumber'; | ||
import { AdvancedContractCondition, ConditionComparisonOperator } from '@lens-protocol/metadata'; | ||
import { invariant, assertError } from '@lens-protocol/shared-kernel'; | ||
|
||
import { LitConditionType, LitEvmAccessCondition } from './types'; | ||
import { toLitSupportedChainName, resolveScalarOperatorSymbol } from './utils'; | ||
import { | ||
assertValidAddress, | ||
assertSupportedChainId, | ||
InvalidAccessCriteriaError, | ||
} from './validators'; | ||
|
||
export const transformAdvancedContractCondition = ( | ||
condition: AdvancedContractCondition, | ||
): Array<LitEvmAccessCondition> => { | ||
assertValidAddress(condition.contract.address); | ||
assertSupportedChainId(condition.contract.chainId); | ||
assertValidAbi(condition.abi, condition.functionName); | ||
assertValidFunctionParams(condition); | ||
assertValidComparison(condition); | ||
|
||
return [ | ||
{ | ||
conditionType: LitConditionType.EVM_CONTRACT, | ||
contractAddress: condition.contract.address.toLowerCase(), | ||
chain: toLitSupportedChainName(condition.contract.chainId), | ||
functionName: condition.functionName, | ||
functionParams: condition.params || [], | ||
functionAbi: new Interface([condition.abi]).getFunction(condition.functionName), | ||
returnValueTest: { | ||
key: '', | ||
comparator: resolveScalarOperatorSymbol(condition.comparison), | ||
value: condition.value, | ||
}, | ||
}, | ||
]; | ||
}; | ||
|
||
function assertValidAbi(humanReadableAbi: string, functionName: string): void { | ||
try { | ||
const fn = new Interface([humanReadableAbi]).getFunction(functionName); | ||
|
||
// assert view function | ||
invariant(fn.stateMutability === 'view', 'unsupported'); | ||
|
||
// assert single output | ||
invariant(Array.isArray(fn.outputs) && fn.outputs.length === 1, 'unsupported'); | ||
|
||
// assert output is boolean or uint | ||
invariant( | ||
fn.outputs[0] && (fn.outputs[0].type === 'bool' || fn.outputs[0].type === 'uint256'), | ||
'unsupported', | ||
); | ||
} catch (e: any) { | ||
throw new InvalidAccessCriteriaError( | ||
`Invalid abi: ${humanReadableAbi} or function: ${functionName}. Only view functions returning a single boolean or uint output are supported`, | ||
); | ||
} | ||
} | ||
|
||
/** | ||
* verifies arguments are valid, as well as the prefixed `:userAddress` param exists | ||
* @param condition the user provided condition object | ||
*/ | ||
function assertValidFunctionParams(condition: AdvancedContractCondition): void { | ||
try { | ||
const fn = new Interface([condition.abi]).getFunction(condition.functionName); | ||
let userAddressParamFound = false; | ||
|
||
invariant(fn.inputs.length === condition.params.length, 'wrong number of params'); | ||
|
||
fn.inputs.forEach((input, index) => { | ||
const param = condition.params[index]; | ||
|
||
invariant(param, `param ${input.name || input.type} is missing`); | ||
|
||
if (input.baseType === 'array' || input.baseType === 'tuple') { | ||
invariant( | ||
Array.isArray(param) && param.length > 0, | ||
`param ${input.name} expects an array argument`, | ||
); | ||
|
||
if (param.includes(':userAddress')) { | ||
userAddressParamFound = true; | ||
} | ||
} | ||
|
||
if (input.baseType === 'address') { | ||
if (param === ':userAddress') { | ||
userAddressParamFound = true; | ||
} else { | ||
assertValidAddress(param); | ||
} | ||
} else if (input.baseType.includes('int')) { | ||
BigNumber.from(param); | ||
} else if (input.baseType === 'bool') { | ||
invariant( | ||
param === 'true' || param === 'false', | ||
`param ${input.name} is invalid, must be a boolean)`, | ||
); | ||
} else if (input.baseType === 'bytes') { | ||
invariant(param.startsWith('0x'), `param ${input.name} is invalid, must be a hex string)`); | ||
} | ||
}); | ||
|
||
invariant(userAddressParamFound, `param :userAddress is missing`); | ||
} catch (e: any) { | ||
assertError(e); | ||
throw new InvalidAccessCriteriaError(e.message); | ||
} | ||
} | ||
|
||
/** | ||
* verifies the comparison is valid based on the function output type | ||
* @param condition the user provided condition object | ||
*/ | ||
function assertValidComparison(condition: AdvancedContractCondition): void { | ||
try { | ||
invariant( | ||
Object.values(ConditionComparisonOperator).includes(condition.comparison), | ||
`comparison operator ${condition.comparison} is unsupported`, | ||
); | ||
|
||
// get function return type | ||
const fn = new Interface([condition.abi]).getFunction(condition.functionName); | ||
|
||
invariant( | ||
Array.isArray(fn.outputs) && fn.outputs.length === 1 && fn.outputs[0], | ||
'function should have a single output', | ||
); | ||
|
||
// for bool, array, tuple results we only allow equal/not equal | ||
if (['bool', 'string', 'bytes', 'address', 'array', 'tuple'].includes(fn.outputs[0].baseType)) { | ||
invariant( | ||
condition.comparison === ConditionComparisonOperator.EQUAL || | ||
condition.comparison === ConditionComparisonOperator.NOT_EQUAL, | ||
`comparison ${condition.comparison} is invalid for function return type ${fn.outputs[0].type}`, | ||
); | ||
} | ||
|
||
// for uint results we allow all comparisons but we check the provided value | ||
if (fn.outputs[0].baseType.includes('int')) { | ||
BigNumber.from(condition.value); | ||
} | ||
} catch (e: any) { | ||
assertError(e); | ||
throw new InvalidAccessCriteriaError(e.message); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.