Skip to content

Commit

Permalink
Merge pull request #342 from Consensys/335-add-model-for-word_sorting…
Browse files Browse the repository at this point in the history
…-example

feat: add model for `word_sorting` example
  • Loading branch information
DavePearce authored Oct 16, 2024
2 parents be12c3f + 7894f1f commit f456663
Show file tree
Hide file tree
Showing 8 changed files with 6,786 additions and 79 deletions.
123 changes: 94 additions & 29 deletions cmd/testgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
util "github.com/consensys/go-corset/pkg/cmd"
"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
tr "github.com/consensys/go-corset/pkg/trace"
Expand All @@ -23,6 +24,8 @@ func main() {
}

func init() {
rootCmd.Flags().Uint("min-lines", 1, "Minimum number of lines")
rootCmd.Flags().Uint("max-lines", 4, "Maximum number of lines")
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}

Expand All @@ -35,27 +38,31 @@ var rootCmd = &cobra.Command{
fmt.Println(cmd.UsageString())
os.Exit(1)
}
model := args[0]
var cfg TestGenConfig
// Lookup model
for _, m := range models {
if m.Name == model {
// Read schema
filename := fmt.Sprintf("%s.lisp", m.Name)
schema := readSchemaFile(path.Join("testdata", filename))
// Generate & split traces
valid, invalid := generateTestTraces(m, schema)
// Write out
writeTestTraces(m, "accepts", schema, valid)
writeTestTraces(m, "rejects", schema, invalid)
os.Exit(0)
}
}
//
fmt.Printf("unknown model \"%s\"\n", model)
os.Exit(1)
cfg.model = findModel(args[0])
cfg.min_lines = util.GetUint(cmd, "min-lines")
cfg.max_lines = util.GetUint(cmd, "max-lines")
// Read schema
filename := fmt.Sprintf("%s.lisp", cfg.model.Name)
schema := readSchemaFile(path.Join("testdata", filename))
// Generate & split traces
valid, invalid := generateTestTraces(cfg, schema)
// Write out
writeTestTraces(cfg.model, "accepts", schema, valid)
writeTestTraces(cfg.model, "rejects", schema, invalid)
os.Exit(0)

},
}

// TestGenConfig encapsulates configuration related to test generation.
type TestGenConfig struct {
model Model
min_lines uint
max_lines uint
}

