Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement NTRU Prime #384

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions kem/ntruprime/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:generate go run gen.go

// Package ntruprime implements the NTRU Prime IND-CCA2 secure
// key encapsulation mechanism (KEM) as submitted to round 3 of the NIST PQC
// competition and described in
//
// https://ntruprime.cr.yp.to/nist/ntruprime-20201007.pdf
//
// The code is translated from the C reference implementation.
package ntruprime
139 changes: 139 additions & 0 deletions kem/ntruprime/gen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
//go:build ignore
// +build ignore

package main

import (
"bytes"
"go/format"
"io/ioutil"
"strings"
"text/template"
)

type Instance struct {
Name string
Hash string
}

func (m Instance) Pkg() string {
return strings.ToLower(m.Name)
}

var (
SInstances = []Instance{
{Name: "SNTRUP761"},
{Name: "SNTRUP653"},
{Name: "SNTRUP857"},
{Name: "SNTRUP953"},
{Name: "SNTRUP1013"},
{Name: "SNTRUP1277"},
}
LPRInstances = []Instance{
{Name: "NTRULPR761"},
{Name: "NTRULPR653"},
{Name: "NTRULPR857"},
{Name: "NTRULPR953"},
{Name: "NTRULPR1013"},
{Name: "NTRULPR1277"},
}
TemplateWarning = "// Code generated from"
)

func main() {
generateStreamlinedPackageFiles()
generateLPRPackageFiles()
}

func generateStreamlinedPackageFiles() {
template, err := template.ParseFiles("templates/sntrup.templ.go")
if err != nil {
panic(err)
}

for _, mode := range SInstances {
buf := new(bytes.Buffer)
err := template.Execute(buf, mode)
if err != nil {
panic(err)
}

// Formating output code
code, err := format.Source(buf.Bytes())
if err != nil {
panic("error formating code")
}

res := string(code)
offset := strings.Index(res, TemplateWarning)
if offset == -1 {
panic("Missing template warning in pkg.templ.go")
}
err = ioutil.WriteFile(mode.Pkg()+"/ntruprime.go", []byte(res[offset:]), 0o644)
if err != nil {
panic(err)
}
}
}

func generateLPRPackageFiles() {
template, err := template.ParseFiles("templates/ntrulpr.templ.go")
if err != nil {
panic(err)
}

for _, mode := range LPRInstances {
buf := new(bytes.Buffer)
err := template.Execute(buf, mode)
if err != nil {
panic(err)
}

// Formating output code
code, err := format.Source(buf.Bytes())
if err != nil {
panic("error formating code")
}

res := string(code)
offset := strings.Index(res, TemplateWarning)
if offset == -1 {
panic("Missing template warning in pkg.templ.go")
}
err = ioutil.WriteFile(mode.Pkg()+"/ntruprime.go", []byte(res[offset:]), 0o644)
if err != nil {
panic(err)
}
}
}

func generateKAT() {
template, err := template.ParseFiles("templates/kat.templ.go")
if err != nil {
panic(err)
}

for _, mode := range SInstances {
buf := new(bytes.Buffer)
err := template.Execute(buf, mode)
if err != nil {
panic(err)
}

// Formating output code
code, err := format.Source(buf.Bytes())
if err != nil {
panic("error formating code")
}

res := string(code)
offset := strings.Index(res, TemplateWarning)
if offset == -1 {
panic("Missing template warning in pkg.templ.go")
}
err = ioutil.WriteFile(mode.Pkg()+"/kat_test.go", []byte(res[offset:]), 0o600)
if err != nil {
panic(err)
}
}
}
67 changes: 67 additions & 0 deletions kem/ntruprime/internal/Decode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package internal

// TO DO: Optimize the Decode function
/* Decode(R,s,M,len) */
/* assumes 0 < M[i] < 16384 */
/* produces 0 <= R[i] < M[i] */
func Decode(out []uint16, S []uint8, M []uint16, len int) {
index := 0
if len == 1 {
if M[0] == 1 {
out[index] = 0
} else if M[0] <= 256 {
out[index] = Uint32ModUint14(uint32(S[0]), M[0])
} else {
out[index] = Uint32ModUint14(uint32(uint16(S[0])+((uint16(S[1]))<<8)), M[0])
}
}
if len > 1 {
R2 := make([]uint16, (len+1)/2)
M2 := make([]uint16, (len+1)/2)
bottomr := make([]uint16, len/2)
bottomt := make([]uint32, len/2)
i := 0
for i = 0; i < len-1; i += 2 {
m := uint32(M[i]) * uint32(M[i+1])

if m > 256*16383 {
bottomt[i/2] = 256 * 256
bottomr[i/2] = uint16(S[0]) + 256*uint16(S[1])
S = S[2:]
M2[i/2] = uint16((((m + 255) >> 8) + 255) >> 8)
} else if m >= 16384 {
bottomt[i/2] = 256
bottomr[i/2] = uint16(S[0])
S = S[1:]
M2[i/2] = uint16((m + 255) >> 8)
} else {
bottomt[i/2] = 1
bottomr[i/2] = 0
M2[i/2] = uint16(m)
}
}
if i < len {
M2[i/2] = M[i]
}

Decode(R2, S, M2, (len+1)/2)

for i = 0; i < len-1; i += 2 {
r := uint32(bottomr[i/2])
var r1 uint32
var r0 uint16

r += bottomt[i/2] * uint32(R2[i/2])
Uint32DivmodUint14(&r1, &r0, r, M[i])
r1 = uint32(Uint32ModUint14(r1, M[i+1])) /* only needed for invalid inputs */

out[index] = r0
index++
out[index] = uint16(r1)
index++
}
if i < len {
out[index] = R2[i/2]
}
}
}
102 changes: 102 additions & 0 deletions kem/ntruprime/internal/Divmod.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package internal

