diff --git a/.noir-sync-commit b/.noir-sync-commit index 87a3bf56846..23a34206cf8 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -1df102a1ee0eb39dcbada50e10b226c7f7be0f26 \ No newline at end of file +0864e7c945089cc06f8cc9e5c7d933c465d8c892 diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs index 945b768efcf..b7b25c6db49 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs @@ -1,26 +1,290 @@ use acvm::{acir::brillig::MemoryAddress, AcirField}; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use super::{debug_show::DebugToString, registers::RegisterAllocator, BrilligContext}; impl BrilligContext { /// This function moves values from a set of registers to another set of registers. - /// It first moves all sources to new allocated registers to avoid overwriting. + /// The only requirement is that every destination needs to be written at most once. pub(crate) fn codegen_mov_registers_to_registers( &mut self, sources: Vec, destinations: Vec, ) { - let new_sources: Vec<_> = sources - .iter() - .map(|source| { - let new_source = self.allocate_register(); - self.mov_instruction(new_source, *source); - new_source - }) + assert_eq!(sources.len(), destinations.len()); + // Remove all no-ops + let movements: Vec<_> = sources + .into_iter() + .zip(destinations) + .filter(|(source, destination)| source != destination) .collect(); - for (new_source, destination) in new_sources.iter().zip(destinations.iter()) { - self.mov_instruction(*destination, *new_source); - self.deallocate_register(*new_source); + + // Now we need to detect all cycles. + // First build a map of the movements. Note that a source could have multiple destinations + let mut movements_map: HashMap> = + movements.into_iter().fold(HashMap::default(), |mut map, (source, destination)| { + map.entry(source).or_default().insert(destination); + map + }); + + let destinations_set: HashSet<_> = movements_map.values().flatten().copied().collect(); + assert_eq!( + destinations_set.len(), + movements_map.values().flatten().count(), + "Multiple moves to the same register found" + ); + + let mut loop_detector = LoopDetector::default(); + loop_detector.collect_loops(&movements_map); + let loops = loop_detector.loops; + // In order to break the loops we need to store one register from each in a temporary and then use that temporary as source. + let mut temporaries = Vec::with_capacity(loops.len()); + for loop_found in loops { + let temp_register = self.allocate_register(); + temporaries.push(temp_register); + let first_source = loop_found.iter().next().unwrap(); + self.mov_instruction(temp_register, *first_source); + let destinations_of_temp = movements_map.remove(first_source).unwrap(); + movements_map.insert(temp_register, destinations_of_temp); + } + // After removing loops we should have an DAG with each node having only one ancestor (but could have multiple successors) + // Now we should be able to move the registers just by performing a DFS on the movements map + let heads: Vec<_> = movements_map + .keys() + .filter(|source| !destinations_set.contains(source)) + .copied() + .collect(); + for head in heads { + self.perform_movements(&movements_map, head); + } + + // Deallocate all temporaries + for temp in temporaries { + self.deallocate_register(temp); } } + + fn perform_movements( + &mut self, + movements: &HashMap>, + current_source: MemoryAddress, + ) { + if let Some(destinations) = movements.get(¤t_source) { + for destination in destinations { + self.perform_movements(movements, *destination); + } + for destination in destinations { + self.mov_instruction(*destination, current_source); + } + } + } +} + +#[derive(Default)] +struct LoopDetector { + visited_sources: HashSet, + loops: Vec>, +} + +impl LoopDetector { + fn collect_loops(&mut self, movements: &HashMap>) { + for source in movements.keys() { + self.find_loop_recursive(*source, movements, im::OrdSet::default()); + } + } + + fn find_loop_recursive( + &mut self, + source: MemoryAddress, + movements: &HashMap>, + mut previous_sources: im::OrdSet, + ) { + if self.visited_sources.contains(&source) { + return; + } + // Mark as visited + self.visited_sources.insert(source); + + previous_sources.insert(source); + // Get all destinations + if let Some(destinations) = movements.get(&source) { + for destination in destinations { + if previous_sources.contains(destination) { + // Found a loop + let loop_sources = previous_sources.clone(); + self.loops.push(loop_sources); + } else { + self.find_loop_recursive(*destination, movements, previous_sources.clone()); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use acvm::{ + acir::brillig::{MemoryAddress, Opcode}, + FieldElement, + }; + use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; + + use crate::{ + brillig::brillig_ir::{artifact::Label, registers::Stack, BrilligContext}, + ssa::ir::function::FunctionId, + }; + + // Tests for the loop finder + + fn generate_movements_map( + movements: Vec<(usize, usize)>, + ) -> HashMap> { + movements.into_iter().fold(HashMap::default(), |mut map, (source, destination)| { + map.entry(MemoryAddress(source)).or_default().insert(MemoryAddress(destination)); + map + }) + } + + #[test] + fn test_loop_detector_basic_loop() { + let movements = vec![(0, 1), (1, 2), (2, 3), (3, 0)]; + let movements_map = generate_movements_map(movements); + let mut loop_detector = super::LoopDetector::default(); + loop_detector.collect_loops(&movements_map); + assert_eq!(loop_detector.loops.len(), 1); + assert_eq!(loop_detector.loops[0].len(), 4); + } + + #[test] + fn test_loop_detector_no_loop() { + let movements = vec![(0, 1), (1, 2), (2, 3), (3, 4)]; + let movements_map = generate_movements_map(movements); + let mut loop_detector = super::LoopDetector::default(); + loop_detector.collect_loops(&movements_map); + assert_eq!(loop_detector.loops.len(), 0); + } + + #[test] + fn test_loop_detector_loop_with_branch() { + let movements = vec![(0, 1), (1, 2), (2, 0), (0, 3), (3, 4)]; + let movements_map = generate_movements_map(movements); + let mut loop_detector = super::LoopDetector::default(); + loop_detector.collect_loops(&movements_map); + assert_eq!(loop_detector.loops.len(), 1); + assert_eq!(loop_detector.loops[0].len(), 3); + } + + #[test] + fn test_loop_detector_two_loops() { + let movements = vec![(0, 1), (1, 2), (2, 0), (3, 4), (4, 5), (5, 3)]; + let movements_map = generate_movements_map(movements); + let mut loop_detector = super::LoopDetector::default(); + loop_detector.collect_loops(&movements_map); + assert_eq!(loop_detector.loops.len(), 2); + assert_eq!(loop_detector.loops[0].len(), 3); + assert_eq!(loop_detector.loops[1].len(), 3); + } + + // Tests for mov_registers_to_registers + + fn movements_to_source_and_destinations( + movements: Vec<(usize, usize)>, + ) -> (Vec, Vec) { + let sources = movements.iter().map(|(source, _)| MemoryAddress::from(*source)).collect(); + let destinations = + movements.iter().map(|(_, destination)| MemoryAddress::from(*destination)).collect(); + (sources, destinations) + } + + pub(crate) fn create_context() -> BrilligContext { + let mut context = BrilligContext::new(true); + context.enter_context(Label::function(FunctionId::test_new(0))); + context + } + + #[test] + #[should_panic(expected = "Multiple moves to the same register found")] + fn test_mov_registers_to_registers_overwrite() { + let movements = vec![(10, 11), (12, 11), (10, 13)]; + let (sources, destinations) = movements_to_source_and_destinations(movements); + let mut context = create_context(); + + context.codegen_mov_registers_to_registers(sources, destinations); + } + + #[test] + fn test_mov_registers_to_registers_no_loop() { + let movements = vec![(10, 11), (11, 12), (12, 13), (13, 14)]; + let (sources, destinations) = movements_to_source_and_destinations(movements); + let mut context = create_context(); + + context.codegen_mov_registers_to_registers(sources, destinations); + let opcodes = context.artifact().byte_code; + assert_eq!( + opcodes, + vec![ + Opcode::Mov { destination: MemoryAddress(14), source: MemoryAddress(13) }, + Opcode::Mov { destination: MemoryAddress(13), source: MemoryAddress(12) }, + Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, + Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(10) }, + ] + ); + } + #[test] + fn test_mov_registers_to_registers_no_op_filter() { + let movements = vec![(10, 11), (11, 11), (11, 12)]; + let (sources, destinations) = movements_to_source_and_destinations(movements); + let mut context = create_context(); + + context.codegen_mov_registers_to_registers(sources, destinations); + let opcodes = context.artifact().byte_code; + assert_eq!( + opcodes, + vec![ + Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, + Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(10) }, + ] + ); + } + + #[test] + fn test_mov_registers_to_registers_loop() { + let movements = vec![(10, 11), (11, 12), (12, 13), (13, 10)]; + let (sources, destinations) = movements_to_source_and_destinations(movements); + let mut context = create_context(); + + context.codegen_mov_registers_to_registers(sources, destinations); + let opcodes = context.artifact().byte_code; + assert_eq!( + opcodes, + vec![ + Opcode::Mov { destination: MemoryAddress(3), source: MemoryAddress(10) }, + Opcode::Mov { destination: MemoryAddress(10), source: MemoryAddress(13) }, + Opcode::Mov { destination: MemoryAddress(13), source: MemoryAddress(12) }, + Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, + Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(3) } + ] + ); + } + + #[test] + fn test_mov_registers_to_registers_loop_and_branch() { + let movements = vec![(10, 11), (11, 12), (12, 10), (10, 13), (13, 14)]; + let (sources, destinations) = movements_to_source_and_destinations(movements); + let mut context = create_context(); + + context.codegen_mov_registers_to_registers(sources, destinations); + let opcodes = context.artifact().byte_code; + assert_eq!( + opcodes, + vec![ + Opcode::Mov { destination: MemoryAddress(3), source: MemoryAddress(10) }, // Temporary + Opcode::Mov { destination: MemoryAddress(14), source: MemoryAddress(13) }, // Branch + Opcode::Mov { destination: MemoryAddress(10), source: MemoryAddress(12) }, // Loop + Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, // Loop + Opcode::Mov { destination: MemoryAddress(13), source: MemoryAddress(3) }, // Finish branch + Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(3) } // Finish loop + ] + ); + } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs index 6be18df7b52..264b83956f8 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs @@ -633,7 +633,7 @@ impl<'context> Elaborator<'context> { 0 } - pub fn unify( + pub(super) fn unify( &mut self, actual: &Type, expected: &Type, @@ -644,6 +644,22 @@ impl<'context> Elaborator<'context> { } } + /// Do not apply type bindings even after a successful unification. + /// This function is used by the interpreter for some comptime code + /// which can change types e.g. on each iteration of a for loop. + pub fn unify_without_applying_bindings( + &mut self, + actual: &Type, + expected: &Type, + file: fm::FileId, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut bindings = TypeBindings::new(); + if actual.try_unify(expected, &mut bindings).is_err() { + self.errors.push((make_error().into(), file)); + } + } + /// Wrapper of Type::unify_with_coercions using self.errors pub(super) fn unify_with_coercions( &mut self, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index b5ed8126e33..e920073b453 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1303,9 +1303,11 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { // Macro calls are typed as type variables during type checking. // Now that we know the type we need to further unify it in case there // are inconsistencies or the type needs to be known. + // We don't commit any type bindings made this way in case the type of + // the macro result changes across loop iterations. let expected_type = self.elaborator.interner.id_type(id); let actual_type = result.get_type(); - self.unify(&actual_type, &expected_type, location); + self.unify_without_binding(&actual_type, &expected_type, location); } Ok(result) } @@ -1319,16 +1321,14 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } } - fn unify(&mut self, actual: &Type, expected: &Type, location: Location) { - // We need to swap out the elaborator's file since we may be - // in a different one currently, and it uses that for the error location. - let old_file = std::mem::replace(&mut self.elaborator.file, location.file); - self.elaborator.unify(actual, expected, || TypeCheckError::TypeMismatch { - expected_typ: expected.to_string(), - expr_typ: actual.to_string(), - expr_span: location.span, + fn unify_without_binding(&mut self, actual: &Type, expected: &Type, location: Location) { + self.elaborator.unify_without_applying_bindings(actual, expected, location.file, || { + TypeCheckError::TypeMismatch { + expected_typ: expected.to_string(), + expr_typ: actual.to_string(), + expr_span: location.span, + } }); - self.elaborator.file = old_file; } fn evaluate_method_call( diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs index 22de18b6461..672328c05bd 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs @@ -3722,5 +3722,29 @@ fn use_numeric_generic_in_trait_method() { "#; let errors = get_program_errors(src); + println!("{errors:?}"); assert_eq!(errors.len(), 0); } + +#[test] +fn macro_result_type_mismatch() { + let src = r#" + fn main() { + comptime { + let x = unquote!(quote { "test" }); + let _: Field = x; + } + } + + comptime fn unquote(q: Quoted) -> Quoted { + q + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + errors[0].0, + CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) + )); +} diff --git a/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr b/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr index d93b4f41cf0..38bc1764b64 100644 --- a/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr +++ b/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr @@ -71,6 +71,21 @@ impl EmbeddedCurveScalar { let (a,b) = crate::field::bn254::decompose(scalar); EmbeddedCurveScalar { lo: a, hi: b } } + + //Bytes to scalar: take the first (after the specified offset) 16 bytes of the input as the lo value, and the next 16 bytes as the hi value + #[field(bn254)] + fn from_bytes(bytes: [u8; 64], offset: u32) -> EmbeddedCurveScalar { + let mut v = 1; + let mut lo = 0 as Field; + let mut hi = 0 as Field; + for i in 0..16 { + lo = lo + (bytes[offset+31 - i] as Field) * v; + hi = hi + (bytes[offset+15 - i] as Field) * v; + v = v * 256; + } + let sig_s = crate::embedded_curve_ops::EmbeddedCurveScalar { lo, hi }; + sig_s + } } impl Eq for EmbeddedCurveScalar { diff --git a/noir/noir-repo/noir_stdlib/src/schnorr.nr b/noir/noir-repo/noir_stdlib/src/schnorr.nr index 24ca514025c..336041fec19 100644 --- a/noir/noir-repo/noir_stdlib/src/schnorr.nr +++ b/noir/noir-repo/noir_stdlib/src/schnorr.nr @@ -1,3 +1,6 @@ +use crate::collections::vec::Vec; +use crate::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar}; + #[foreign(schnorr_verify)] // docs:start:schnorr_verify pub fn verify_signature( @@ -20,3 +23,65 @@ pub fn verify_signature_slice( // docs:end:schnorr_verify_slice {} +pub fn verify_signature_noir(public_key: EmbeddedCurvePoint, signature: [u8; 64], message: [u8; N]) -> bool { + //scalar lo/hi from bytes + let sig_s = EmbeddedCurveScalar::from_bytes(signature, 0); + let sig_e = EmbeddedCurveScalar::from_bytes(signature, 32); + // pub_key is on Grumpkin curve + let mut is_ok = (public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17) + & (!public_key.is_infinite); + + if ((sig_s.lo != 0) | (sig_s.hi != 0)) & ((sig_e.lo != 0) | (sig_e.hi != 0)) { + let (r_is_infinite, result) = calculate_signature_challenge(public_key, sig_s, sig_e, message); + + is_ok = !r_is_infinite; + for i in 0..32 { + is_ok &= result[i] == signature[32 + i]; + } + } + is_ok +} + +pub fn assert_valid_signature(public_key: EmbeddedCurvePoint, signature: [u8; 64], message: [u8; N]) { + //scalar lo/hi from bytes + let sig_s = EmbeddedCurveScalar::from_bytes(signature, 0); + let sig_e = EmbeddedCurveScalar::from_bytes(signature, 32); + + // assert pub_key is on Grumpkin curve + assert(public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17); + assert(public_key.is_infinite == false); + // assert signature is not null + assert((sig_s.lo != 0) | (sig_s.hi != 0)); + assert((sig_e.lo != 0) | (sig_e.hi != 0)); + + let (r_is_infinite, result) = calculate_signature_challenge(public_key, sig_s, sig_e, message); + + assert(!r_is_infinite); + for i in 0..32 { + assert(result[i] == signature[32 + i]); + } +} + +fn calculate_signature_challenge( + public_key: EmbeddedCurvePoint, + sig_s: EmbeddedCurveScalar, + sig_e: EmbeddedCurveScalar, + message: [u8; N] +) -> (bool, [u8; 32]) { + let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let r = crate::embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]); + // compare the _hashes_ rather than field elements modulo r + let pedersen_hash = crate::hash::pedersen_hash([r.x, public_key.x, public_key.y]); + let pde: [u8; 32] = pedersen_hash.to_be_bytes(); + + let mut hash_input = [0; N + 32]; + for i in 0..32 { + hash_input[i] = pde[i]; + } + for i in 0..N { + hash_input[32+i] = message[i]; + } + + let result = crate::hash::blake2s(hash_input); + (r.is_infinite, result) +} diff --git a/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/Nargo.toml b/noir/noir-repo/test_programs/compile_failure/macro_result_type/Nargo.toml similarity index 100% rename from noir/noir-repo/test_programs/compile_success_empty/macro_result_type/Nargo.toml rename to noir/noir-repo/test_programs/compile_failure/macro_result_type/Nargo.toml diff --git a/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/src/main.nr b/noir/noir-repo/test_programs/compile_failure/macro_result_type/src/main.nr similarity index 100% rename from noir/noir-repo/test_programs/compile_success_empty/macro_result_type/src/main.nr rename to noir/noir-repo/test_programs/compile_failure/macro_result_type/src/main.nr diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml new file mode 100644 index 00000000000..38e72395bb5 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "comptime_change_type_each_iteration" +type = "bin" +authors = [""] +compiler_version = ">=0.34.0" + +[dependencies] diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr new file mode 100644 index 00000000000..7b34c112d4f --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr @@ -0,0 +1,19 @@ +fn main() { + comptime + { + for i in 9..11 { + // Lengths are different on each iteration: + // foo9, foo10 + let name = f"foo{i}".as_ctstring().as_quoted_str!(); + + // So to call `from_signature` we need to delay the type check + // by quoting the function call so that we re-typecheck on each iteration + let hash = std::meta::unquote!(quote { from_signature($name) }); + assert(hash > 3); + } + } +} + +fn from_signature(_signature: str) -> u32 { + N +} diff --git a/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs b/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs deleted file mode 100644 index bcd91d7bf5d..00000000000 --- a/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs +++ /dev/null @@ -1,12 +0,0 @@ - -trait Foo { - fn foo() {} -} - -impl Foo<3> for () { - fn foo() {} -} - -fn main() { - let _ = Foo::foo(); -} diff --git a/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr b/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr index b64078e6b46..835ea2ffb1f 100644 --- a/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr +++ b/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr @@ -12,11 +12,6 @@ fn main( // Regression for issue #2421 // We want to make sure that we can accurately verify a signature whose message is a slice vs. an array let message_field_bytes: [u8; 10] = message_field.to_be_bytes(); - let mut message2 = [0; 42]; - for i in 0..10 { - assert(message[i] == message_field_bytes[i]); - message2[i] = message[i]; - } // Is there ever a situation where someone would want // to ensure that a signature was invalid? @@ -27,102 +22,7 @@ fn main( let valid_signature = std::schnorr::verify_signature(pub_key_x, pub_key_y, signature, message); assert(valid_signature); let pub_key = embedded_curve_ops::EmbeddedCurvePoint { x: pub_key_x, y: pub_key_y, is_infinite: false }; - let valid_signature = verify_signature_noir(pub_key, signature, message2); + let valid_signature = std::schnorr::verify_signature_noir(pub_key, signature, message); assert(valid_signature); - assert_valid_signature(pub_key, signature, message2); -} - -// TODO: to put in the stdlib once we have numeric generics -// Meanwhile, you have to use a message with 32 additional bytes: -// If you want to verify a signature on a message of 10 bytes, you need to pass a message of length 42, -// where the first 10 bytes are the one from the original message (the other bytes are not used) -pub fn verify_signature_noir( - public_key: embedded_curve_ops::EmbeddedCurvePoint, - signature: [u8; 64], - message: [u8; M] -) -> bool { - let N = message.len() - 32; - - //scalar lo/hi from bytes - let sig_s = bytes_to_scalar(signature, 0); - let sig_e = bytes_to_scalar(signature, 32); - // pub_key is on Grumpkin curve - let mut is_ok = (public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17) - & (!public_key.is_infinite); - - if ((sig_s.lo != 0) | (sig_s.hi != 0)) & ((sig_e.lo != 0) | (sig_e.hi != 0)) { - let g1 = embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; - let r = embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]); - // compare the _hashes_ rather than field elements modulo r - let pedersen_hash = std::hash::pedersen_hash([r.x, public_key.x, public_key.y]); - let mut hash_input = [0; M]; - let pde: [u8; 32] = pedersen_hash.to_be_bytes(); - - for i in 0..32 { - hash_input[i] = pde[i]; - } - for i in 0..N { - hash_input[32+i] = message[i]; - } - let result = std::hash::blake2s(hash_input); - - is_ok = !r.is_infinite; - for i in 0..32 { - if result[i] != signature[32 + i] { - is_ok = false; - } - } - } - is_ok -} - -pub fn bytes_to_scalar(bytes: [u8; 64], offset: u32) -> embedded_curve_ops::EmbeddedCurveScalar { - let mut v = 1; - let mut lo = 0 as Field; - let mut hi = 0 as Field; - for i in 0..16 { - lo = lo + (bytes[offset+31 - i] as Field) * v; - hi = hi + (bytes[offset+15 - i] as Field) * v; - v = v * 256; - } - let sig_s = embedded_curve_ops::EmbeddedCurveScalar { lo, hi }; - sig_s -} - -pub fn assert_valid_signature( - public_key: embedded_curve_ops::EmbeddedCurvePoint, - signature: [u8; 64], - message: [u8; M] -) { - let N = message.len() - 32; - //scalar lo/hi from bytes - let sig_s = bytes_to_scalar(signature, 0); - let sig_e = bytes_to_scalar(signature, 32); - - // assert pub_key is on Grumpkin curve - assert(public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17); - assert(public_key.is_infinite == false); - // assert signature is not null - assert((sig_s.lo != 0) | (sig_s.hi != 0)); - assert((sig_e.lo != 0) | (sig_e.hi != 0)); - - let g1 = embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; - let r = embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]); - // compare the _hashes_ rather than field elements modulo r - let pedersen_hash = std::hash::pedersen_hash([r.x, public_key.x, public_key.y]); - let mut hash_input = [0; M]; - let pde: [u8; 32] = pedersen_hash.to_be_bytes(); - - for i in 0..32 { - hash_input[i] = pde[i]; - } - for i in 0..N { - hash_input[32+i] = message[i]; - } - let result = std::hash::blake2s(hash_input); - - assert(!r.is_infinite); - for i in 0..32 { - assert(result[i] == signature[32 + i]); - } + std::schnorr::assert_valid_signature(pub_key, signature, message); }