// Model represents a hard-coded oracle for a given test.
type Model struct {
// Name of the model in question
Expand All @@ -66,26 +73,39 @@ type Model struct {

var models []Model = []Model{
{"memory", memoryModel},
{"word_sorting", wordSortingModel},
}

func findModel(name string) Model {
for _, m := range models {
if m.Name == name {
return m
}
}
//
panic(fmt.Sprintf("unknown model \"%s\"", name))
}

// Generate test traces
func generateTestTraces(model Model, schema sc.Schema) ([]tr.Trace, []tr.Trace) {
func generateTestTraces(cfg TestGenConfig, schema sc.Schema) ([]tr.Trace, []tr.Trace) {
// NOTE: This is really a temporary solution for now. It doesn't handle
// length multipliers. It doesn't allow for modules with different heights.
// It uses a fixed pool.
pool := []fr.Element{fr.NewElement(0), fr.NewElement(1), fr.NewElement(2)}
//
enumerator := sc.NewTraceEnumerator(2, schema, pool)
valid := make([]tr.Trace, 0)
invalid := make([]tr.Trace, 0)
// Generate and split the traces
for enumerator.HasNext() {
trace := enumerator.Next()
// Check whether trace is valid or not (according to the oracle)
if model.Oracle(schema, trace) {
valid = append(valid, trace)
} else {
invalid = append(invalid, trace)
//
for n := cfg.min_lines; n < cfg.max_lines; n++ {
enumerator := sc.NewTraceEnumerator(n, schema, pool)
// Generate and split the traces
for enumerator.HasNext() {
trace := enumerator.Next()
// Check whether trace is valid or not (according to the oracle)
if cfg.model.Oracle(schema, trace) {
valid = append(valid, trace)
} else {
invalid = append(invalid, trace)
}
}
}
// Done
Expand All @@ -108,7 +128,7 @@ func writeTestTraces(model Model, ext string, schema sc.Schema, traces []tr.Trac
panic(err)
}
// Log what happened
log.Infof("Wrote %s\n", filename)
log.Infof("Wrote %s (%d traces)\n", filename, len(traces))
}

// Convert a trace into an array of raw columns.
Expand Down Expand Up @@ -222,6 +242,36 @@ func memoryModel(schema sc.Schema, trace tr.Trace) bool {
return true
}

func wordSortingModel(schema sc.Schema, trace tr.Trace) bool {
TWO_8 := fr.NewElement(256)
//
X := findColumn(0, "X", schema, trace).Data()
Delta := findColumn(0, "Delta", schema, trace).Data()
Byte_0 := findColumn(0, "Byte_0", schema, trace).Data()
Byte_1 := findColumn(0, "Byte_1", schema, trace).Data()
//
for i := uint(0); i < X.Len(); i++ {
X_i := X.Get(i)
Delta_i := Delta.Get(i)
Byte_0_i := Byte_0.Get(i)
Byte_1_i := Byte_1.Get(i)
tmp := add(mul(Byte_1_i, TWO_8), Byte_0_i)
//
if Delta_i.Cmp(&tmp) != 0 {
return false
} else if i > 0 {
X_im1 := X.Get(i - 1)
diff := sub(X_i, X_im1)

if Delta_i.Cmp(&diff) != 0 {
return false
}
}
}
// Success
return true
}

// ============================================================================
// Helpers
// ============================================================================
Expand All @@ -232,3 +282,18 @@ func isIncremented(before fr.Element, after fr.Element) bool {
//
return after.IsOne()
}

func add(lhs fr.Element, rhs fr.Element) fr.Element {
lhs.Add(&lhs, &rhs)
return lhs
}

func sub(lhs fr.Element, rhs fr.Element) fr.Element {
lhs.Sub(&lhs, &rhs)
return lhs
}

func mul(lhs fr.Element, rhs fr.Element) fr.Element {
lhs.Mul(&lhs, &rhs)
return lhs
}
28 changes: 14 additions & 14 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ var checkCmd = &cobra.Command{
os.Exit(1)
}
// Configure log level
if getFlag(cmd, "debug") {
if GetFlag(cmd, "debug") {
log.SetLevel(log.DebugLevel)
}
//
cfg.air = getFlag(cmd, "air")
cfg.mir = getFlag(cmd, "mir")
cfg.hir = getFlag(cmd, "hir")
cfg.expand = !getFlag(cmd, "raw")
cfg.report = getFlag(cmd, "report")
cfg.reportPadding = getUint(cmd, "report-context")
cfg.spillage = getInt(cmd, "spillage")
cfg.strict = !getFlag(cmd, "warn")
cfg.quiet = getFlag(cmd, "quiet")
cfg.padding.Right = getUint(cmd, "padding")
cfg.parallelExpansion = !getFlag(cmd, "sequential")
cfg.batchSize = getUint(cmd, "batch")
cfg.ansiEscapes = getFlag(cmd, "ansi-escapes")
cfg.air = GetFlag(cmd, "air")
cfg.mir = GetFlag(cmd, "mir")
cfg.hir = GetFlag(cmd, "hir")
cfg.expand = !GetFlag(cmd, "raw")
cfg.report = GetFlag(cmd, "report")
cfg.reportPadding = GetUint(cmd, "report-context")
cfg.spillage = GetInt(cmd, "spillage")
cfg.strict = !GetFlag(cmd, "warn")
cfg.quiet = GetFlag(cmd, "quiet")
cfg.padding.Right = GetUint(cmd, "padding")
cfg.parallelExpansion = !GetFlag(cmd, "sequential")
cfg.batchSize = GetUint(cmd, "batch")
cfg.ansiEscapes = GetFlag(cmd, "ansi-escapes")
// TODO: support true ranges
cfg.padding.Left = cfg.padding.Right
if !cfg.hir && !cfg.mir && !cfg.air {
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ var debugCmd = &cobra.Command{
fmt.Println(cmd.UsageString())
os.Exit(1)
}
hir := getFlag(cmd, "hir")
mir := getFlag(cmd, "mir")
air := getFlag(cmd, "air")
stats := getFlag(cmd, "stats")
hir := GetFlag(cmd, "hir")
mir := GetFlag(cmd, "mir")
air := GetFlag(cmd, "air")
stats := GetFlag(cmd, "stats")
// Parse constraints
hirSchema := readSchemaFile(args[0])

Expand Down
26 changes: 13 additions & 13 deletions pkg/cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ var testCmd = &cobra.Command{
os.Exit(1)
}
// Configure log level
if getFlag(cmd, "debug") {
if GetFlag(cmd, "debug") {
log.SetLevel(log.DebugLevel)
}
// Setup check config
cfg.air = getFlag(cmd, "air")
cfg.mir = getFlag(cmd, "mir")
cfg.hir = getFlag(cmd, "hir")
cfg.expand = !getFlag(cmd, "raw")
cfg.report = getFlag(cmd, "report")
cfg.reportPadding = getUint(cmd, "report-context")
// cfg.strict = !getFlag(cmd, "warn")
// cfg.quiet = getFlag(cmd, "quiet")
cfg.padding.Right = getUint(cmd, "padding")
cfg.parallelExpansion = !getFlag(cmd, "sequential")
cfg.batchSize = getUint(cmd, "batch")
cfg.ansiEscapes = getFlag(cmd, "ansi-escapes")
cfg.air = GetFlag(cmd, "air")
cfg.mir = GetFlag(cmd, "mir")
cfg.hir = GetFlag(cmd, "hir")
cfg.expand = !GetFlag(cmd, "raw")
cfg.report = GetFlag(cmd, "report")
cfg.reportPadding = GetUint(cmd, "report-context")
// cfg.strict = !GetFlag(cmd, "warn")
// cfg.quiet = GetFlag(cmd, "quiet")
cfg.padding.Right = GetUint(cmd, "padding")
cfg.parallelExpansion = !GetFlag(cmd, "sequential")
cfg.batchSize = GetUint(cmd, "batch")
cfg.ansiEscapes = GetFlag(cmd, "ansi-escapes")
// TODO: support true ranges
cfg.padding.Left = cfg.padding.Right
// Normalise IRs
Expand Down
18 changes: 9 additions & 9 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ var traceCmd = &cobra.Command{
}
// Parse trace
cols := readTraceFile(args[0])
list := getFlag(cmd, "list")
stats := getFlag(cmd, "stats")
includes := getStringArray(cmd, "include")
print := getFlag(cmd, "print")
start := getUint(cmd, "start")
end := getUint(cmd, "end")
max_width := getUint(cmd, "max-width")
filter := getString(cmd, "filter")
output := getString(cmd, "out")
list := GetFlag(cmd, "list")
stats := GetFlag(cmd, "stats")
includes := GetStringArray(cmd, "include")
print := GetFlag(cmd, "print")
start := GetUint(cmd, "start")
end := GetUint(cmd, "end")
max_width := GetUint(cmd, "max-width")
filter := GetString(cmd, "filter")
output := GetString(cmd, "out")
// construct filters
if filter != "" {
cols = filterColumns(cols, filter)
Expand Down
20 changes: 10 additions & 10 deletions pkg/cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
"github.com/spf13/cobra"
)

// Get an expected flag, or panic if an error arises.
func getFlag(cmd *cobra.Command, flag string) bool {
// GetFlag gets an expected flag, or panic if an error arises.
func GetFlag(cmd *cobra.Command, flag string) bool {
r, err := cmd.Flags().GetBool(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -26,8 +26,8 @@ func getFlag(cmd *cobra.Command, flag string) bool {
return r
}

// Get an expectedsigned integer, or panic if an error arises.
func getInt(cmd *cobra.Command, flag string) int {
// GetInt gets an expectedsigned integer, or panic if an error arises.
func GetInt(cmd *cobra.Command, flag string) int {
r, err := cmd.Flags().GetInt(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -37,8 +37,8 @@ func getInt(cmd *cobra.Command, flag string) int {
return r
}

// Get an expected unsigned integer, or panic if an error arises.
func getUint(cmd *cobra.Command, flag string) uint {
// GetUint gets an expected unsigned integer, or panic if an error arises.
func GetUint(cmd *cobra.Command, flag string) uint {
r, err := cmd.Flags().GetUint(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -48,8 +48,8 @@ func getUint(cmd *cobra.Command, flag string) uint {
return r
}

// Get an expected string, or panic if an error arises.
func getString(cmd *cobra.Command, flag string) string {
// GetString gets an expected string, or panic if an error arises.
func GetString(cmd *cobra.Command, flag string) string {
r, err := cmd.Flags().GetString(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -59,8 +59,8 @@ func getString(cmd *cobra.Command, flag string) string {
return r
}

// Get an expected string array, or panic if an error arises.
func getStringArray(cmd *cobra.Command, flag string) []string {
// GetStringArray gets an expected string array, or panic if an error arises.
func GetStringArray(cmd *cobra.Command, flag string) []string {
r, err := cmd.Flags().GetStringArray(flag)
if err != nil {
fmt.Println(err)
Expand Down
9 changes: 9 additions & 0 deletions testdata/word_sorting.auto.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{"X": [0, 0], "Delta": [0, 0], "Byte_0": [0, 0], "Byte_1": [0, 0]}
{"X": [0, 1], "Delta": [0, 1], "Byte_0": [0, 1], "Byte_1": [0, 0]}
{"X": [0, 2], "Delta": [0, 2], "Byte_0": [0, 2], "Byte_1": [0, 0]}
{"X": [0, 0, 0], "Delta": [0, 0, 0], "Byte_0": [0, 0, 0], "Byte_1": [0, 0, 0]}
{"X": [0, 1, 1], "Delta": [0, 1, 0], "Byte_0": [0, 1, 0], "Byte_1": [0, 0, 0]}
{"X": [0, 2, 2], "Delta": [0, 2, 0], "Byte_0": [0, 2, 0], "Byte_1": [0, 0, 0]}
{"X": [0, 0, 1], "Delta": [0, 0, 1], "Byte_0": [0, 0, 1], "Byte_1": [0, 0, 0]}
{"X": [0, 1, 2], "Delta": [0, 1, 1], "Byte_0": [0, 1, 1], "Byte_1": [0, 0, 0]}
{"X": [0, 0, 2], "Delta": [0, 0, 2], "Byte_0": [0, 0, 2], "Byte_1": [0, 0, 0]}
Loading

0 comments on commit f456663

Please sign in to comment.