/*
CPU division instruction typically takes time depending on x.
This software is designed to take time independent of x.
Time still varies depending on m; user must ensure that m is constant.
Time also varies on CPUs where multiplication is variable-time.
There could be more CPU issues.
There could also be compiler issues.
*/
// q, r = x/m
// Returns quotient and remainder
func Uint32DivmodUint14(q *uint32, r *uint16, x uint32, m uint16) {
var v uint32 = 0x80000000

v /= uint32(m)

*q = 0

qpart := uint32(uint64(x) * uint64(v) >> 31)

x -= qpart * uint32(m)
*q += qpart

qpart = uint32(uint64(x) * uint64(v) >> 31)
x -= qpart * uint32(m)
*q += qpart

x -= uint32(m)
*q += 1
mask := -(x >> 31)
x += mask & uint32(m)
*q += mask

*r = uint16(x)
}

// Returns the quotient of x/m
func Uint32DivUint14(x uint32, m uint16) uint32 {
var q uint32
var r uint16
Uint32DivmodUint14(&q, &r, x, m)
return q
}

// Returns the remainder of x/m
func Uint32ModUint14(x uint32, m uint16) uint16 {
var q uint32
var r uint16
Uint32DivmodUint14(&q, &r, x, m)
return r
}

// Calculates quotient and remainder
func Int32DivmodUint14(q *int32, r *uint16, x int32, m uint16) {
var uq, uq2 uint32
var ur, ur2 uint16
var mask uint32

Uint32DivmodUint14(&uq, &ur, 0x80000000+uint32(x), m)
Uint32DivmodUint14(&uq2, &ur2, 0x80000000, m)

ur -= ur2
uq -= uq2
mask = -(uint32)(ur >> 15)
ur += uint16(mask & uint32(m))
uq += mask
*r = ur
*q = int32(uq)
}

// Returns quotient of x/m
func Int32DivUint14(x int32, m uint16) int32 {
var q int32
var r uint16
Int32DivmodUint14(&q, &r, x, m)
return q
}

// Returns remainder of x/m
func Int32ModUint14(x int32, m uint16) uint16 {
var q int32
var r uint16
Int32DivmodUint14(&q, &r, x, m)
return r
}

// Returns -1 if x!=0; else return 0
func Int16NonzeroMask(x int16) int {
u := uint16(x) /* 0, else 1...65535 */
v := uint32(u) /* 0, else 1...65535 */
v = -v /* 0, else 2^32-65535...2^32-1 */
v >>= 31 /* 0, else 1 */
return -int(v) /* 0, else -1 */
}

// Returns -1 if x<0; otherwise return 0
func Int16NegativeMask(x int16) int {
u := uint16(x)
u >>= 15
return -(int)(u)
}
42 changes: 42 additions & 0 deletions kem/ntruprime/internal/Encode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package internal

/* 0 <= R[i] < M[i] < 16384 */
func Encode(out []uint8, R []uint16, M []uint16, len int) {
if len > 1 {
R2 := make([]uint16, (len+1)/2)
M2 := make([]uint16, (len+1)/2)
var i int
for ; len > 1; len = (len + 1) / 2 {
for i = 0; i < len-1; i += 2 {
m0 := uint32(M[i])
r := uint32(R[i]) + uint32(R[i+1])*m0
m := uint32(M[i+1]) * m0
for m >= 16384 {
out[0] = uint8(r)
out = out[1:]

r >>= 8
m = (m + 255) >> 8
}
R2[i/2] = uint16(r)
M2[i/2] = uint16(m)
}
if i < len {
R2[i/2] = R[i]
M2[i/2] = M[i]
}
copy(R, R2)
copy(M, M2)
}
}
if len == 1 {
r := R[0]
m := M[0]
for m > 1 {
out[0] = uint8(r)
out = out[1:]
r >>= 8
m = (m + 255) >> 8
}
}
}
Loading