Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add model for word_sorting example #342

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 94 additions & 29 deletions cmd/testgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
util "github.com/consensys/go-corset/pkg/cmd"
"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
tr "github.com/consensys/go-corset/pkg/trace"
Expand All @@ -23,6 +24,8 @@ func main() {
}

func init() {
rootCmd.Flags().Uint("min-lines", 1, "Minimum number of lines")
rootCmd.Flags().Uint("max-lines", 4, "Maximum number of lines")
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}

Expand All @@ -35,27 +38,31 @@ var rootCmd = &cobra.Command{
fmt.Println(cmd.UsageString())
os.Exit(1)
}
model := args[0]
var cfg TestGenConfig
// Lookup model
for _, m := range models {
if m.Name == model {
// Read schema
filename := fmt.Sprintf("%s.lisp", m.Name)
schema := readSchemaFile(path.Join("testdata", filename))
// Generate & split traces
valid, invalid := generateTestTraces(m, schema)
// Write out
writeTestTraces(m, "accepts", schema, valid)
writeTestTraces(m, "rejects", schema, invalid)
os.Exit(0)
}
}
//
fmt.Printf("unknown model \"%s\"\n", model)
os.Exit(1)
cfg.model = findModel(args[0])
cfg.min_lines = util.GetUint(cmd, "min-lines")
cfg.max_lines = util.GetUint(cmd, "max-lines")
// Read schema
filename := fmt.Sprintf("%s.lisp", cfg.model.Name)
schema := readSchemaFile(path.Join("testdata", filename))
// Generate & split traces
valid, invalid := generateTestTraces(cfg, schema)
// Write out
writeTestTraces(cfg.model, "accepts", schema, valid)
writeTestTraces(cfg.model, "rejects", schema, invalid)
os.Exit(0)

},
}

// TestGenConfig encapsulates configuration related to test generation.
type TestGenConfig struct {
model Model
min_lines uint
max_lines uint
}

