From 9c2b10b19805b08e4c16cd818bf647ac78387f3c Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 26 Mar 2024 23:15:01 +0100 Subject: [PATCH 1/3] attempt at compatible compare gadget --- src/lib/provable/gadgets/comparison.ts | 51 ++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 src/lib/provable/gadgets/comparison.ts diff --git a/src/lib/provable/gadgets/comparison.ts b/src/lib/provable/gadgets/comparison.ts new file mode 100644 index 0000000000..cf3565d7cf --- /dev/null +++ b/src/lib/provable/gadgets/comparison.ts @@ -0,0 +1,51 @@ +import type { Field } from '../field.js'; +import { createBool, createField } from '../core/field-constructor.js'; +import { Fp } from '../../../bindings/crypto/finite-field.js'; +import { assert } from '../../../lib/util/assert.js'; +import { exists } from '../core/exists.js'; +import { assertMul } from './compatible.js'; + +export { compareCompatible }; + +/** + * Compare x and y assuming both have at most `n` bits. + */ +function compareCompatible(x: Field, y: Field, n = Fp.sizeInBits - 2) { + // compatible with snarky's `compare` + + let maxLength = Fp.sizeInBits - 2; + assert(n <= maxLength, `bitLength must be at most ${maxLength}`); + + // 2^n + x - y + let z = createField(1n << BigInt(n)) + .add(y) + .sub(x); + + let zBits = unpack(z, n + 1); + + // n-th bit tells us if x <= y + let lessOrEqual = zBits[n]; + + // other bits tell us if x = y + let prefix = zBits.slice(0, n); + let notAllZeros = prefix.reduce((a, b) => a.or(b)); + let less = lessOrEqual.and(notAllZeros); + + return { lessOrEqual, less }; +} + +// custom version of toBits to be compatible + +function unpack(x: Field, length: number) { + let bits = exists(length, () => { + let x0 = x.toBigInt(); + return Array.from({ length }, (_, k) => (x0 >> BigInt(k)) & 1n); + }); + bits.forEach((b) => b.assertBool()); + let lc = bits.reduce( + (acc, b, i) => acc.add(b.mul(1n << BigInt(i))), + createField(0) + ); + assertMul(lc, createField(1), x); + return bits.map((b) => createBool(b.value)); +} From 572b53b357855b54e3c8112abe64b89dd8b8eeb3 Mon Sep 17 00:00:00 2001 From: Gregor Date: Wed, 27 Mar 2024 11:03:29 +0100 Subject: [PATCH 2/3] fix and enable compatible comparison --- src/lib/provable/field.ts | 27 +++---------- src/lib/provable/gadgets/comparison.ts | 54 +++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/src/lib/provable/field.ts b/src/lib/provable/field.ts index 836386b83b..8db0e78349 100644 --- a/src/lib/provable/field.ts +++ b/src/lib/provable/field.ts @@ -23,6 +23,7 @@ import { } from './core/fieldvar.js'; import { exists, existsOne } from './core/exists.js'; import { setFieldConstructor } from './core/field-constructor.js'; +import { compareCompatible } from './gadgets/comparison.js'; // external API export { Field }; @@ -580,7 +581,7 @@ class Field { if (this.isConstant() && isConstant(y)) { return new Bool(this.toBigInt() < toFp(y)); } - return compare(this, toFieldVar(y)).less; + return compareCompatible(this, Field.from(y)).less; } /** @@ -610,7 +611,7 @@ class Field { if (this.isConstant() && isConstant(y)) { return new Bool(this.toBigInt() <= toFp(y)); } - return compare(this, toFieldVar(y)).lessOrEqual; + return compareCompatible(this, Field.from(y)).lessOrEqual; } /** @@ -688,7 +689,7 @@ class Field { } return; } - let { less } = compare(this, toFieldVar(y)); + let { less } = compareCompatible(this, Field.from(y)); less.assertTrue(); } catch (err) { throw withMessage(err, message); @@ -716,7 +717,7 @@ class Field { } return; } - let { lessOrEqual } = compare(this, toFieldVar(y)); + let { lessOrEqual } = compareCompatible(this, Field.from(y)); lessOrEqual.assertTrue(); } catch (err) { throw withMessage(err, message); @@ -1170,24 +1171,6 @@ function withMessage(error: unknown, message?: string) { return error; } -// internal base method for all comparisons -function compare(x: Field, y: FieldVar) { - // TODO: support all bit lengths - let maxLength = Fp.sizeInBits - 2; - asProver(() => { - let actualLength = Math.max( - x.toBigInt().toString(2).length, - new Field(y).toBigInt().toString(2).length - ); - if (actualLength > maxLength) - throw Error( - `Provable comparison functions can only be used on Fields of size <= ${maxLength} bits, got ${actualLength} bits.` - ); - }); - let [, less, lessOrEqual] = Snarky.field.compare(maxLength, x.value, y); - return { less: new Bool(less), lessOrEqual: new Bool(lessOrEqual) }; -} - function checkBitLength( name: string, length: number, diff --git a/src/lib/provable/gadgets/comparison.ts b/src/lib/provable/gadgets/comparison.ts index cf3565d7cf..f34c0e6709 100644 --- a/src/lib/provable/gadgets/comparison.ts +++ b/src/lib/provable/gadgets/comparison.ts @@ -1,41 +1,61 @@ import type { Field } from '../field.js'; +import type { Bool } from '../bool.js'; import { createBool, createField } from '../core/field-constructor.js'; import { Fp } from '../../../bindings/crypto/finite-field.js'; import { assert } from '../../../lib/util/assert.js'; import { exists } from '../core/exists.js'; import { assertMul } from './compatible.js'; +import { asProver } from '../core/provable-context.js'; export { compareCompatible }; /** * Compare x and y assuming both have at most `n` bits. + * + * **Important:** If `x` and `y` have more than `n` bits, this doesn't prove the comparison correctly. + * It is up to the caller to prove that `x` and `y` have at most `n` bits. + * + * **Warning:** This was created for 1:1 compatibility with snarky's `compare` gadget. + * It was designed for R1CS and is extremeley inefficient when used with plonkish arithmetization. */ function compareCompatible(x: Field, y: Field, n = Fp.sizeInBits - 2) { - // compatible with snarky's `compare` - let maxLength = Fp.sizeInBits - 2; assert(n <= maxLength, `bitLength must be at most ${maxLength}`); - // 2^n + x - y + // as prover check + asProver(() => { + let actualLength = Math.max( + x.toBigInt().toString(2).length, + y.toBigInt().toString(2).length + ); + if (actualLength > maxLength) + throw Error( + `Provable comparison functions can only be used on Fields of size <= ${maxLength} bits, got ${actualLength} bits.` + ); + }); + + // z = 2^n + y - x let z = createField(1n << BigInt(n)) .add(y) .sub(x); let zBits = unpack(z, n + 1); - // n-th bit tells us if x <= y + // highest (n-th) bit tells us if z >= 2^n + // which is equivalent to x <= y let lessOrEqual = zBits[n]; // other bits tell us if x = y let prefix = zBits.slice(0, n); - let notAllZeros = prefix.reduce((a, b) => a.or(b)); + let notAllZeros = any(prefix); let less = lessOrEqual.and(notAllZeros); return { lessOrEqual, less }; } -// custom version of toBits to be compatible +// helper functions for `compareCompatible()` +// custom version of toBits to be compatible function unpack(x: Field, length: number) { let bits = exists(length, () => { let x0 = x.toBigInt(); @@ -49,3 +69,25 @@ function unpack(x: Field, length: number) { assertMul(lc, createField(1), x); return bits.map((b) => createBool(b.value)); } + +function any(xs: Bool[]) { + let sum = xs.reduce((a, b) => a.add(b.toField()), createField(0)); + let allZero = isZero(sum); + return allZero.not(); +} + +// custom isZero to be compatible +function isZero(x: Field): Bool { + // create witnesses z = 1/x (or z=0 if x=0), and b = 1 - zx + let [b, z] = exists(2, () => { + let xmy = x.toBigInt(); + let z = Fp.inverse(xmy) ?? 0n; + let b = Fp.sub(1n, Fp.mul(z, xmy)); + return [b, z]; + }); + // b * x === 0 + assertMul(b, x, createField(0)); + // z * x === 1 - b + assertMul(z, x, createField(1).sub(b)); + return createBool(b.value); +} From 4ce01cc8dc5770a7e52b8ce7b7cca3bfa7998d3e Mon Sep 17 00:00:00 2001 From: Gregor Date: Wed, 27 Mar 2024 11:03:55 +0100 Subject: [PATCH 3/3] remove unused bindings --- src/snarky.d.ts | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/snarky.d.ts b/src/snarky.d.ts index b83d5a2711..e8f72dca55 100644 --- a/src/snarky.d.ts +++ b/src/snarky.d.ts @@ -161,14 +161,6 @@ declare const Snarky: { * x*x === x without handling of constants */ assertBoolean(x: FieldVar): void; - /** - * check x < y and x <= y - */ - compare( - bitLength: number, - x: FieldVar, - y: FieldVar - ): [_: 0, less: BoolVar, lessOrEqual: BoolVar]; /** * returns x truncated to the lowest `16 * lengthDiv16` bits * => can be used to assert that x fits in `16 * lengthDiv16` bits.