Skip to content

Commit

Permalink
Merge pull request #1524 from o1-labs/feature/compatible-comparisons
Browse files Browse the repository at this point in the history
Compatible comparisons
  • Loading branch information
mitschabaude authored Mar 28, 2024
2 parents a892d6a + 4ce01cc commit e34d28f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 30 deletions.
27 changes: 5 additions & 22 deletions src/lib/provable/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
} from './core/fieldvar.js';
import { exists, existsOne } from './core/exists.js';
import { setFieldConstructor } from './core/field-constructor.js';
import { compareCompatible } from './gadgets/comparison.js';

// external API
export { Field };
Expand Down Expand Up @@ -580,7 +581,7 @@ class Field {
if (this.isConstant() && isConstant(y)) {
return new Bool(this.toBigInt() < toFp(y));
}
return compare(this, toFieldVar(y)).less;
return compareCompatible(this, Field.from(y)).less;
}

/**
Expand Down Expand Up @@ -610,7 +611,7 @@ class Field {
if (this.isConstant() && isConstant(y)) {
return new Bool(this.toBigInt() <= toFp(y));
}
return compare(this, toFieldVar(y)).lessOrEqual;
return compareCompatible(this, Field.from(y)).lessOrEqual;
}

/**
Expand Down Expand Up @@ -688,7 +689,7 @@ class Field {
}
return;
}
let { less } = compare(this, toFieldVar(y));
let { less } = compareCompatible(this, Field.from(y));
less.assertTrue();
} catch (err) {
throw withMessage(err, message);
Expand Down Expand Up @@ -716,7 +717,7 @@ class Field {
}
return;
}
let { lessOrEqual } = compare(this, toFieldVar(y));
let { lessOrEqual } = compareCompatible(this, Field.from(y));
lessOrEqual.assertTrue();
} catch (err) {
throw withMessage(err, message);
Expand Down Expand Up @@ -1170,24 +1171,6 @@ function withMessage(error: unknown, message?: string) {
return error;
}

// internal base method for all comparisons
function compare(x: Field, y: FieldVar) {
// TODO: support all bit lengths
let maxLength = Fp.sizeInBits - 2;
asProver(() => {
let actualLength = Math.max(
x.toBigInt().toString(2).length,
new Field(y).toBigInt().toString(2).length
);
if (actualLength > maxLength)
throw Error(
`Provable comparison functions can only be used on Fields of size <= ${maxLength} bits, got ${actualLength} bits.`
);
});
let [, less, lessOrEqual] = Snarky.field.compare(maxLength, x.value, y);
return { less: new Bool(less), lessOrEqual: new Bool(lessOrEqual) };
}

function checkBitLength(
name: string,
length: number,
Expand Down
93 changes: 93 additions & 0 deletions src/lib/provable/gadgets/comparison.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import type { Field } from '../field.js';
import type { Bool } from '../bool.js';
import { createBool, createField } from '../core/field-constructor.js';
import { Fp } from '../../../bindings/crypto/finite-field.js';
import { assert } from '../../../lib/util/assert.js';
import { exists } from '../core/exists.js';
import { assertMul } from './compatible.js';
import { asProver } from '../core/provable-context.js';

export { compareCompatible };

/**
* Compare x and y assuming both have at most `n` bits.
*
* **Important:** If `x` and `y` have more than `n` bits, this doesn't prove the comparison correctly.
* It is up to the caller to prove that `x` and `y` have at most `n` bits.
*
* **Warning:** This was created for 1:1 compatibility with snarky's `compare` gadget.
* It was designed for R1CS and is extremeley inefficient when used with plonkish arithmetization.
*/
function compareCompatible(x: Field, y: Field, n = Fp.sizeInBits - 2) {
let maxLength = Fp.sizeInBits - 2;
assert(n <= maxLength, `bitLength must be at most ${maxLength}`);

// as prover check
asProver(() => {
let actualLength = Math.max(
x.toBigInt().toString(2).length,
y.toBigInt().toString(2).length
);
if (actualLength > maxLength)
throw Error(
`Provable comparison functions can only be used on Fields of size <= ${maxLength} bits, got ${actualLength} bits.`
);
});

// z = 2^n + y - x
let z = createField(1n << BigInt(n))
.add(y)
.sub(x);

let zBits = unpack(z, n + 1);

// highest (n-th) bit tells us if z >= 2^n
// which is equivalent to x <= y
let lessOrEqual = zBits[n];

// other bits tell us if x = y
let prefix = zBits.slice(0, n);
let notAllZeros = any(prefix);
let less = lessOrEqual.and(notAllZeros);

return { lessOrEqual, less };
}

// helper functions for `compareCompatible()`

// custom version of toBits to be compatible
function unpack(x: Field, length: number) {
let bits = exists(length, () => {
let x0 = x.toBigInt();
return Array.from({ length }, (_, k) => (x0 >> BigInt(k)) & 1n);
});
bits.forEach((b) => b.assertBool());
let lc = bits.reduce(
(acc, b, i) => acc.add(b.mul(1n << BigInt(i))),
createField(0)
);
assertMul(lc, createField(1), x);
return bits.map((b) => createBool(b.value));
}

function any(xs: Bool[]) {
let sum = xs.reduce((a, b) => a.add(b.toField()), createField(0));
let allZero = isZero(sum);
return allZero.not();
}

// custom isZero to be compatible
function isZero(x: Field): Bool {
// create witnesses z = 1/x (or z=0 if x=0), and b = 1 - zx
let [b, z] = exists(2, () => {
let xmy = x.toBigInt();
let z = Fp.inverse(xmy) ?? 0n;
let b = Fp.sub(1n, Fp.mul(z, xmy));
return [b, z];
});
// b * x === 0
assertMul(b, x, createField(0));
// z * x === 1 - b
assertMul(z, x, createField(1).sub(b));
return createBool(b.value);
}
8 changes: 0 additions & 8 deletions src/snarky.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,6 @@ declare const Snarky: {
* x*x === x without handling of constants
*/
assertBoolean(x: FieldVar): void;
/**
* check x < y and x <= y
*/
compare(
bitLength: number,
x: FieldVar,
y: FieldVar
): [_: 0, less: BoolVar, lessOrEqual: BoolVar];
/**
* returns x truncated to the lowest `16 * lengthDiv16` bits
* => can be used to assert that x fits in `16 * lengthDiv16` bits.
Expand Down

0 comments on commit e34d28f

Please sign in to comment.