// Model represents a hard-coded oracle for a given test.
type Model struct {
// Name of the model in question
Expand All @@ -66,26 +73,39 @@ type Model struct {

var models []Model = []Model{
{"memory", memoryModel},
{"word_sorting", wordSortingModel},
}

func findModel(name string) Model {
for _, m := range models {
if m.Name == name {
return m
}
}
//
panic(fmt.Sprintf("unknown model \"%s\"", name))
}

// Generate test traces
func generateTestTraces(model Model, schema sc.Schema) ([]tr.Trace, []tr.Trace) {
func generateTestTraces(cfg TestGenConfig, schema sc.Schema) ([]tr.Trace, []tr.Trace) {
// NOTE: This is really a temporary solution for now. It doesn't handle
// length multipliers. It doesn't allow for modules with different heights.
// It uses a fixed pool.
pool := []fr.Element{fr.NewElement(0), fr.NewElement(1), fr.NewElement(2)}
//
enumerator := sc.NewTraceEnumerator(2, schema, pool)
valid := make([]tr.Trace, 0)
invalid := make([]tr.Trace, 0)
// Generate and split the traces
for enumerator.HasNext() {
trace := enumerator.Next()
// Check whether trace is valid or not (according to the oracle)
if model.Oracle(schema, trace) {
valid = append(valid, trace)
} else {
invalid = append(invalid, trace)
//
for n := cfg.min_lines; n < cfg.max_lines; n++ {
enumerator := sc.NewTraceEnumerator(n, schema, pool)
// Generate and split the traces
for enumerator.HasNext() {
trace := enumerator.Next()
// Check whether trace is valid or not (according to the oracle)
if cfg.model.Oracle(schema, trace) {
valid = append(valid, trace)
} else {
invalid = append(invalid, trace)
}
}
}
// Done
Expand All @@ -108,7 +128,7 @@ func writeTestTraces(model Model, ext string, schema sc.Schema, traces []tr.Trac
panic(err)
}
// Log what happened
log.Infof("Wrote %s\n", filename)
log.Infof("Wrote %s (%d traces)\n", filename, len(traces))
}

// Convert a trace into an array of raw columns.
Expand Down Expand Up @@ -222,6 +242,36 @@ func memoryModel(schema sc.Schema, trace tr.Trace) bool {
return true
}

func wordSortingModel(schema sc.Schema, trace tr.Trace) bool {
TWO_8 := fr.NewElement(256)
//
X := findColumn(0, "X", schema, trace).Data()
Delta := findColumn(0, "Delta", schema, trace).Data()
Byte_0 := findColumn(0, "Byte_0", schema, trace).Data()
Byte_1 := findColumn(0, "Byte_1", schema, trace).Data()
//
for i := uint(0); i < X.Len(); i++ {
X_i := X.Get(i)
Delta_i := Delta.Get(i)
Byte_0_i := Byte_0.Get(i)
Byte_1_i := Byte_1.Get(i)
tmp := add(mul(Byte_1_i, TWO_8), Byte_0_i)
//
if Delta_i.Cmp(&tmp) != 0 {
return false
} else if i > 0 {
X_im1 := X.Get(i - 1)
diff := sub(X_i, X_im1)

if Delta_i.Cmp(&diff) != 0 {
return false
}
}
}
// Success
return true
}

// ============================================================================
// Helpers
// ============================================================================
Expand All @@ -232,3 +282,18 @@ func isIncremented(before fr.Element, after fr.Element) bool {
//
return after.IsOne()
}

func add(lhs fr.Element, rhs fr.Element) fr.Element {
lhs.Add(&lhs, &rhs)
return lhs
}

func sub(lhs fr.Element, rhs fr.Element) fr.Element {
lhs.Sub(&lhs, &rhs)
return lhs
}

func mul(lhs fr.Element, rhs fr.Element) fr.Element {
lhs.Mul(&lhs, &rhs)
return lhs
}
28 changes: 14 additions & 14 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ var checkCmd = &cobra.Command{
os.Exit(1)
}
// Configure log level
if getFlag(cmd, "debug") {
if GetFlag(cmd, "debug") {
log.SetLevel(log.DebugLevel)
}
//
cfg.air = getFlag(cmd, "air")
cfg.mir = getFlag(cmd, "mir")
cfg.hir = getFlag(cmd, "hir")
cfg.expand = !getFlag(cmd, "raw")
cfg.report = getFlag(cmd, "report")
cfg.reportPadding = getUint(cmd, "report-context")
cfg.spillage = getInt(cmd, "spillage")
cfg.strict = !getFlag(cmd, "warn")
cfg.quiet = getFlag(cmd, "quiet")
cfg.padding.Right = getUint(cmd, "padding")
cfg.parallelExpansion = !getFlag(cmd, "sequential")
cfg.batchSize = getUint(cmd, "batch")
cfg.ansiEscapes = getFlag(cmd, "ansi-escapes")
cfg.air = GetFlag(cmd, "air")
cfg.mir = GetFlag(cmd, "mir")
cfg.hir = GetFlag(cmd, "hir")
cfg.expand = !GetFlag(cmd, "raw")
cfg.report = GetFlag(cmd, "report")
cfg.reportPadding = GetUint(cmd, "report-context")
cfg.spillage = GetInt(cmd, "spillage")
cfg.strict = !GetFlag(cmd, "warn")
cfg.quiet = GetFlag(cmd, "quiet")
cfg.padding.Right = GetUint(cmd, "padding")
cfg.parallelExpansion = !GetFlag(cmd, "sequential")
cfg.batchSize = GetUint(cmd, "batch")
cfg.ansiEscapes = GetFlag(cmd, "ansi-escapes")
// TODO: support true ranges
cfg.padding.Left = cfg.padding.Right
if !cfg.hir && !cfg.mir && !cfg.air {
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ var debugCmd = &cobra.Command{
fmt.Println(cmd.UsageString())
os.Exit(1)
}
hir := getFlag(cmd, "hir")
mir := getFlag(cmd, "mir")
air := getFlag(cmd, "air")
stats := getFlag(cmd, "stats")
hir := GetFlag(cmd, "hir")
mir := GetFlag(cmd, "mir")
air := GetFlag(cmd, "air")
stats := GetFlag(cmd, "stats")
// Parse constraints
hirSchema := readSchemaFile(args[0])

Expand Down
26 changes: 13 additions & 13 deletions pkg/cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ var testCmd = &cobra.Command{
os.Exit(1)
}
// Configure log level
if getFlag(cmd, "debug") {
if GetFlag(cmd, "debug") {
log.SetLevel(log.DebugLevel)
}
// Setup check config
cfg.air = getFlag(cmd, "air")
cfg.mir = getFlag(cmd, "mir")
cfg.hir = getFlag(cmd, "hir")
cfg.expand = !getFlag(cmd, "raw")
cfg.report = getFlag(cmd, "report")
cfg.reportPadding = getUint(cmd, "report-context")
// cfg.strict = !getFlag(cmd, "warn")
// cfg.quiet = getFlag(cmd, "quiet")
cfg.padding.Right = getUint(cmd, "padding")
cfg.parallelExpansion = !getFlag(cmd, "sequential")
cfg.batchSize = getUint(cmd, "batch")
cfg.ansiEscapes = getFlag(cmd, "ansi-escapes")
cfg.air = GetFlag(cmd, "air")
cfg.mir = GetFlag(cmd, "mir")
cfg.hir = GetFlag(cmd, "hir")
cfg.expand = !GetFlag(cmd, "raw")
cfg.report = GetFlag(cmd, "report")
cfg.reportPadding = GetUint(cmd, "report-context")
// cfg.strict = !GetFlag(cmd, "warn")
// cfg.quiet = GetFlag(cmd, "quiet")
cfg.padding.Right = GetUint(cmd, "padding")
cfg.parallelExpansion = !GetFlag(cmd, "sequential")
cfg.batchSize = GetUint(cmd, "batch")
cfg.ansiEscapes = GetFlag(cmd, "ansi-escapes")
// TODO: support true ranges
cfg.padding.Left = cfg.padding.Right
// Normalise IRs
Expand Down
18 changes: 9 additions & 9 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ var traceCmd = &cobra.Command{
}
// Parse trace
cols := readTraceFile(args[0])
list := getFlag(cmd, "list")
stats := getFlag(cmd, "stats")
includes := getStringArray(cmd, "include")
print := getFlag(cmd, "print")
start := getUint(cmd, "start")
end := getUint(cmd, "end")
max_width := getUint(cmd, "max-width")
filter := getString(cmd, "filter")
output := getString(cmd, "out")
list := GetFlag(cmd, "list")
stats := GetFlag(cmd, "stats")
includes := GetStringArray(cmd, "include")
print := GetFlag(cmd, "print")
start := GetUint(cmd, "start")
end := GetUint(cmd, "end")
max_width := GetUint(cmd, "max-width")
filter := GetString(cmd, "filter")
output := GetString(cmd, "out")
// construct filters
if filter != "" {
cols = filterColumns(cols, filter)
Expand Down
20 changes: 10 additions & 10 deletions pkg/cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
"github.com/spf13/cobra"
)

// Get an expected flag, or panic if an error arises.
func getFlag(cmd *cobra.Command, flag string) bool {
// GetFlag gets an expected flag, or panic if an error arises.
func GetFlag(cmd *cobra.Command, flag string) bool {
r, err := cmd.Flags().GetBool(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -26,8 +26,8 @@ func getFlag(cmd *cobra.Command, flag string) bool {
return r
}

// Get an expectedsigned integer, or panic if an error arises.
func getInt(cmd *cobra.Command, flag string) int {
// GetInt gets an expectedsigned integer, or panic if an error arises.
func GetInt(cmd *cobra.Command, flag string) int {
r, err := cmd.Flags().GetInt(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -37,8 +37,8 @@ func getInt(cmd *cobra.Command, flag string) int {
return r
}

// Get an expected unsigned integer, or panic if an error arises.
func getUint(cmd *cobra.Command, flag string) uint {
// GetUint gets an expected unsigned integer, or panic if an error arises.
func GetUint(cmd *cobra.Command, flag string) uint {
r, err := cmd.Flags().GetUint(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -48,8 +48,8 @@ func getUint(cmd *cobra.Command, flag string) uint {
return r
}

// Get an expected string, or panic if an error arises.
func getString(cmd *cobra.Command, flag string) string {
// GetString gets an expected string, or panic if an error arises.
func GetString(cmd *cobra.Command, flag string) string {
r, err := cmd.Flags().GetString(flag)
if err != nil {
fmt.Println(err)
Expand All @@ -59,8 +59,8 @@ func getString(cmd *cobra.Command, flag string) string {
return r
}

// Get an expected string array, or panic if an error arises.
func getStringArray(cmd *cobra.Command, flag string) []string {
// GetStringArray gets an expected string array, or panic if an error arises.
func GetStringArray(cmd *cobra.Command, flag string) []string {
r, err := cmd.Flags().GetStringArray(flag)
if err != nil {
fmt.Println(err)
Expand Down
9 changes: 9 additions & 0 deletions testdata/word_sorting.auto.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{"X": [0, 0], "Delta": [0, 0], "Byte_0": [0, 0], "Byte_1": [0, 0]}
{"X": [0, 1], "Delta": [0, 1], "Byte_0": [0, 1], "Byte_1": [0, 0]}
{"X": [0, 2], "Delta": [0, 2], "Byte_0": [0, 2], "Byte_1": [0, 0]}
{"X": [0, 0, 0], "Delta": [0, 0, 0], "Byte_0": [0, 0, 0], "Byte_1": [0, 0, 0]}
{"X": [0, 1, 1], "Delta": [0, 1, 0], "Byte_0": [0, 1, 0], "Byte_1": [0, 0, 0]}
{"X": [0, 2, 2], "Delta": [0, 2, 0], "Byte_0": [0, 2, 0], "Byte_1": [0, 0, 0]}
{"X": [0, 0, 1], "Delta": [0, 0, 1], "Byte_0": [0, 0, 1], "Byte_1": [0, 0, 0]}
{"X": [0, 1, 2], "Delta": [0, 1, 1], "Byte_0": [0, 1, 1], "Byte_1": [0, 0, 0]}
{"X": [0, 0, 2], "Delta": [0, 0, 2], "Byte_0": [0, 0, 2], "Byte_1": [0, 0, 0]}
Loading