diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 81cf92190..59c73cf74 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -18,7 +18,7 @@ jobs: python-version: "3.9" - name: Install cairo-lang - run: pip install cairo-lang==0.13.1 + run: pip install cairo-lang==0.13.2 - name: Install sympy run: pip install sympy==1.11.1 diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 0d252d77f..536f1b305 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1,17 +1,18 @@ package main import ( + "encoding/json" "fmt" "math" "os" + "path/filepath" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" hintrunner "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/zero" - cairoversion "github.com/NethermindEth/cairo-vm-go/pkg/parsers/cairo_version" "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" zero "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero" - runnerzero "github.com/NethermindEth/cairo-vm-go/pkg/runners/zero" + "github.com/NethermindEth/cairo-vm-go/pkg/runner" "github.com/urfave/cli/v2" ) @@ -24,6 +25,8 @@ func main() { var traceLocation string var memoryLocation string var layoutName string + var airPublicInputLocation string + var airPrivateInputLocation string app := &cli.App{ Name: "cairo-vm", Usage: "A cairo virtual machine", @@ -85,6 +88,18 @@ func main() { Required: false, Destination: &layoutName, }, + &cli.StringFlag{ + Name: "air_public_input", + Usage: "location to store the air_public_input", + Required: false, + Destination: &airPublicInputLocation, + }, + &cli.StringFlag{ + Name: "air_private_input", + Usage: "location to store the air_private_input", + Required: false, + Destination: &airPrivateInputLocation, + }, }, Action: func(ctx *cli.Context) error { // TODO: move this action's body to a separate function to decrease the @@ -136,6 +151,10 @@ func main() { case proofmode && cairoVersion == 0: runnerMode = runnerzero.ProofModeCairo0 } + program, err := runner.LoadCairoZeroProgram(zeroProgram) + if err != nil { + return fmt.Errorf("cannot load program: %w", err) + } fmt.Println("Running....") runner, err := runnerzero.NewRunner(runnerMode, program, hints, proofmode, collectTrace, maxsteps, layoutName) if err != nil { @@ -191,6 +210,46 @@ func main() { } } + if proofmode { + if airPublicInputLocation != "" { + airPublicInput, err := runner.GetAirPublicInput() + if err != nil { + return err + } + airPublicInputJson, err := json.MarshalIndent(airPublicInput, "", " ") + if err != nil { + return err + } + err = os.WriteFile(airPublicInputLocation, airPublicInputJson, 0644) + if err != nil { + return fmt.Errorf("cannot write air_public_input: %w", err) + } + } + + if airPrivateInputLocation != "" { + tracePath, err := filepath.Abs(traceLocation) + if err != nil { + return err + } + memoryPath, err := filepath.Abs(memoryLocation) + if err != nil { + return err + } + airPrivateInput, err := runner.GetAirPrivateInput(tracePath, memoryPath) + if err != nil { + return err + } + airPrivateInputJson, err := json.MarshalIndent(airPrivateInput, "", " ") + if err != nil { + return err + } + err = os.WriteFile(airPrivateInputLocation, airPrivateInputJson, 0644) + if err != nil { + return fmt.Errorf("cannot write air_private_input: %w", err) + } + } + } + fmt.Println("Success!") output := runner.Output() if len(output) > 0 { diff --git a/integration_tests/cairo_zero_hint_tests/hintrefs.cairo b/integration_tests/cairo_files_not_run_rust_vm/hintrefs.cairo similarity index 100% rename from integration_tests/cairo_zero_hint_tests/hintrefs.cairo rename to integration_tests/cairo_files_not_run_rust_vm/hintrefs.cairo diff --git a/integration_tests/cairo_zero_file_tests/keccak_builtin.starknet_with_keccak.cairo b/integration_tests/cairo_zero_file_tests/keccak_builtin.starknet_with_keccak.cairo new file mode 100644 index 000000000..be6e1717c --- /dev/null +++ b/integration_tests/cairo_zero_file_tests/keccak_builtin.starknet_with_keccak.cairo @@ -0,0 +1,18 @@ +%builtins keccak +from starkware.cairo.common.cairo_builtins import KeccakBuiltin +from starkware.cairo.common.keccak_state import KeccakBuiltinState + +func main{keccak_ptr: KeccakBuiltin*}() { + assert keccak_ptr[0].input = KeccakBuiltinState(1, 2, 3, 4, 5, 6, 7, 8); + let result = keccak_ptr[0].output; + let keccak_ptr = keccak_ptr + KeccakBuiltin.SIZE; + assert result.s0 = 528644516554364142278482415480021626364691973678134577961206; + assert result.s1 = 768681319646568210457759892191562701823009052229295869963057; + assert result.s2 = 1439835513376369408063324968379272676079109225238241190228026; + assert result.s3 = 1150396629165612276474514703759718478742374517669870754478270; + assert result.s4 = 1515147102575186161827863034255579930572231617017100845406254; + assert result.s5 = 1412568161597072838250338588041800080889949791225997426843744; + assert result.s6 = 982235455376248641031519404605670648838699214888770304613539; + assert result.s7 = 1339947803093378278438908448344904300127577306141693325151040; + return (); +} diff --git a/integration_tests/cairozero_test.go b/integration_tests/cairozero_test.go index 1bfe3d1c1..dd0e1fa64 100644 --- a/integration_tests/cairozero_test.go +++ b/integration_tests/cairozero_test.go @@ -6,7 +6,6 @@ import ( "os" "os/exec" "path/filepath" - "strconv" "strings" "sync" "testing" @@ -53,7 +52,7 @@ func (f *Filter) filtered(testFile string) bool { return false } -func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][2]int, benchmark bool, errorExpected bool) { +func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][3]int, benchmark bool, errorExpected bool) { t.Logf("testing: %s\n", path) compiledOutput, err := compileZeroCode(path) @@ -73,6 +72,17 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str } } + elapsedRs, rsTraceFile, rsMemoryFile, err := runRustVm(name, compiledOutput) + if errorExpected { + // we let the code go on so that we can check if the go vm also raises an error + assert.Error(t, err, path) + } else { + if err != nil { + t.Error(err) + return + } + } + elapsedGo, traceFile, memoryFile, _, err := runVm(compiledOutput) if errorExpected { assert.Error(t, err, path) @@ -85,7 +95,7 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str } if benchmark { - benchmarkMap[name] = [2]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds())} + benchmarkMap[name] = [3]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds()), int(elapsedRs.Milliseconds())} } pyTrace, pyMemory, err := decodeProof(pyTraceFile, pyMemoryFile) @@ -100,6 +110,20 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str return } + rsTrace, rsMemory, err := decodeProof(rsTraceFile, rsMemoryFile) + if err != nil { + t.Error(err) + return + } + + if !assert.Equal(t, pyTrace, rsTrace) { + t.Logf("pytrace:\n%s\n", traceRepr(pyTrace)) + t.Logf("rstrace:\n%s\n", traceRepr(rsTrace)) + } + if !assert.Equal(t, pyMemory, rsMemory) { + t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory)) + t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory)) + } if !assert.Equal(t, pyTrace, trace) { t.Logf("pytrace:\n%s\n", traceRepr(pyTrace)) t.Logf("trace:\n%s\n", traceRepr(trace)) @@ -108,6 +132,14 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory)) t.Logf("memory;\n%s\n", memoryRepr(memory)) } + if !assert.Equal(t, rsTrace, trace) { + t.Logf("rstrace:\n%s\n", traceRepr(rsTrace)) + t.Logf("trace:\n%s\n", traceRepr(trace)) + } + if !assert.Equal(t, rsMemory, memory) { + t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory)) + t.Logf("memory;\n%s\n", memoryRepr(memory)) + } } var zerobench = flag.Bool("zerobench", false, "run integration tests and generate benchmarks file") @@ -123,7 +155,7 @@ func TestCairoZeroFiles(t *testing.T) { filter := Filter{} filter.init() - benchmarkMap := make(map[string][2]int) + benchmarkMap := make(map[string][3]int) sem := make(chan struct{}, 5) // semaphore to limit concurrency var wg sync.WaitGroup // WaitGroup to wait for all goroutines to finish @@ -176,18 +208,17 @@ func TestCairoZeroFiles(t *testing.T) { } } -// Save the Benchmarks for the integration tests in `BenchMarks.txt` -func WriteBenchMarksToFile(benchmarkMap map[string][2]int) { - totalWidth := 123 +func WriteBenchMarksToFile(benchmarkMap map[string][3]int) { + totalWidth := 113 // Reduced width to adjust for long file names border := strings.Repeat("=", totalWidth) separator := strings.Repeat("-", totalWidth) var sb strings.Builder - w := tabwriter.NewWriter(&sb, 40, 0, 0, ' ', tabwriter.Debug) + w := tabwriter.NewWriter(&sb, 0, 0, 1, ' ', tabwriter.AlignRight) sb.WriteString(border + "\n") - fmt.Fprintln(w, "| File \t PythonVM (ms) \t GoVM (ms) \t") + fmt.Fprintf(w, "| %-40s | %-20s | %-20s | %-20s |\n", "File", "PythonVM (ms)", "GoVM (ms)", "RustVM (ms)") w.Flush() sb.WriteString(border + "\n") @@ -195,16 +226,13 @@ func WriteBenchMarksToFile(benchmarkMap map[string][2]int) { totalFiles := len(benchmarkMap) for key, values := range benchmarkMap { - row := "| " + key + "\t " - - for iter, value := range values { - row = row + strconv.Itoa(value) + "\t" - if iter == 0 { - row = row + " " - } + // Adjust the key length if it's too long + displayKey := key + if len(displayKey) > 40 { + displayKey = displayKey[:37] + "..." } - fmt.Fprintln(w, row) + fmt.Fprintf(w, "| %-40s | %-20d | %-20d | %-20d |\n", displayKey, values[0], values[1], values[2]) w.Flush() if iterator < totalFiles-1 { @@ -236,6 +264,8 @@ const ( compiledSuffix = "_compiled.json" pyTraceSuffix = "_py_trace" pyMemorySuffix = "_py_memory" + rsTraceSuffix = "_rs_trace" + rsMemorySuffix = "_rs_memory" traceSuffix = "_trace" memorySuffix = "_memory" ) @@ -288,8 +318,22 @@ func runPythonVm(testFilename, path string) (time.Duration, string, string, erro // A file without this suffix will use the default ("plain") layout. if strings.HasSuffix(testFilename, ".small.cairo") { args = append(args, "--layout", "small") + } else if strings.HasSuffix(testFilename, ".dex.cairo") { + args = append(args, "--layout", "dex") + } else if strings.HasSuffix(testFilename, ".recursive.cairo") { + args = append(args, "--layout", "recursive") } else if strings.HasSuffix(testFilename, ".starknet_with_keccak.cairo") { args = append(args, "--layout", "starknet_with_keccak") + } else if strings.HasSuffix(testFilename, ".starknet.cairo") { + args = append(args, "--layout", "starknet") + } else if strings.HasSuffix(testFilename, ".recursive_large_output.cairo") { + args = append(args, "--layout", "recursive_large_output") + } else if strings.HasSuffix(testFilename, ".recursive_with_poseidon.cairo") { + args = append(args, "--layout", "recursive_with_poseidon") + } else if strings.HasSuffix(testFilename, ".all_solidity.cairo") { + args = append(args, "--layout", "all_solidity") + } else if strings.HasSuffix(testFilename, ".all_cairo.cairo") { + args = append(args, "--layout", "all_cairo") } cmd := exec.Command("cairo-run", args...) @@ -309,6 +353,61 @@ func runPythonVm(testFilename, path string) (time.Duration, string, string, erro return elapsed, traceOutput, memoryOutput, nil } +// given a path to a compiled cairo zero file, execute it using the +// rust vm and return the trace and memory files location +func runRustVm(testFilename, path string) (time.Duration, string, string, error) { + traceOutput := swapExtenstion(path, rsTraceSuffix) + memoryOutput := swapExtenstion(path, rsMemorySuffix) + + args := []string{ + path, + "--proof_mode", + "--trace_file", + traceOutput, + "--memory_file", + memoryOutput, + } + + // If any other layouts are needed, add the suffix checks here. + // The convention would be: ".$layout.cairo" + // A file without this suffix will use the default ("plain") layout. + if strings.HasSuffix(testFilename, ".small.cairo") { + args = append(args, "--layout", "small") + } else if strings.HasSuffix(testFilename, ".dex.cairo") { + args = append(args, "--layout", "dex") + } else if strings.HasSuffix(testFilename, ".recursive.cairo") { + args = append(args, "--layout", "recursive") + } else if strings.HasSuffix(testFilename, ".starknet_with_keccak.cairo") { + args = append(args, "--layout", "starknet_with_keccak") + } else if strings.HasSuffix(testFilename, ".starknet.cairo") { + args = append(args, "--layout", "starknet") + } else if strings.HasSuffix(testFilename, ".recursive_large_output.cairo") { + args = append(args, "--layout", "recursive_large_output") + } else if strings.HasSuffix(testFilename, ".recursive_with_poseidon.cairo") { + args = append(args, "--layout", "recursive_with_poseidon") + } else if strings.HasSuffix(testFilename, ".all_solidity.cairo") { + args = append(args, "--layout", "all_solidity") + } else if strings.HasSuffix(testFilename, ".all_cairo.cairo") { + args = append(args, "--layout", "all_cairo") + } + + cmd := exec.Command("./../rust_vm_bin/cairo-vm-cli", args...) + + start := time.Now() + + res, err := cmd.CombinedOutput() + + elapsed := time.Since(start) + + if err != nil { + return 0, "", "", fmt.Errorf( + "./../rust_vm_bin/cairo-vm-cli %s: %w\n%s", path, err, string(res), + ) + } + + return elapsed, traceOutput, memoryOutput, nil +} + // given a path to a compiled cairo zero file, execute // it using our vm func runVm(path string) (time.Duration, string, string, string, error) { @@ -321,8 +420,22 @@ func runVm(path string) (time.Duration, string, string, string, error) { layout := "plain" if strings.Contains(path, ".small") { layout = "small" + } else if strings.Contains(path, ".dex") { + layout = "dex" + } else if strings.Contains(path, ".recursive") { + layout = "recursive" } else if strings.Contains(path, ".starknet_with_keccak") { layout = "starknet_with_keccak" + } else if strings.Contains(path, ".starknet") { + layout = "starknet" + } else if strings.Contains(path, ".recursive_large_output") { + layout = "recursive_large_output" + } else if strings.Contains(path, ".recursive_with_poseidon") { + layout = "recursive_with_poseidon" + } else if strings.Contains(path, ".all_solidity") { + layout = "all_solidity" + } else if strings.Contains(path, ".all_cairo") { + layout = "all_cairo" } cmd := exec.Command( @@ -351,7 +464,6 @@ func runVm(path string) (time.Duration, string, string, string, error) { } return elapsed, traceOutput, memoryOutput, string(res), nil - } func decodeProof(traceLocation string, memoryLocation string) ([]vm.Trace, []*fp.Element, error) { diff --git a/pkg/hintrunner/core/cairo_hintparser.go b/pkg/hintrunner/core/cairo_hintparser.go index 51bda93de..86168d73f 100644 --- a/pkg/hintrunner/core/cairo_hintparser.go +++ b/pkg/hintrunner/core/cairo_hintparser.go @@ -95,6 +95,14 @@ func GetHintByName(hint starknet.Hint) (hinter.Hinter, error) { return &AllocSegment{ Dst: parseCellRefer(args.Dst), }, nil + case starknet.EvalCircuitName: + args := hint.Args.(*starknet.EvalCircuit) + return &EvalCircuit{ + AddModN: parseResOperand(args.NAddMods), + AddModPtr: parseResOperand(args.AddModPtr), + MulModN: parseResOperand(args.NMulMods), + MulModPtr: parseResOperand(args.MulModPtr), + }, nil case starknet.TestLessThanName: args := hint.Args.(*starknet.TestLessThan) return &TestLessThan{ diff --git a/pkg/hintrunner/core/hint.go b/pkg/hintrunner/core/hint.go index 71b725f33..8e98ad233 100644 --- a/pkg/hintrunner/core/hint.go +++ b/pkg/hintrunner/core/hint.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -56,6 +57,46 @@ func (hint *AllocSegment) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerCon return nil } +type EvalCircuit struct { + AddModN hinter.Reference + AddModPtr hinter.Reference + MulModN hinter.Reference + MulModPtr hinter.Reference +} + +func (hint *EvalCircuit) String() string { + return "EvalCircuit" +} + +func (hint *EvalCircuit) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error { + addModInputAddress, err := hinter.ResolveAsAddress(vm, hint.AddModPtr) + if err != nil { + return fmt.Errorf("resolve addModBuiltin pointer: %w", err) + } + nAddMods, err := hint.AddModN.Resolve(vm) + if err != nil { + return fmt.Errorf("resolve nAddMods operand %s: %v", hint.AddModN, err) + } + nAddModsFelt, err := nAddMods.Uint64() + if err != nil { + return err + } + mulModInputAddress, err := hinter.ResolveAsAddress(vm, hint.MulModPtr) + if err != nil { + return fmt.Errorf("resolve mulModBuiltin pointer: %w", err) + } + nMulMods, err := hint.MulModN.Resolve(vm) + if err != nil { + return fmt.Errorf("resolve nMulMods operand %s: %v", hint.MulModN, err) + } + nMulModsFelt, err := nMulMods.Uint64() + if err != nil { + return err + } + + return builtins.FillMemory(vm.Memory, *addModInputAddress, nAddModsFelt, *mulModInputAddress, nMulModsFelt) +} + type TestLessThan struct { dst hinter.Reference lhs hinter.Reference @@ -540,7 +581,7 @@ func (hint U256InvModN) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerConte n := new(big.Int).Lsh(&N1BigInt, 128) n.Add(n, &N0BigInt) - _, r, g := u.Igcdex(n, b) + _, r, g := utils.Igcdex(n, b) mask := new(big.Int).Lsh(big.NewInt(1), 128) mask.Sub(mask, big.NewInt(1)) diff --git a/pkg/hintrunner/core/hint_test.go b/pkg/hintrunner/core/hint_test.go index 82bf52ffb..39dbe0796 100644 --- a/pkg/hintrunner/core/hint_test.go +++ b/pkg/hintrunner/core/hint_test.go @@ -455,6 +455,577 @@ func TestDivModDivisionByZeroError(t *testing.T) { require.ErrorContains(t, err, "cannot be divided by zero, rhs: 0") } +func TestEvalCircuit(t *testing.T) { + t.Run("test mod_builtin_runner (1)", func(t *testing.T) { + vm := VM.DefaultVirtualMachine() + + vm.Context.Ap = 0 + vm.Context.Fp = 0 + + // Test : p = 2^96 + 1 + // Note that these calculations are performed based on the offsets that we provide + // x1 = 17 (4 memory cells) + // nil (4 memory cells) (should become equal to 6) + // x2 = 23 (4 memory cells) + // res = nil (4 memory cells) (multiplication of the above two numbers should then equal 138) + + // Values Array + // x1 = UInt384(17,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 0, mem.MemoryValueFromInt(17)) + utils.WriteTo(vm, VM.ExecutionSegment, 1, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 3, mem.MemoryValueFromInt(0)) + + // 4 unallocated memory cells + + // x2 = UInt384(23,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 8, mem.MemoryValueFromInt(23)) + utils.WriteTo(vm, VM.ExecutionSegment, 9, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 10, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 11, mem.MemoryValueFromInt(0)) + + // 4 unallocated memory cells for res + + // AddMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 16, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 17, mem.MemoryValueFromInt(4)) + utils.WriteTo(vm, VM.ExecutionSegment, 18, mem.MemoryValueFromInt(8)) + + // MulMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 19, mem.MemoryValueFromInt(4)) + utils.WriteTo(vm, VM.ExecutionSegment, 20, mem.MemoryValueFromInt(8)) + utils.WriteTo(vm, VM.ExecutionSegment, 21, mem.MemoryValueFromInt(12)) + + AddModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 96, 1, builtins.Add)) + MulModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 96, 1, builtins.Mul)) + + /* + The Add and Mul Mod builtin structure are defined as: + struct ModBuiltin { + p: UInt384, // The modulus. + values_ptr: UInt384*, // A pointer to input values, the intermediate results and the output. + offsets_ptr: felt*, // A pointer to offsets inside the values array, defining the circuit. + // The offsets array should contain 3 * n elements. + n: felt, // The number of operations to perform. + } + */ + + // add_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 16})) + + // n + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(1)) + + // mul_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 19})) + + // n + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(1)) + + // To get the address of mul_mod_ptr and add_mod_ptr + utils.WriteTo(vm, VM.ExecutionSegment, 22, mem.MemoryValueFromSegmentAndOffset(AddModBuiltin.SegmentIndex, 0)) + utils.WriteTo(vm, VM.ExecutionSegment, 23, mem.MemoryValueFromSegmentAndOffset(MulModBuiltin.SegmentIndex, 0)) + + var addRef hinter.ApCellRef = 22 + var mulRef hinter.ApCellRef = 23 + + nAddMods := hinter.Immediate(f.NewElement(1)) + nMulMods := hinter.Immediate(f.NewElement(1)) + addModPtrAddr := hinter.Deref{Deref: addRef} + mulModPtrAddr := hinter.Deref{Deref: mulRef} + + hint := EvalCircuit{ + AddModN: nAddMods, + AddModPtr: addModPtrAddr, + MulModN: nMulMods, + MulModPtr: mulModPtrAddr, + } + + err := hint.Execute(vm, nil) + require.Nil(t, err) + + res1 := &f.Element{} + res1.SetInt64(138) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res1), + utils.ReadFrom(vm, VM.ExecutionSegment, 12), + ) + + res2 := &f.Element{} + res2.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 13), + ) + + res3 := &f.Element{} + res3.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 14), + ) + + res4 := &f.Element{} + res4.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 15), + ) + }) + + t.Run("test mod_builtin_runner (2)", func(t *testing.T) { + vm := VM.DefaultVirtualMachine() + + vm.Context.Ap = 0 + vm.Context.Fp = 0 + + // Test : p = 2^96 + 1 + // Note that these calculations are performed based on the offsets that we provide + // x1 = 1 (4 memory cells) + // nil (4 memory cells) (should become equal to 0) + // x2 = 2^96 + 2 (4 memory cells) + // res = nil (4 memory cells) (multiplication of the above two numbers should then equal 0) + + // Values Array + // x1 = UInt384(1,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, VM.ExecutionSegment, 1, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 3, mem.MemoryValueFromInt(0)) + + // 4 unallocated memory cells + + // x2 = UInt384(2,1,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 8, mem.MemoryValueFromInt(2)) + utils.WriteTo(vm, VM.ExecutionSegment, 9, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, VM.ExecutionSegment, 10, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 11, mem.MemoryValueFromInt(0)) + + // 4 unallocated memory cells for res + + // AddMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 16, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 17, mem.MemoryValueFromInt(4)) + utils.WriteTo(vm, VM.ExecutionSegment, 18, mem.MemoryValueFromInt(8)) + + // MulMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 19, mem.MemoryValueFromInt(4)) + utils.WriteTo(vm, VM.ExecutionSegment, 20, mem.MemoryValueFromInt(8)) + utils.WriteTo(vm, VM.ExecutionSegment, 21, mem.MemoryValueFromInt(12)) + + AddModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 96, 1, builtins.Add)) + MulModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 96, 1, builtins.Mul)) + + /* + The Add and Mul Mod builtin structure are defined as: + struct ModBuiltin { + p: UInt384, // The modulus. + values_ptr: UInt384*, // A pointer to input values, the intermediate results and the output. + offsets_ptr: felt*, // A pointer to offsets inside the values array, defining the circuit. + // The offsets array should contain 3 * n elements. + n: felt, // The number of operations to perform. + } + */ + + // add_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 16})) + + // n + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(1)) + + // mul_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 19})) + + // n + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(1)) + + // To get the address of mul_mod_ptr and add_mod_ptr + utils.WriteTo(vm, VM.ExecutionSegment, 22, mem.MemoryValueFromSegmentAndOffset(AddModBuiltin.SegmentIndex, 0)) + utils.WriteTo(vm, VM.ExecutionSegment, 23, mem.MemoryValueFromSegmentAndOffset(MulModBuiltin.SegmentIndex, 0)) + + var addRef hinter.ApCellRef = 22 + var mulRef hinter.ApCellRef = 23 + + nAddMods := hinter.Immediate(f.NewElement(1)) + nMulMods := hinter.Immediate(f.NewElement(1)) + addModPtrAddr := hinter.Deref{Deref: addRef} + mulModPtrAddr := hinter.Deref{Deref: mulRef} + + hint := EvalCircuit{ + AddModN: nAddMods, + AddModPtr: addModPtrAddr, + MulModN: nMulMods, + MulModPtr: mulModPtrAddr, + } + + err := hint.Execute(vm, nil) + require.Nil(t, err) + + res1 := &f.Element{} + res1.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res1), + utils.ReadFrom(vm, VM.ExecutionSegment, 12), + ) + + res2 := &f.Element{} + res2.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 13), + ) + + res3 := &f.Element{} + res3.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 14), + ) + + res4 := &f.Element{} + res4.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 15), + ) + }) + + t.Run("test mod_builtin_runner (3)", func(t *testing.T) { + vm := VM.DefaultVirtualMachine() + + vm.Context.Ap = 0 + vm.Context.Fp = 0 + + // Test : p = 2^3 + 1 + // Note that the calculations are performed based on the offsets that we provide + // x1 = 1 + // x2 = 2^3 + 2 + // x3 = 2 + + // Values Array + // x1 = UInt384(1,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, VM.ExecutionSegment, 1, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 3, mem.MemoryValueFromInt(0)) + + // x2 = UInt384(2,1,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 4, mem.MemoryValueFromInt(2)) + utils.WriteTo(vm, VM.ExecutionSegment, 5, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, VM.ExecutionSegment, 6, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 7, mem.MemoryValueFromInt(0)) + + // x3 = UInt384(2,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 8, mem.MemoryValueFromInt(2)) + utils.WriteTo(vm, VM.ExecutionSegment, 9, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 10, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 11, mem.MemoryValueFromInt(0)) + + // 20 unallocated memory cells for res and other calculations + + // AddMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 32, mem.MemoryValueFromInt(0)) // x1 + utils.WriteTo(vm, VM.ExecutionSegment, 33, mem.MemoryValueFromInt(12)) // x2 - x1 + utils.WriteTo(vm, VM.ExecutionSegment, 34, mem.MemoryValueFromInt(4)) // x2 + utils.WriteTo(vm, VM.ExecutionSegment, 35, mem.MemoryValueFromInt(16)) // (x2 - x1) / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 36, mem.MemoryValueFromInt(20)) // x1 / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 37, mem.MemoryValueFromInt(24)) // (x2 - x1) / x3 + x1 / x3 + + // MulMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 38, mem.MemoryValueFromInt(8)) // x3 + utils.WriteTo(vm, VM.ExecutionSegment, 39, mem.MemoryValueFromInt(16)) // (x2 - x1) / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 40, mem.MemoryValueFromInt(12)) // (x2 - x1) + utils.WriteTo(vm, VM.ExecutionSegment, 41, mem.MemoryValueFromInt(8)) // x3 + utils.WriteTo(vm, VM.ExecutionSegment, 42, mem.MemoryValueFromInt(20)) // x1 / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 43, mem.MemoryValueFromInt(0)) // x1 + utils.WriteTo(vm, VM.ExecutionSegment, 44, mem.MemoryValueFromInt(8)) // x3 + utils.WriteTo(vm, VM.ExecutionSegment, 45, mem.MemoryValueFromInt(24)) // ((x2 - x1) / x3 + x1 / x3) + utils.WriteTo(vm, VM.ExecutionSegment, 46, mem.MemoryValueFromInt(28)) // ((x2 - x1) / x3 + x1 / x3) * x3 + + AddModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 3, 1, builtins.Add)) + MulModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 3, 1, builtins.Mul)) + + /* + The Add and Mul Mod builtin structure are defined as: + struct ModBuiltin { + p: UInt384, // The modulus. + values_ptr: UInt384*, // A pointer to input values, the intermediate results and the output. + offsets_ptr: felt*, // A pointer to offsets inside the values array, defining the circuit. + // The offsets array should contain 3 * n elements. + n: felt, // The number of operations to perform. + } + */ + + // add_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 32})) + + // n + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(2)) + + // mul_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 38})) + + // n + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(3)) + + // To get the address of mul_mod_ptr and add_mod_ptr + utils.WriteTo(vm, VM.ExecutionSegment, 47, mem.MemoryValueFromSegmentAndOffset(AddModBuiltin.SegmentIndex, 0)) + utils.WriteTo(vm, VM.ExecutionSegment, 48, mem.MemoryValueFromSegmentAndOffset(MulModBuiltin.SegmentIndex, 0)) + + var addRef hinter.ApCellRef = 47 + var mulRef hinter.ApCellRef = 48 + + nAddMods := hinter.Immediate(f.NewElement(2)) + nMulMods := hinter.Immediate(f.NewElement(3)) + addModPtrAddr := hinter.Deref{Deref: addRef} + mulModPtrAddr := hinter.Deref{Deref: mulRef} + + hint := EvalCircuit{ + AddModN: nAddMods, + AddModPtr: addModPtrAddr, + MulModN: nMulMods, + MulModPtr: mulModPtrAddr, + } + + err := hint.Execute(vm, nil) + require.Nil(t, err) + + res1 := &f.Element{} + res1.SetInt64(1) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res1), + utils.ReadFrom(vm, VM.ExecutionSegment, 28), + ) + + res2 := &f.Element{} + res2.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 29), + ) + + res3 := &f.Element{} + res3.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 30), + ) + + res4 := &f.Element{} + res4.SetInt64(0) + + require.Equal( + t, + mem.MemoryValueFromFieldElement(res2), + utils.ReadFrom(vm, VM.ExecutionSegment, 31), + ) + }) + + t.Run("test mod_builtin_runner (4)", func(t *testing.T) { + vm := VM.DefaultVirtualMachine() + + vm.Context.Ap = 0 + vm.Context.Fp = 0 + + // Test : p = 2^3 + 1 + // Note that the calculations are performed based on the offsets that we provide + // x1 = 8 + // x2 = 2^3 + 2 + // x3 = 2 + + // Values Array + // x1 = UInt384(8,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 0, mem.MemoryValueFromInt(8)) + utils.WriteTo(vm, VM.ExecutionSegment, 1, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 3, mem.MemoryValueFromInt(0)) + + // x2 = UInt384(2,1,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 4, mem.MemoryValueFromInt(2)) + utils.WriteTo(vm, VM.ExecutionSegment, 5, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, VM.ExecutionSegment, 6, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 7, mem.MemoryValueFromInt(0)) + + // x3 = UInt384(2,0,0,0) + utils.WriteTo(vm, VM.ExecutionSegment, 8, mem.MemoryValueFromInt(2)) + utils.WriteTo(vm, VM.ExecutionSegment, 9, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 10, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, VM.ExecutionSegment, 11, mem.MemoryValueFromInt(0)) + + // 20 unallocated memory cells for res and other calculations + + // AddMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 32, mem.MemoryValueFromInt(0)) // x1 + utils.WriteTo(vm, VM.ExecutionSegment, 33, mem.MemoryValueFromInt(12)) // x2 - x1 + utils.WriteTo(vm, VM.ExecutionSegment, 34, mem.MemoryValueFromInt(4)) // x2 + utils.WriteTo(vm, VM.ExecutionSegment, 35, mem.MemoryValueFromInt(16)) // (x2 - x1) / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 36, mem.MemoryValueFromInt(20)) // x1 / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 37, mem.MemoryValueFromInt(24)) // (x2 - x1) / x3 + x1 / x3 + + // MulMod Offsets Array + utils.WriteTo(vm, VM.ExecutionSegment, 38, mem.MemoryValueFromInt(8)) // x3 + utils.WriteTo(vm, VM.ExecutionSegment, 39, mem.MemoryValueFromInt(16)) // (x2 - x1) / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 40, mem.MemoryValueFromInt(12)) // (x2 - x1) + utils.WriteTo(vm, VM.ExecutionSegment, 41, mem.MemoryValueFromInt(8)) // x3 + utils.WriteTo(vm, VM.ExecutionSegment, 42, mem.MemoryValueFromInt(20)) // x1 / x3 + utils.WriteTo(vm, VM.ExecutionSegment, 43, mem.MemoryValueFromInt(0)) // x1 + utils.WriteTo(vm, VM.ExecutionSegment, 44, mem.MemoryValueFromInt(8)) // x3 + utils.WriteTo(vm, VM.ExecutionSegment, 45, mem.MemoryValueFromInt(24)) // ((x2 - x1) / x3 + x1 / x3) + utils.WriteTo(vm, VM.ExecutionSegment, 46, mem.MemoryValueFromInt(28)) // ((x2 - x1) / x3 + x1 / x3) * x3 + + AddModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 3, 1, builtins.Add)) + MulModBuiltin := vm.Memory.AllocateBuiltinSegment(builtins.NewModBuiltin(1, 3, 1, builtins.Mul)) + + /* + The Add and Mul Mod builtin structure are defined as: + struct ModBuiltin { + p: UInt384, // The modulus. + values_ptr: UInt384*, // A pointer to input values, the intermediate results and the output. + offsets_ptr: felt*, // A pointer to offsets inside the values array, defining the circuit. + // The offsets array should contain 3 * n elements. + n: felt, // The number of operations to perform. + } + */ + + // add_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 32})) + + // n + utils.WriteTo(vm, AddModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(2)) + + // mul_mod_ptr + // p = UInt384(1,1,0,0) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 0, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 1, mem.MemoryValueFromInt(1)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 2, mem.MemoryValueFromInt(0)) + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 3, mem.MemoryValueFromInt(0)) + + // values_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 4, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 0})) + + // offsets_ptr + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 5, mem.MemoryValueFromMemoryAddress(&mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: 38})) + + // n + utils.WriteTo(vm, MulModBuiltin.SegmentIndex, 6, mem.MemoryValueFromInt(3)) + + // To get the address of mul_mod_ptr and add_mod_ptr + utils.WriteTo(vm, VM.ExecutionSegment, 47, mem.MemoryValueFromSegmentAndOffset(AddModBuiltin.SegmentIndex, 0)) + utils.WriteTo(vm, VM.ExecutionSegment, 48, mem.MemoryValueFromSegmentAndOffset(MulModBuiltin.SegmentIndex, 0)) + + var addRef hinter.ApCellRef = 47 + var mulRef hinter.ApCellRef = 48 + + nAddMods := hinter.Immediate(f.NewElement(2)) + nMulMods := hinter.Immediate(f.NewElement(3)) + addModPtrAddr := hinter.Deref{Deref: addRef} + mulModPtrAddr := hinter.Deref{Deref: mulRef} + + hint := EvalCircuit{ + AddModN: nAddMods, + AddModPtr: addModPtrAddr, + MulModN: nMulMods, + MulModPtr: mulModPtrAddr, + } + + err := hint.Execute(vm, nil) + require.ErrorContains(t, err, "expected integer at address") + }) + +} + func TestU256InvModN(t *testing.T) { t.Run("test u256InvModN (n == 1)", func(t *testing.T) { vm := VM.DefaultVirtualMachine() diff --git a/pkg/hintrunner/utils/math_utils.go b/pkg/hintrunner/utils/math_utils.go index 7f274f81d..678aca7d4 100644 --- a/pkg/hintrunner/utils/math_utils.go +++ b/pkg/hintrunner/utils/math_utils.go @@ -5,6 +5,7 @@ import ( "fmt" "math/big" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -55,7 +56,7 @@ func AsIntBig(value *big.Int) big.Int { func Divmod(n, m, p *big.Int) (big.Int, error) { // https://github.com/starkware-libs/cairo-lang/blob/efa9648f57568aad8f8a13fbf027d2de7c63c2c0/src/starkware/python/math_utils.py#L26 - a, _, c := Igcdex(m, p) + a, _, c := utils.Igcdex(m, p) if c.Cmp(big.NewInt(1)) != 0 { return *big.NewInt(0), errors.New("no solution exists (gcd(m, p) != 1)") } @@ -65,68 +66,6 @@ func Divmod(n, m, p *big.Int) (big.Int, error) { return *res, nil } -func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) { - // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362 - - if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0) - } - g, x, y := gcdext(a, b) - return x, y, g -} - -func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) { - // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125 - - if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 { - g := new(big.Int) - if a.Cmp(big.NewInt(0)) == 0 { - g.Abs(b) - } else { - g.Abs(a) - } - - if g.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0) - } - return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g) - } - - xSign, aSigned := sign(a) - ySign, bSigned := sign(b) - x, r := big.NewInt(1), big.NewInt(0) - y, s := big.NewInt(0), big.NewInt(1) - - for bSigned.Sign() != 0 { - q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int)) - aSigned = bSigned - bSigned = *c - x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r)) - y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s)) - } - - return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign))) -} - -func sign(n *big.Int) (int, big.Int) { - // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119 - - if n.Sign() < 0 { - return -1, *new(big.Int).Abs(n) - } - return 1, *new(big.Int).Set(n) -} - -func SafeDiv(x, y *big.Int) (big.Int, error) { - if y.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), fmt.Errorf("division by zero") - } - if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 { - return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y) - } - return *new(big.Int).Div(x, y), nil -} - func IsQuadResidue(x *fp.Element) bool { // Implementation adapted from sympy implementation which can be found here : // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689 diff --git a/pkg/hintrunner/utils/math_utils_test.go b/pkg/hintrunner/utils/math_utils_test.go index 3c1a0a609..42da85c2f 100644 --- a/pkg/hintrunner/utils/math_utils_test.go +++ b/pkg/hintrunner/utils/math_utils_test.go @@ -46,69 +46,3 @@ func TestDivMod(t *testing.T) { }) } } - -func TestIgcdex(t *testing.T) { - // https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278 - tests := []struct { - name string - a, b *big.Int - expectedX, expectedY, expectedG *big.Int - }{ - { - name: "Case 1", - a: big.NewInt(2), - b: big.NewInt(3), - expectedX: big.NewInt(-1), - expectedY: big.NewInt(1), - expectedG: big.NewInt(1), - }, - { - name: "Case 2", - a: big.NewInt(10), - b: big.NewInt(12), - expectedX: big.NewInt(-1), - expectedY: big.NewInt(1), - expectedG: big.NewInt(2), - }, - { - name: "Case 3", - a: big.NewInt(100), - b: big.NewInt(2004), - expectedX: big.NewInt(-20), - expectedY: big.NewInt(1), - expectedG: big.NewInt(4), - }, - { - name: "Case 4", - a: big.NewInt(0), - b: big.NewInt(0), - expectedX: big.NewInt(0), - expectedY: big.NewInt(1), - expectedG: big.NewInt(0), - }, - { - name: "Case 5", - a: big.NewInt(1), - b: big.NewInt(0), - expectedX: big.NewInt(1), - expectedY: big.NewInt(0), - expectedG: big.NewInt(1), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actualX, actualY, actualG := Igcdex(tt.a, tt.b) - - if actualX.Cmp(tt.expectedX) != 0 { - t.Errorf("got x: %v, want: %v", actualX, tt.expectedX) - } - if actualY.Cmp(tt.expectedY) != 0 { - t.Errorf("got x: %v, want: %v", actualY, tt.expectedY) - } - if actualG.Cmp(tt.expectedG) != 0 { - t.Errorf("got x: %v, want: %v", actualG, tt.expectedG) - } - }) - } -} diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 1fa40ffca..d79146b38 100755 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -173,6 +173,7 @@ const ( // ------ Other hints related code ------ allocSegmentCode string = "memory[ap] = segments.add()" + evalCircuitCode string = "from starkware.cairo.lang.builtins.modulo.mod_builtin_runner import ModBuiltinRunner\nassert builtin_runners[\"add_mod_builtin\"].instance_def.batch_size == 1\nassert builtin_runners[\"mul_mod_builtin\"].instance_def.batch_size == 1\n\nModBuiltinRunner.fill_memory(\n memory=memory,\n add_mod=(ids.add_mod_ptr.address_, builtin_runners[\"add_mod_builtin\"], ids.add_mod_n),\n mul_mod=(ids.mul_mod_ptr.address_, builtin_runners[\"mul_mod_builtin\"], ids.mul_mod_n),\n)" memcpyContinueCopyingCode string = "n -= 1\nids.continue_copying = 1 if n > 0 else 0" memsetContinueLoopCode string = "n -= 1\nids.continue_loop = 1 if n > 0 else 0" memcpyEnterScopeCode string = "vm_enter_scope({'n': ids.len})" diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index ccc105b26..269c8984b 100755 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -24,7 +24,11 @@ func (hint *GenericZeroHinter) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu } func GetZeroHints(cairoZeroJson *zero.ZeroProgram) (map[uint64][]hinter.Hinter, error) { - hints := make(map[uint64][]hinter.Hinter) + numHints := 0 + for _, rawHints := range cairoZeroJson.Hints { + numHints += len(rawHints) + } + hints := make(map[uint64][]hinter.Hinter, numHints) for counter, rawHints := range cairoZeroJson.Hints { pc, err := strconv.ParseUint(counter, 10, 64) if err != nil { @@ -319,6 +323,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint) (hinter.Hinte // Other hints case allocSegmentCode: return createAllocSegmentHinter() + case evalCircuitCode: + return createEvalCircuitHinter(resolver) case memcpyContinueCopyingCode: return createMemContinueHinter(resolver, false) case memsetContinueLoopCode: diff --git a/pkg/hintrunner/zero/zerohint_ec.go b/pkg/hintrunner/zero/zerohint_ec.go index 011781669..de508956c 100644 --- a/pkg/hintrunner/zero/zerohint_ec.go +++ b/pkg/hintrunner/zero/zerohint_ec.go @@ -257,7 +257,7 @@ func newDivModNSafeDivPlusOneHint() hinter.Hinter { valueBig.Mul(resBig, bBig) valueBig.Sub(valueBig, aBig) - newValueBig, err := secp_utils.SafeDiv(valueBig, nBig) + newValueBig, err := utils.SafeDiv(valueBig, nBig) if err != nil { return err } diff --git a/pkg/hintrunner/zero/zerohint_math_test.go b/pkg/hintrunner/zero/zerohint_math_test.go index 2dbb63dac..97681fb70 100755 --- a/pkg/hintrunner/zero/zerohint_math_test.go +++ b/pkg/hintrunner/zero/zerohint_math_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" - "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/utils" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -445,8 +445,8 @@ func TestZeroHintMath(t *testing.T) { "Assert250bits": { { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: apRelative, Value: feltInt64(3042)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -459,8 +459,8 @@ func TestZeroHintMath(t *testing.T) { }, { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: fpRelative, Value: feltInt64(4938538853994)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -473,8 +473,8 @@ func TestZeroHintMath(t *testing.T) { }, { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: apRelative, Value: feltString("348329493943842849393993999999231222222222")}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -487,8 +487,8 @@ func TestZeroHintMath(t *testing.T) { }, { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: apRelative, Value: feltString("348329493943842849393124453993999999231222222222")}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -501,8 +501,8 @@ func TestZeroHintMath(t *testing.T) { }, { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: apRelative, Value: feltInt64(-233)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -578,8 +578,8 @@ func TestZeroHintMath(t *testing.T) { "SplitFelt": { { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: apRelative, Value: feltString("100000000000000000000000000000000000000")}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -592,8 +592,8 @@ func TestZeroHintMath(t *testing.T) { }, { operanders: []*hintOperander{ - {Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "low", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "high", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, {Name: "value", Kind: apRelative, Value: &utils.FeltMax128}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { @@ -611,8 +611,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: &utils.FeltZero}, {Name: "div", Kind: apRelative, Value: &utils.FeltMax128}, {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -624,8 +624,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: &utils.FeltZero}, {Name: "div", Kind: apRelative, Value: &utils.FeltZero}, {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -637,8 +637,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: &utils.FeltZero}, {Name: "div", Kind: apRelative, Value: &utils.FeltOne}, {Name: "bound", Kind: apRelative, Value: new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 130))}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -650,8 +650,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: feltInt64(-6)}, {Name: "div", Kind: apRelative, Value: feltInt64(2)}, {Name: "bound", Kind: apRelative, Value: feltInt64(2)}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -663,8 +663,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: feltInt64(6)}, {Name: "div", Kind: apRelative, Value: feltInt64(2)}, {Name: "bound", Kind: apRelative, Value: feltInt64(3)}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -676,8 +676,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: feltInt64(5)}, {Name: "div", Kind: apRelative, Value: feltInt64(2)}, {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -692,8 +692,8 @@ func TestZeroHintMath(t *testing.T) { {Name: "value", Kind: apRelative, Value: feltInt64(-3)}, {Name: "div", Kind: apRelative, Value: feltInt64(2)}, {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) @@ -751,8 +751,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(100)}, {Name: "div", Kind: fpRelative, Value: feltUint64(6)}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) @@ -766,8 +766,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(450326666)}, {Name: "div", Kind: fpRelative, Value: feltUint64(136310839)}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) @@ -781,8 +781,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(0)}, {Name: "div", Kind: fpRelative, Value: feltUint64(10)}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) @@ -796,8 +796,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(10)}, {Name: "div", Kind: fpRelative, Value: feltUint64(0)}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) @@ -808,8 +808,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(10)}, {Name: "div", Kind: fpRelative, Value: feltString("10633823966279327296825105735305134079")}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) @@ -823,8 +823,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(10)}, {Name: "div", Kind: fpRelative, Value: feltString("10633823966279327296825105735305134080")}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) @@ -838,8 +838,8 @@ func TestZeroHintMath(t *testing.T) { operanders: []*hintOperander{ {Name: "value", Kind: fpRelative, Value: feltUint64(10)}, {Name: "div", Kind: fpRelative, Value: feltString("10633823966279327296825105735305134081")}, - {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, - {Name: "q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + {Name: "r", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 0)}, + {Name: "q", Kind: reference, Value: addrBuiltin(builtins.RangeCheckType, 1)}, }, makeHinter: func(ctx *hintTestContext) hinter.Hinter { return newUnsignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["q"], ctx.operanders["r"]) diff --git a/pkg/hintrunner/zero/zerohint_others.go b/pkg/hintrunner/zero/zerohint_others.go index efd061576..a44b13acb 100644 --- a/pkg/hintrunner/zero/zerohint_others.go +++ b/pkg/hintrunner/zero/zerohint_others.go @@ -81,6 +81,35 @@ func createAllocSegmentHinter() (hinter.Hinter, error) { return &core.AllocSegment{Dst: hinter.ApCellRef(0)}, nil } +func createEvalCircuitHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + addModPtr, err := resolver.GetReference("add_mod_ptr") + if err != nil { + return nil, err + } + + nAddMods, err := resolver.GetReference("add_mod_n") + if err != nil { + return nil, err + } + + mulModPtr, err := resolver.GetReference("mul_mod_ptr") + if err != nil { + return nil, err + } + + nMulMods, err := resolver.GetReference("mul_mod_n") + if err != nil { + return nil, err + } + + return &core.EvalCircuit{ + AddModN: nAddMods, + AddModPtr: addModPtr, + MulModN: nMulMods, + MulModPtr: mulModPtr, + }, nil +} + // VMEnterScope hint enters a new scope in the Cairo VM func createVMEnterScopeHinter() (hinter.Hinter, error) { return &GenericZeroHinter{ diff --git a/pkg/hintrunner/zero/zerohint_signature.go b/pkg/hintrunner/zero/zerohint_signature.go index 8af7daf81..205842151 100644 --- a/pkg/hintrunner/zero/zerohint_signature.go +++ b/pkg/hintrunner/zero/zerohint_signature.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" secp_utils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -372,7 +373,7 @@ func newDivModSafeDivHint() hinter.Hinter { } divisor := new(big.Int).Sub(new(big.Int).Mul(res, b), a) - value, err := secp_utils.SafeDiv(divisor, N) + value, err := utils.SafeDiv(divisor, N) if err != nil { return err } diff --git a/pkg/hintrunner/zero/zerohint_signature_test.go b/pkg/hintrunner/zero/zerohint_signature_test.go index 49ece7c67..f29417352 100644 --- a/pkg/hintrunner/zero/zerohint_signature_test.go +++ b/pkg/hintrunner/zero/zerohint_signature_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" - "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/utils" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/stretchr/testify/require" ) @@ -54,7 +54,7 @@ func TestVerifyZeroHint(t *testing.T) { "VerifyECDSASignature": { { operanders: []*hintOperander{ - {Name: "ecdsaPtr", Kind: reference, Value: addrBuiltin(starknet.ECDSA, 0)}, + {Name: "ecdsaPtr", Kind: reference, Value: addrBuiltin(builtins.ECDSAType, 0)}, {Name: "signature_r", Kind: apRelative, Value: feltString("3086480810278599376317923499561306189851900463386393948998357832163236918254")}, {Name: "signature_s", Kind: apRelative, Value: feltString("598673427589502599949712887611119751108407514580626464031881322743364689811")}, }, diff --git a/pkg/hintrunner/zero/zerohint_test.go b/pkg/hintrunner/zero/zerohint_test.go index 638755fbc..a71b0b7a9 100644 --- a/pkg/hintrunner/zero/zerohint_test.go +++ b/pkg/hintrunner/zero/zerohint_test.go @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" runnerutil "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils" - "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -57,7 +56,7 @@ type hintOperander struct { } type builtinReference struct { - builtin starknet.Builtin + builtin builtins.BuiltinType offset uint64 } @@ -135,7 +134,7 @@ func runHinterTests(t *testing.T, tests map[string][]hintTestCase) { offset uint64 addr memory.MemoryAddress } - builtinsAllocated := map[starknet.Builtin]allocatedBuiltin{} + builtinsAllocated := map[builtins.BuiltinType]allocatedBuiltin{} for _, o := range tc.operanders { if o.Kind != reference { continue diff --git a/pkg/hintrunner/zero/zerohint_utils_test.go b/pkg/hintrunner/zero/zerohint_utils_test.go index c4af2df1b..d817f7a13 100644 --- a/pkg/hintrunner/zero/zerohint_utils_test.go +++ b/pkg/hintrunner/zero/zerohint_utils_test.go @@ -8,8 +8,8 @@ import ( "testing" runnerutil "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils" - "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/vm" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "github.com/stretchr/testify/assert" @@ -30,7 +30,7 @@ func addrWithSegment(segment, offset uint64) *memory.MemoryAddress { } } -func addrBuiltin(builtin starknet.Builtin, offset uint64) *builtinReference { +func addrBuiltin(builtin builtins.BuiltinType, offset uint64) *builtinReference { return &builtinReference{ builtin: builtin, offset: offset, diff --git a/pkg/parsers/cairo_version/cairo_version.go b/pkg/parsers/cairo_version/cairo_version.go deleted file mode 100644 index c17104e3c..000000000 --- a/pkg/parsers/cairo_version/cairo_version.go +++ /dev/null @@ -1,30 +0,0 @@ -package cairoversion - -import ( - "encoding/json" - "os" - "strconv" - "strings" -) - -type CairoVersion struct { - Version string `json:"compiler_version"` -} - -func GetCairoVersion(pathToFile string) (uint8, error) { - content, err := os.ReadFile(pathToFile) - if err != nil { - return 0, err - } - cv := CairoVersion{} - err = json.Unmarshal(content, &cv) - if err != nil { - return 0, err - } - firstNumberStr := strings.Split(cv.Version, ".")[0] - firstNumber, err := strconv.ParseUint(firstNumberStr, 10, 8) - if err != nil { - return 0, err - } - return uint8(firstNumber), nil -} diff --git a/pkg/parsers/starknet/hint.go b/pkg/parsers/starknet/hint.go index 768b20d65..01bc19ac8 100644 --- a/pkg/parsers/starknet/hint.go +++ b/pkg/parsers/starknet/hint.go @@ -18,6 +18,7 @@ const ( CheatcodeName HintName = "Cheatcode" // Core hints AllocSegmentName HintName = "AllocSegment" + EvalCircuitName HintName = "EvalCircuit" TestLessThanName HintName = "TestLessThan" TestLessThanOrEqualName HintName = "TestLessThanOrEqual" TestLessThanOrEqualAddressName HintName = "TestLessThanOrEqualAddress" @@ -75,6 +76,13 @@ type AllocSegment struct { Dst CellRef `json:"dst" validate:"required"` } +type EvalCircuit struct { + NAddMods ResOperand `json:"add_mod_n" validate:"required"` + AddModPtr ResOperand `json:"add_mod_ptr" validate:"required"` + NMulMods ResOperand `json:"mul_mod_n" validate:"required"` + MulModPtr ResOperand `json:"mul_mod_ptr" validate:"required"` +} + type TestLessThan struct { Lhs ResOperand `json:"lhs" validate:"required"` Rhs ResOperand `json:"rhs" validate:"required"` @@ -489,6 +497,8 @@ func (h *Hint) UnmarshalJSON(data []byte) error { // Core hints case AllocSegmentName: args = &AllocSegment{} + case EvalCircuitName: + args = &EvalCircuit{} case TestLessThanName: args = &TestLessThan{} case TestLessThanOrEqualName: diff --git a/pkg/parsers/starknet/starknet.go b/pkg/parsers/starknet/starknet.go index 894439fe0..0bbc615d5 100644 --- a/pkg/parsers/starknet/starknet.go +++ b/pkg/parsers/starknet/starknet.go @@ -4,90 +4,15 @@ import ( "encoding/json" "fmt" "os" - "strconv" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -type Builtin uint8 - -const ( - Output Builtin = iota + 1 - RangeCheck - Pedersen - ECDSA - Keccak - Bitwise - ECOP - Poseidon - SegmentArena - RangeCheck96 -) - -func (b Builtin) MarshalJSON() ([]byte, error) { - switch b { - case Output: - return []byte("output"), nil - case RangeCheck: - return []byte("range_check"), nil - case RangeCheck96: - return []byte("range_check96"), nil - case Pedersen: - return []byte("pedersen"), nil - case ECDSA: - return []byte("ecdsa"), nil - case Keccak: - return []byte("keccak"), nil - case Bitwise: - return []byte("bitwise"), nil - case ECOP: - return []byte("ec_op"), nil - case Poseidon: - return []byte("poseidon"), nil - case SegmentArena: - return []byte("segment_arena"), nil - - } - return nil, fmt.Errorf("marshal unknown builtin: %d", uint8(b)) -} - -func (b *Builtin) UnmarshalJSON(data []byte) error { - builtinName, err := strconv.Unquote(string(data)) - if err != nil { - return fmt.Errorf("unmarshal builtin: %w", err) - } - - switch builtinName { - case "output": - *b = Output - case "range_check": - *b = RangeCheck - case "range_check96": - *b = RangeCheck96 - case "pedersen": - *b = Pedersen - case "ecdsa": - *b = ECDSA - case "keccak": - *b = Keccak - case "bitwise": - *b = Bitwise - case "ec_op": - *b = ECOP - case "poseidon": - *b = Poseidon - case "segment_arena": - *b = SegmentArena - default: - return fmt.Errorf("unmarshal unknown builtin: %s", builtinName) - } - return nil -} - type EntryPointByTypeInfo struct { - Selector fp.Element `json:"selector"` - Offset fp.Element `json:"offset"` - Builtins []Builtin `json:"builtins"` + Selector fp.Element `json:"selector"` + Offset fp.Element `json:"offset"` + Builtins []builtins.BuiltinType `json:"builtins"` } type EntryPointByType struct { @@ -103,10 +28,10 @@ type Arg struct { } type EntryPointByFunction struct { - Offset int `json:"offset"` - Builtins []Builtin `json:"builtins"` - InputArgs []Arg `json:"input_args"` - ReturnArg []Arg `json:"return_arg"` + Offset int `json:"offset"` + Builtins []builtins.BuiltinType `json:"builtins"` + InputArgs []Arg `json:"input_args"` + ReturnArg []Arg `json:"return_arg"` } type Hints struct { diff --git a/pkg/parsers/starknet/starknet_test.go b/pkg/parsers/starknet/starknet_test.go index bc1d1acfb..64842bcf0 100644 --- a/pkg/parsers/starknet/starknet_test.go +++ b/pkg/parsers/starknet/starknet_test.go @@ -6,6 +6,8 @@ import ( "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" ) func TestCompilerVersionParsing(t *testing.T) { @@ -93,7 +95,7 @@ func TestEntryPointInfoParsing(t *testing.T) { assert.Len(t, entryPointInfo.Builtins, 9) for i := 0; i < 9; i++ { - assert.Equal(t, Builtin(i+1), entryPointInfo.Builtins[i]) + assert.Equal(t, builtins.BuiltinType(i+1), entryPointInfo.Builtins[i]) } } diff --git a/pkg/parsers/zero/zero.go b/pkg/parsers/zero/zero.go index 925350266..ceb320170 100644 --- a/pkg/parsers/zero/zero.go +++ b/pkg/parsers/zero/zero.go @@ -4,7 +4,7 @@ import ( "encoding/json" "os" - starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" ) type FlowTrackingData struct { @@ -68,16 +68,16 @@ type AttributeScope struct { } type ZeroProgram struct { - Prime string `json:"prime"` - Data []string `json:"data"` - Builtins []starknetParser.Builtin `json:"builtins"` - Hints map[string][]Hint `json:"hints"` - CompilerVersion string `json:"version"` - MainScope string `json:"main_scope"` - Identifiers map[string]*Identifier `json:"identifiers"` - ReferenceManager ReferenceManager `json:"reference_manager"` - Attributes []AttributeScope `json:"attributes"` - DebugInfo DebugInfo `json:"debug_info"` + Prime string `json:"prime"` + Data []string `json:"data"` + Builtins []builtins.BuiltinType `json:"builtins"` + Hints map[string][]Hint `json:"hints"` + CompilerVersion string `json:"compiler_version"` + MainScope string `json:"main_scope"` + Identifiers map[string]*Identifier `json:"identifiers"` + ReferenceManager ReferenceManager `json:"reference_manager"` + Attributes []AttributeScope `json:"attributes"` + DebugInfo DebugInfo `json:"debug_info"` } type Identifier struct { @@ -95,6 +95,7 @@ type Identifier struct { Value any `json:"value"` } +// TODO: Do we really need this ? func (z ZeroProgram) MarshalToFile(filepath string) error { // Marshal Output struct into JSON bytes data, err := json.MarshalIndent(z, "", " ") diff --git a/pkg/parsers/zero/zero_test.go b/pkg/parsers/zero/zero_test.go index 0db00ef56..527e4dbe3 100644 --- a/pkg/parsers/zero/zero_test.go +++ b/pkg/parsers/zero/zero_test.go @@ -3,7 +3,7 @@ package zero import ( "testing" - starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/stretchr/testify/require" ) @@ -69,10 +69,10 @@ func TestBuiltins(t *testing.T) { require.Equal(t, &ZeroProgram{ - Builtins: []starknetParser.Builtin{ - starknetParser.Output, - starknetParser.RangeCheck, - starknetParser.Bitwise, + Builtins: []builtins.BuiltinType{ + builtins.OutputType, + builtins.RangeCheckType, + builtins.BitwiseType, }, }, zeroProgram, diff --git a/pkg/runner/air_input.go b/pkg/runner/air_input.go new file mode 100644 index 000000000..8a7038ea0 --- /dev/null +++ b/pkg/runner/air_input.go @@ -0,0 +1,116 @@ +package runner + +import ( + "github.com/NethermindEth/cairo-vm-go/pkg/vm" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" +) + +func (runner *ZeroRunner) GetAirPublicInput() (AirPublicInput, error) { + rcMin, rcMax := runner.getPermRangeCheckLimits() + + // TODO: refactor to reuse earlier computed relocated trace + relocatedTrace := make([]vm.Trace, len(runner.vm.Trace)) + runner.vm.RelocateTrace(&relocatedTrace) + firstTrace := relocatedTrace[0] + lastTrace := relocatedTrace[len(relocatedTrace)-1] + memorySegments := make(map[string]AirMemorySegmentEntry) + // TODO: you need to calculate this for each builtin + memorySegments["program"] = AirMemorySegmentEntry{BeginAddr: firstTrace.Pc, StopPtr: lastTrace.Pc} + memorySegments["execution"] = AirMemorySegmentEntry{BeginAddr: firstTrace.Ap, StopPtr: lastTrace.Ap} + + return AirPublicInput{ + Layout: runner.layout.Name, + RcMin: rcMin, + RcMax: rcMax, + NSteps: len(runner.vm.Trace), + DynamicParams: nil, + // TODO: yet to be implemented fully + MemorySegments: memorySegments, + // TODO: yet to be implemented + PublicMemory: make([]AirPublicMemoryEntry, 0), + }, nil +} + +type AirPublicInput struct { + Layout string `json:"layout"` + RcMin uint16 `json:"rc_min"` + RcMax uint16 `json:"rc_max"` + NSteps int `json:"n_steps"` + DynamicParams interface{} `json:"dynamic_params"` + MemorySegments map[string]AirMemorySegmentEntry `json:"memory_segments"` + PublicMemory []AirPublicMemoryEntry `json:"public_memory"` +} + +type AirMemorySegmentEntry struct { + BeginAddr uint64 `json:"begin_addr"` + StopPtr uint64 `json:"stop_ptr"` +} + +type AirPublicMemoryEntry struct { + Address uint16 `json:"address"` + Value string `json:"value"` + Page uint16 `json:"page"` +} + +func (runner *ZeroRunner) GetAirPrivateInput(tracePath, memoryPath string) (AirPrivateInput, error) { + airPrivateInput := AirPrivateInput{ + TracePath: tracePath, + MemoryPath: memoryPath, + } + + for _, bRunner := range runner.layout.Builtins { + builtinName := bRunner.Runner.String() + builtinSegment, ok := runner.vm.Memory.FindSegmentWithBuiltin(builtinName) + if ok { + // some checks might be missing here + switch builtinName { + case builtins.RangeCheckName: + { + airPrivateInput.RangeCheck = bRunner.Runner.(*builtins.RangeCheck).GetAirPrivateInput(builtinSegment) + } + case builtins.BitwiseName: + { + airPrivateInput.Bitwise = bRunner.Runner.(*builtins.Bitwise).GetAirPrivateInput(builtinSegment) + } + case builtins.PoseidonName: + { + airPrivateInput.Poseidon = bRunner.Runner.(*builtins.Poseidon).GetAirPrivateInput(builtinSegment) + } + case builtins.PedersenName: + { + airPrivateInput.Pedersen = bRunner.Runner.(*builtins.Pedersen).GetAirPrivateInput(builtinSegment) + } + case builtins.EcOpName: + { + airPrivateInput.EcOp = bRunner.Runner.(*builtins.EcOp).GetAirPrivateInput(builtinSegment) + } + case builtins.KeccakName: + { + airPrivateInput.Keccak = bRunner.Runner.(*builtins.Keccak).GetAirPrivateInput(builtinSegment) + } + case builtins.ECDSAName: + { + ecdsaAirPrivateInput, err := bRunner.Runner.(*builtins.ECDSA).GetAirPrivateInput(builtinSegment) + if err != nil { + return AirPrivateInput{}, err + } + airPrivateInput.Ecdsa = ecdsaAirPrivateInput + } + } + } + } + + return airPrivateInput, nil +} + +type AirPrivateInput struct { + TracePath string `json:"trace_path"` + MemoryPath string `json:"memory_path"` + Pedersen []builtins.AirPrivateBuiltinPedersen `json:"pedersen"` + RangeCheck []builtins.AirPrivateBuiltinRangeCheck `json:"range_check"` + Ecdsa []builtins.AirPrivateBuiltinECDSA `json:"ecdsa"` + Bitwise []builtins.AirPrivateBuiltinBitwise `json:"bitwise"` + EcOp []builtins.AirPrivateBuiltinEcOp `json:"ec_op"` + Keccak []builtins.AirPrivateBuiltinKeccak `json:"keccak"` + Poseidon []builtins.AirPrivateBuiltinPoseidon `json:"poseidon"` +} diff --git a/pkg/runner/program.go b/pkg/runner/program.go new file mode 100644 index 000000000..adc30a71f --- /dev/null +++ b/pkg/runner/program.go @@ -0,0 +1,66 @@ +package runner + +import ( + "fmt" + + "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +type ZeroProgram struct { + // the bytecode in string format + Bytecode []*fp.Element + // given a string it returns the pc for that function call + Entrypoints map[string]uint64 + // it stores the start and end label pcs + Labels map[string]uint64 + // builtins + Builtins []builtins.BuiltinType +} + +type CairoProgram struct{} + +func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error) { + // bytecode + bytecode := make([]*fp.Element, len(cairoZeroJson.Data)) + for i := range cairoZeroJson.Data { + felt, err := new(fp.Element).SetString(cairoZeroJson.Data[i]) + if err != nil { + return nil, fmt.Errorf( + "cannot read bytecode %s at position %d: %w", + cairoZeroJson.Data[i], i, err, + ) + } + bytecode[i] = felt + } + + entrypoints, labels := extractEntrypointsAndLabels(cairoZeroJson) + + return &ZeroProgram{ + Bytecode: bytecode, + Entrypoints: entrypoints, + Labels: labels, + Builtins: cairoZeroJson.Builtins, + }, nil +} + +func extractEntrypointsAndLabels(json *zero.ZeroProgram) (map[string]uint64, map[string]uint64) { + entrypoints := map[string]uint64{} + for key, ident := range json.Identifiers { + if ident.IdentifierType == "function" { + name := key[len(json.MainScope)+1:] + entrypoints[name] = uint64(ident.Pc) + } + } + + labels := make(map[string]uint64, 2) + for key, ident := range json.Identifiers { + if ident.IdentifierType == "label" { + name := key[len(json.MainScope)+1:] + labels[name] = uint64(ident.Pc) + } + } + + return entrypoints, labels +} diff --git a/pkg/runners/zero/program_test.go b/pkg/runner/program_test.go similarity index 96% rename from pkg/runners/zero/program_test.go rename to pkg/runner/program_test.go index 63f274e27..599f8b1f3 100644 --- a/pkg/runners/zero/program_test.go +++ b/pkg/runner/program_test.go @@ -1,4 +1,4 @@ -package zero +package runner import ( "testing" @@ -50,7 +50,7 @@ func TestLoadCairoZeroProgram(t *testing.T) { program, err := LoadCairoZeroProgram(cairoZeroJson) require.NoError(t, err) - require.Equal(t, &Program{ + require.Equal(t, &ZeroProgram{ Bytecode: []*fp.Element{ stringToFelt("0x01"), stringToFelt("0x02"), diff --git a/pkg/runners/zero/zero.go b/pkg/runner/runner.go similarity index 78% rename from pkg/runners/zero/zero.go rename to pkg/runner/runner.go index f0c802307..bff732be2 100644 --- a/pkg/runners/zero/zero.go +++ b/pkg/runner/runner.go @@ -1,4 +1,4 @@ -package zero +package runner import ( "errors" @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" - "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" @@ -14,38 +13,30 @@ import ( "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -type RunnerMode uint8 - -const ( - ExecutionMode RunnerMode = iota + 1 - ProofModeCairo0 - ProofModeCairo1 -) - -type Runner struct { +type ZeroRunner struct { // core components - program *Program + program *ZeroProgram vm *vm.VirtualMachine hintrunner hintrunner.HintRunner // config proofmode bool collectTrace bool maxsteps uint64 - runnerMode RunnerMode // auxiliar runFinished bool layout builtins.Layout } +type CairoRunner struct{} + // Creates a new Runner of a Cairo Zero program -func NewRunner(runnerMode RunnerMode, program *Program, hints map[uint64][]hinter.Hinter, proofmode bool, collectTrace bool, maxsteps uint64, layoutName string) (Runner, error) { +func NewRunner(program *ZeroProgram, hints map[uint64][]hinter.Hinter, proofmode bool, collectTrace bool, maxsteps uint64, layoutName string) (ZeroRunner, error) { hintrunner := hintrunner.NewHintRunner(hints) layout, err := builtins.GetLayout(layoutName) if err != nil { - return Runner{}, err + return ZeroRunner{}, err } - return Runner{ - runnerMode: runnerMode, + return ZeroRunner{ program: program, hintrunner: hintrunner, proofmode: proofmode, @@ -57,7 +48,7 @@ func NewRunner(runnerMode RunnerMode, program *Program, hints map[uint64][]hinte // RunEntryPoint is like Run, but it executes the program starting from the given PC offset. // This PC offset is expected to be a start from some function inside the loaded program. -func (runner *Runner) RunEntryPoint(pc uint64) error { +func (runner *ZeroRunner) RunEntryPoint(pc uint64) error { if runner.runFinished { return errors.New("cannot re-run using the same runner") } @@ -76,11 +67,7 @@ func (runner *Runner) RunEntryPoint(pc uint64) error { returnFp := memory.AllocateEmptySegment() mvReturnFp := mem.MemoryValueFromMemoryAddress(&returnFp) - cairo1FpOffset := uint64(0) - if runner.runnerMode == ProofModeCairo1 { - cairo1FpOffset = 2 - } - end, err := runner.initializeEntrypoint(pc, nil, &mvReturnFp, memory, stack, cairo1FpOffset) + end, err := runner.initializeEntrypoint(pc, nil, &mvReturnFp, memory, stack) if err != nil { return err } @@ -92,10 +79,11 @@ func (runner *Runner) RunEntryPoint(pc uint64) error { return nil } -func (runner *Runner) Run() error { +func (runner *ZeroRunner) Run() error { if runner.runFinished { return errors.New("cannot re-run using the same runner") } + end, err := runner.initializeMainEntrypoint() if err != nil { return fmt.Errorf("initializing main entry point: %w", err) @@ -117,7 +105,7 @@ func (runner *Runner) Run() error { return nil } -func (runner *Runner) initializeSegments() (*mem.Memory, error) { +func (runner *ZeroRunner) initializeSegments() (*mem.Memory, error) { memory := mem.InitializeEmptyMemory() _, err := memory.AllocateSegment(runner.program.Bytecode) // ProgramSegment if err != nil { @@ -128,7 +116,7 @@ func (runner *Runner) initializeSegments() (*mem.Memory, error) { return memory, nil } -func (runner *Runner) initializeMainEntrypoint() (mem.MemoryAddress, error) { +func (runner *ZeroRunner) initializeMainEntrypoint() (mem.MemoryAddress, error) { memory, err := runner.initializeSegments() if err != nil { return mem.UnknownAddress, err @@ -138,24 +126,8 @@ func (runner *Runner) initializeMainEntrypoint() (mem.MemoryAddress, error) { if err != nil { return mem.UnknownAddress, err } - switch runner.runnerMode { - case ExecutionMode: - returnFp := memory.AllocateEmptySegment() - mvReturnFp := mem.MemoryValueFromMemoryAddress(&returnFp) - mainPCOffset, ok := runner.program.Entrypoints["main"] - if !ok { - return mem.UnknownAddress, errors.New("can't find an entrypoint for main") - } - return runner.initializeEntrypoint(mainPCOffset, nil, &mvReturnFp, memory, stack, 0) - case ProofModeCairo1: - returnFp := memory.AllocateEmptySegment() - mvReturnFp := mem.MemoryValueFromMemoryAddress(&returnFp) - mainPCOffset, ok := runner.program.Entrypoints["main"] - if !ok { - return mem.UnknownAddress, errors.New("can't find an entrypoint for main") - } - return runner.initializeEntrypoint(mainPCOffset, nil, &mvReturnFp, memory, stack, 2) - case ProofModeCairo0: + + if runner.proofmode { initialPCOffset, ok := runner.program.Labels["__start__"] if !ok { return mem.UnknownAddress, @@ -176,7 +148,7 @@ func (runner *Runner) initializeMainEntrypoint() (mem.MemoryAddress, error) { if err := runner.initializeVm(&mem.MemoryAddress{ SegmentIndex: vm.ProgramSegment, Offset: initialPCOffset, - }, stack, memory, 0); err != nil { + }, stack, memory); err != nil { return mem.UnknownAddress, err } @@ -184,27 +156,34 @@ func (runner *Runner) initializeMainEntrypoint() (mem.MemoryAddress, error) { runner.vm.Context.Ap = 2 runner.vm.Context.Fp = 2 return mem.MemoryAddress{SegmentIndex: vm.ProgramSegment, Offset: endPcOffset}, nil + } + returnFp := memory.AllocateEmptySegment() + mvReturnFp := mem.MemoryValueFromMemoryAddress(&returnFp) + mainPCOffset, ok := runner.program.Entrypoints["main"] + if !ok { + return mem.UnknownAddress, errors.New("can't find an entrypoint for main") } - return mem.UnknownAddress, errors.New("unknown runner mode") + return runner.initializeEntrypoint(mainPCOffset, nil, &mvReturnFp, memory, stack) } -func (runner *Runner) initializeEntrypoint( - initialPCOffset uint64, arguments []*fp.Element, returnFp *mem.MemoryValue, memory *mem.Memory, stack []mem.MemoryValue, cairo1FpOffset uint64, +func (runner *ZeroRunner) initializeEntrypoint( + initialPCOffset uint64, arguments []*fp.Element, returnFp *mem.MemoryValue, memory *mem.Memory, stack []mem.MemoryValue, ) (mem.MemoryAddress, error) { for i := range arguments { stack = append(stack, mem.MemoryValueFromFieldElement(arguments[i])) } - endPC := memory.AllocateEmptySegment() - stack = append(stack, *returnFp, mem.MemoryValueFromMemoryAddress(&endPC)) - return endPC, runner.initializeVm(&mem.MemoryAddress{ + end := memory.AllocateEmptySegment() + + stack = append(stack, *returnFp, mem.MemoryValueFromMemoryAddress(&end)) + return end, runner.initializeVm(&mem.MemoryAddress{ SegmentIndex: vm.ProgramSegment, Offset: initialPCOffset, - }, stack, memory, cairo1FpOffset) + }, stack, memory) } -func (runner *Runner) initializeBuiltins(memory *mem.Memory) ([]mem.MemoryValue, error) { - builtinsSet := make(map[starknet.Builtin]bool) +func (runner *ZeroRunner) initializeBuiltins(memory *mem.Memory) ([]mem.MemoryValue, error) { + builtinsSet := make(map[builtins.BuiltinType]bool) for _, bRunner := range runner.layout.Builtins { builtinsSet[bRunner.Builtin] = true } @@ -229,30 +208,30 @@ func (runner *Runner) initializeBuiltins(memory *mem.Memory) ([]mem.MemoryValue, return stack, nil } -func (runner *Runner) initializeVm( - initialPC *mem.MemoryAddress, stack []mem.MemoryValue, memory *mem.Memory, cairo1FpOffset uint64, +func (runner *ZeroRunner) initializeVm( + initialPC *mem.MemoryAddress, stack []mem.MemoryValue, memory *mem.Memory, ) error { executionSegment := memory.Segments[vm.ExecutionSegment] offset := executionSegment.Len() - for idx := range stack { + stackSize := uint64(len(stack)) + for idx := uint64(0); idx < stackSize; idx++ { if err := executionSegment.Write(offset+uint64(idx), &stack[idx]); err != nil { return err } } - initialFp := offset + uint64(len(stack)) + cairo1FpOffset var err error // initialize vm runner.vm, err = vm.NewVirtualMachine(vm.Context{ Pc: *initialPC, - Ap: initialFp, - Fp: initialFp, + Ap: offset + stackSize, + Fp: offset + stackSize, }, memory, vm.VirtualMachineConfig{ProofMode: runner.proofmode, CollectTrace: runner.collectTrace}) return err } // run until the program counter equals the `pc` parameter -func (runner *Runner) RunUntilPc(pc *mem.MemoryAddress) error { +func (runner *ZeroRunner) RunUntilPc(pc *mem.MemoryAddress) error { for !runner.vm.Context.Pc.Equal(pc) { if runner.steps() >= runner.maxsteps { return fmt.Errorf( @@ -270,7 +249,7 @@ func (runner *Runner) RunUntilPc(pc *mem.MemoryAddress) error { } // run until the vm step count reaches the `steps` parameter -func (runner *Runner) RunFor(steps uint64) error { +func (runner *ZeroRunner) RunFor(steps uint64) error { for runner.steps() < steps { if runner.steps() >= runner.maxsteps { return fmt.Errorf( @@ -296,7 +275,7 @@ func (runner *Runner) RunFor(steps uint64) error { // until the checkUsedCells doesn't return any error. // Since this vm always finishes the run of the program at the number of steps that is a power of two in the proof mode, // there is no need to run additional steps before the loop. -func (runner *Runner) EndRun() error { +func (runner *ZeroRunner) EndRun() error { for runner.checkUsedCells() != nil { pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1) if err := runner.RunFor(pow2Steps); err != nil { @@ -308,7 +287,7 @@ func (runner *Runner) EndRun() error { // checkUsedCells returns error if not enough steps were made to allocate required number of cells for builtins // or there are not enough trace cells to fill the entire range check range -func (runner *Runner) checkUsedCells() error { +func (runner *ZeroRunner) checkUsedCells() error { for _, bRunner := range runner.layout.Builtins { builtinName := bRunner.Runner.String() builtinSegment, ok := runner.vm.Memory.FindSegmentWithBuiltin(builtinName) @@ -327,11 +306,11 @@ func (runner *Runner) checkUsedCells() error { } // Checks if there are not enough trace cells to fill the entire range check range. Each step has assigned a number of range check units. If the number of unused range check units is less than the range of potential values to be checked (defined by rcMin and rcMax), the number of trace cells must be increased, by running additional steps. -func (runner *Runner) checkRangeCheckUsage() error { +func (runner *ZeroRunner) checkRangeCheckUsage() error { rcMin, rcMax := runner.getPermRangeCheckLimits() var rcUnitsUsedByBuiltins uint64 for _, builtin := range runner.program.Builtins { - if builtin == starknet.RangeCheck { + if builtin == builtins.RangeCheckType { for _, layoutBuiltin := range runner.layout.Builtins { if builtin == layoutBuiltin.Builtin { rangeCheckRunner, ok := layoutBuiltin.Runner.(*builtins.RangeCheck) @@ -356,11 +335,11 @@ func (runner *Runner) checkRangeCheckUsage() error { } // getPermRangeCheckLimits returns the minimum and maximum values used by the range check units in the program. To find the values, maximum and minimum values from the range check segment are compared with maximum and minimum values of instructions offsets calculated during running the instructions. -func (runner *Runner) getPermRangeCheckLimits() (uint16, uint16) { +func (runner *ZeroRunner) getPermRangeCheckLimits() (uint16, uint16) { rcMin, rcMax := runner.vm.RcLimitsMin, runner.vm.RcLimitsMax for _, builtin := range runner.program.Builtins { - if builtin == starknet.RangeCheck { + if builtin == builtins.RangeCheckType { bRunner := builtins.Runner(builtin) rangeCheckRunner, _ := bRunner.(*builtins.RangeCheck) rangeCheckSegment, ok := runner.vm.Memory.FindSegmentWithBuiltin(rangeCheckRunner.String()) @@ -381,7 +360,7 @@ func (runner *Runner) getPermRangeCheckLimits() (uint16, uint16) { // FinalizeSegments calculates the final size of the builtins segments, // using number of allocated instances and memory cells per builtin instance. // Additionally it sets the final size of the program segment to the program size. -func (runner *Runner) FinalizeSegments() error { +func (runner *ZeroRunner) FinalizeSegments() error { programSize := uint64(len(runner.program.Bytecode)) runner.vm.Memory.Segments[vm.ProgramSegment].Finalize(programSize) for _, bRunner := range runner.layout.Builtins { @@ -398,34 +377,35 @@ func (runner *Runner) FinalizeSegments() error { } // BuildMemory relocates the memory and returns it -func (runner *Runner) BuildMemory() ([]byte, error) { +func (runner *ZeroRunner) BuildMemory() ([]byte, error) { relocatedMemory := runner.vm.RelocateMemory() return vm.EncodeMemory(relocatedMemory), nil } // BuildTrace relocates the trace and returns it -func (runner *Runner) BuildTrace() ([]byte, error) { - relocatedTrace := runner.vm.RelocateTrace() +func (runner *ZeroRunner) BuildTrace() ([]byte, error) { + relocatedTrace := make([]vm.Trace, len(runner.vm.Trace)) + runner.vm.RelocateTrace(&relocatedTrace) return vm.EncodeTrace(relocatedTrace), nil } -func (runner *Runner) pc() mem.MemoryAddress { +func (runner *ZeroRunner) pc() mem.MemoryAddress { return runner.vm.Context.Pc } -func (runner *Runner) steps() uint64 { +func (runner *ZeroRunner) steps() uint64 { return runner.vm.Step } // Gives the output of the last run. Panics if there hasn't // been any runs yet. -func (runner *Runner) Output() []*fp.Element { +func (runner *ZeroRunner) Output() []*fp.Element { if runner.vm == nil { panic("cannot get the output from an uninitialized runner") } output := []*fp.Element{} - outputSegment, ok := runner.vm.Memory.FindSegmentWithBuiltin("output") + outputSegment, ok := runner.vm.Memory.FindSegmentWithBuiltin(builtins.OutputName) if !ok { return output } diff --git a/pkg/runners/zero/zero_benchmark_test.go b/pkg/runner/runner_benchmark_test.go similarity index 99% rename from pkg/runners/zero/zero_benchmark_test.go rename to pkg/runner/runner_benchmark_test.go index 980360718..8015d9719 100644 --- a/pkg/runners/zero/zero_benchmark_test.go +++ b/pkg/runner/runner_benchmark_test.go @@ -1,4 +1,4 @@ -package zero +package runner import ( "math" diff --git a/pkg/runners/zero/zero_test.go b/pkg/runner/runner_test.go similarity index 76% rename from pkg/runners/zero/zero_test.go rename to pkg/runner/runner_test.go index c358e762b..abad43e28 100644 --- a/pkg/runners/zero/zero_test.go +++ b/pkg/runner/runner_test.go @@ -1,4 +1,4 @@ -package zero +package runner import ( "fmt" @@ -7,8 +7,8 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/assembler" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" - sn "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/vm" + "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" pedersenhash "github.com/consensys/gnark-crypto/ecc/stark-curve/pedersen-hash" @@ -27,7 +27,7 @@ func TestSimpleProgram(t *testing.T) { `) hints := make(map[uint64][]hinter.Hinter) - runner, err := NewRunner(ExecutionMode, program, hints, false, false, math.MaxUint64, "plain") + runner, err := NewRunner(program, hints, false, false, math.MaxUint64, "plain") require.NoError(t, err) endPc, err := runner.initializeMainEntrypoint() @@ -74,7 +74,7 @@ func TestStepLimitExceeded(t *testing.T) { `) hints := make(map[uint64][]hinter.Hinter) - runner, err := NewRunner(ExecutionMode, program, hints, false, false, 3, "plain") + runner, err := NewRunner(program, hints, false, false, 3, "plain") require.NoError(t, err) endPc, err := runner.initializeMainEntrypoint() @@ -133,7 +133,7 @@ func TestStepLimitExceededProofMode(t *testing.T) { // when maxstep = 6, it fails executing the extra step required by proof mode // when maxstep = 7, it fails trying to get the trace to be a power of 2 hints := make(map[uint64][]hinter.Hinter) - runner, err := NewRunner(ProofModeCairo0, program, hints, true, false, uint64(maxstep), "plain") + runner, err := NewRunner(program, hints, true, false, uint64(maxstep), "plain") require.NoError(t, err) err = runner.Run() @@ -186,12 +186,12 @@ func TestBitwiseBuiltin(t *testing.T) { [ap + 1] = 9; [ap + 2] = 15; ret; - `, "starknet_with_keccak", sn.Bitwise) + `, "starknet_with_keccak", builtins.BitwiseType) err := runner.Run() require.NoError(t, err) - bitwise, ok := runner.vm.Memory.FindSegmentWithBuiltin("bitwise") + bitwise, ok := runner.vm.Memory.FindSegmentWithBuiltin(builtins.BitwiseName) require.True(t, ok) requireEqualSegments(t, createSegment(14, 7, 6, 9, 15), bitwise) @@ -202,7 +202,7 @@ func TestBitwiseBuiltinError(t *testing.T) { runner := createRunner(` [ap] = [[fp - 3]]; ret; - `, "starknet_with_keccak", sn.Bitwise) + `, "starknet_with_keccak", builtins.BitwiseType) err := runner.Run() require.ErrorContains(t, err, "cannot infer value") @@ -211,7 +211,7 @@ func TestBitwiseBuiltinError(t *testing.T) { runner = createRunner(` [ap] = [[fp - 3] + 1]; ret; - `, "starknet_with_keccak", sn.Bitwise) + `, "starknet_with_keccak", builtins.BitwiseType) err = runner.Run() require.ErrorContains(t, err, "cannot infer value") @@ -219,7 +219,7 @@ func TestBitwiseBuiltinError(t *testing.T) { runner = createRunner(` [ap] = [[fp - 3] + 2]; ret; - `, "starknet_with_keccak", sn.Bitwise) + `, "starknet_with_keccak", builtins.BitwiseType) err = runner.Run() require.ErrorContains(t, err, "input value at offset 0 is unknown") @@ -233,7 +233,7 @@ func TestOutputBuiltin(t *testing.T) { [ap + 1] = 7; [ap + 1] = [[fp - 3] + 1]; ret; - `, "small", sn.Output) + `, "small", builtins.OutputType) err := runner.Run() require.NoError(t, err) @@ -263,11 +263,11 @@ func TestPedersenBuiltin(t *testing.T) { ret; `, val1.Text(10), val2.Text(10), val3.Text(10)) - runner := createRunner(code, "small", sn.Pedersen) + runner := createRunner(code, "small", builtins.PedersenType) err := runner.Run() require.NoError(t, err) - pedersen, ok := runner.vm.Memory.FindSegmentWithBuiltin("pedersen") + pedersen, ok := runner.vm.Memory.FindSegmentWithBuiltin(builtins.PedersenName) require.True(t, ok) requireEqualSegments(t, createSegment(&val1, &val2, &val3), pedersen) } @@ -276,14 +276,14 @@ func TestPedersenBuiltinError(t *testing.T) { runner := createRunner(` [ap] = [[fp - 3]]; ret; - `, "small", sn.Pedersen) + `, "small", builtins.PedersenType) err := runner.Run() require.ErrorContains(t, err, "cannot infer value") runner = createRunner(` [ap] = [[fp - 3] + 2]; ret; - `, "small", sn.Pedersen) + `, "small", builtins.PedersenType) err = runner.Run() require.ErrorContains(t, err, "input value at offset 0 is unknown") } @@ -298,7 +298,7 @@ func TestRangeCheckBuiltin(t *testing.T) { [ap + 1] = 0xffffffffffffffffffffffffffffffff; [ap + 1] = [[fp - 3] + 1]; ret; - `, "small", sn.RangeCheck) + `, "small", builtins.RangeCheckType) err := runner.Run() require.NoError(t, err) @@ -319,7 +319,7 @@ func TestRangeCheckBuiltinError(t *testing.T) { [ap] = 0x100000000000000000000000000000000; [ap] = [[fp - 3]]; ret; - `, "small", sn.RangeCheck) + `, "small", builtins.RangeCheckType) err := runner.Run() require.ErrorContains(t, err, "check write: 2**128 <") @@ -328,7 +328,53 @@ func TestRangeCheckBuiltinError(t *testing.T) { runner = createRunner(` [ap] = [[fp - 3]]; ret; - `, "small", sn.RangeCheck) + `, "small", builtins.RangeCheckType) + + err = runner.Run() + require.ErrorContains(t, err, "cannot infer value") +} + +func TestRangeCheck96Builtin(t *testing.T) { + // range_check96 is located at fp - 3 (fp - 2 and fp - 1 contain initialization vals) + // we write 5 and 2**96 - 1 to range check + // no error should come from this + runner := createRunner(` + [ap] = 5; + [ap] = [[fp - 3]]; + [ap + 1] = 0xffffffffffffffffffffffff; + [ap + 1] = [[fp - 3] + 1]; + ret; + `, "all_cairo", builtins.RangeCheck96Type) + + err := runner.Run() + require.NoError(t, err) + + rangeCheck96, ok := runner.vm.Memory.FindSegmentWithBuiltin(builtins.RangeCheck96Name) + require.True(t, ok) + + felt := &fp.Element{} + felt, err = felt.SetString("0xffffffffffffffffffffffff") + require.NoError(t, err) + + requireEqualSegments(t, createSegment(5, felt), rangeCheck96) +} + +func TestRangeCheck96BuiltinError(t *testing.T) { + // first test fails due to out of bound check + runner := createRunner(` + [ap] = 0x1000000000000000000000000; + [ap] = [[fp - 3]]; + ret; + `, "all_cairo", builtins.RangeCheck96Type) + + err := runner.Run() + require.ErrorContains(t, err, "check write: 2**96 <") + + // second test fails due to reading unknown value + runner = createRunner(` + [ap] = [[fp - 3]]; + ret; + `, "all_cairo", builtins.RangeCheck96Type) err = runner.Run() require.ErrorContains(t, err, "cannot infer value") @@ -358,17 +404,39 @@ func TestEcOpBuiltin(t *testing.T) { [ap + 5] = 108925483682366235368969256555281508851459278989259552980345066351008608800; [ap + 6] = 1592365885972480102953613056006596671718206128324372995731808913669237079419; ret; - `, "starknet_with_keccak", sn.ECOP) + `, "starknet_with_keccak", builtins.ECOPType) err := runner.Run() require.NoError(t, err) } -func createRunner(code string, layoutName string, builtins ...sn.Builtin) Runner { +func TestModuloBuiltin(t *testing.T) { + // modulo is located at fp - 3 + // we first write 2048 and 5 to modulo + // then we read the modulo result from add and mul + // runner := createRunner(` + // [ap] = 2048; + // [ap] = [[fp - 3]]; + + // [ap + 1] = 5; + // [ap + 1] = [[fp - 3] + 1]; + // ret; + // `, "small", sn.AddMod, sn.MulMod) + + // err := runner.Run() + // require.NoError(t, err) + + // modulo, ok := runner.vm.Memory.FindSegmentWithBuiltin("add_mod") + // require.True(t, ok) + + // requireEqualSegments(t, createSegment(2048, 5), modulo) +} + +func createRunner(code string, layoutName string, builtins ...builtins.BuiltinType) ZeroRunner { program := createProgramWithBuiltins(code, builtins...) hints := make(map[uint64][]hinter.Hinter) - runner, err := NewRunner(ExecutionMode, program, hints, false, false, math.MaxUint64, layoutName) + runner, err := NewRunner(program, hints, false, false, math.MaxUint64, layoutName) if err != nil { panic(err) } @@ -414,13 +482,13 @@ func trimmedSegment(segment *memory.Segment) *memory.Segment { return segment } -func createProgram(code string) *Program { +func createProgram(code string) *ZeroProgram { bytecode, err := assembler.CasmToBytecode(code) if err != nil { panic(err) } - program := Program{ + program := ZeroProgram{ Bytecode: bytecode, Entrypoints: map[string]uint64{ "main": 0, @@ -430,7 +498,7 @@ func createProgram(code string) *Program { return &program } -func createProgramWithBuiltins(code string, builtins ...sn.Builtin) *Program { +func createProgramWithBuiltins(code string, builtins ...builtins.BuiltinType) *ZeroProgram { program := createProgram(code) program.Builtins = builtins return program diff --git a/pkg/runners/cairo/cairo.go b/pkg/runners/cairo/cairo.go deleted file mode 100644 index fedf8a0ad..000000000 --- a/pkg/runners/cairo/cairo.go +++ /dev/null @@ -1 +0,0 @@ -package cairo diff --git a/pkg/runners/cairo/cairo_test.go b/pkg/runners/cairo/cairo_test.go deleted file mode 100644 index fedf8a0ad..000000000 --- a/pkg/runners/cairo/cairo_test.go +++ /dev/null @@ -1 +0,0 @@ -package cairo diff --git a/pkg/runners/cairo/program.go b/pkg/runners/cairo/program.go deleted file mode 100644 index fedf8a0ad..000000000 --- a/pkg/runners/cairo/program.go +++ /dev/null @@ -1 +0,0 @@ -package cairo diff --git a/pkg/runners/cairo/program_test.go b/pkg/runners/cairo/program_test.go deleted file mode 100644 index fedf8a0ad..000000000 --- a/pkg/runners/cairo/program_test.go +++ /dev/null @@ -1 +0,0 @@ -package cairo diff --git a/pkg/utils/math.go b/pkg/utils/math.go index ea02fd0ca..40b310fc9 100644 --- a/pkg/utils/math.go +++ b/pkg/utils/math.go @@ -1,6 +1,8 @@ package utils import ( + "errors" + "fmt" "math" "math/big" "math/bits" @@ -156,3 +158,75 @@ func Int16FromBigInt(n *big.Int) (int16, bool) { func RightRot(value uint32, n uint32) uint32 { return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n)) } + +func SafeDivUint64(x, y uint64) (uint64, error) { + if y == 0 { + return 0, fmt.Errorf("cannot divide: y division is zero") + } + if x%y != 0 { + return 0, errors.New("cannot divide: x is not divisible by y") + } + return x / y, nil +} + +func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) { + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362 + + if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 { + return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0) + } + g, x, y := gcdext(a, b) + return x, y, g +} + +func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) { + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125 + + if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 { + g := new(big.Int) + if a.Cmp(big.NewInt(0)) == 0 { + g.Abs(b) + } else { + g.Abs(a) + } + + if g.Cmp(big.NewInt(0)) == 0 { + return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0) + } + return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g) + } + + xSign, aSigned := sign(a) + ySign, bSigned := sign(b) + x, r := big.NewInt(1), big.NewInt(0) + y, s := big.NewInt(0), big.NewInt(1) + + for bSigned.Sign() != 0 { + q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int)) + aSigned = bSigned + bSigned = *c + x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r)) + y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s)) + } + + return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign))) +} + +func sign(n *big.Int) (int, big.Int) { + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119 + + if n.Sign() < 0 { + return -1, *new(big.Int).Abs(n) + } + return 1, *new(big.Int).Set(n) +} + +func SafeDiv(x, y *big.Int) (big.Int, error) { + if y.Cmp(big.NewInt(0)) == 0 { + return *big.NewInt(0), fmt.Errorf("division by zero") + } + if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 { + return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y) + } + return *new(big.Int).Div(x, y), nil +} diff --git a/pkg/utils/math_test.go b/pkg/utils/math_test.go index 62cd381eb..0a455c072 100644 --- a/pkg/utils/math_test.go +++ b/pkg/utils/math_test.go @@ -1,6 +1,7 @@ package utils import ( + "math/big" "testing" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -119,3 +120,69 @@ func TestRightRot(t *testing.T) { }) } } + +func TestIgcdex(t *testing.T) { + // https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278 + tests := []struct { + name string + a, b *big.Int + expectedX, expectedY, expectedG *big.Int + }{ + { + name: "Case 1", + a: big.NewInt(2), + b: big.NewInt(3), + expectedX: big.NewInt(-1), + expectedY: big.NewInt(1), + expectedG: big.NewInt(1), + }, + { + name: "Case 2", + a: big.NewInt(10), + b: big.NewInt(12), + expectedX: big.NewInt(-1), + expectedY: big.NewInt(1), + expectedG: big.NewInt(2), + }, + { + name: "Case 3", + a: big.NewInt(100), + b: big.NewInt(2004), + expectedX: big.NewInt(-20), + expectedY: big.NewInt(1), + expectedG: big.NewInt(4), + }, + { + name: "Case 4", + a: big.NewInt(0), + b: big.NewInt(0), + expectedX: big.NewInt(0), + expectedY: big.NewInt(1), + expectedG: big.NewInt(0), + }, + { + name: "Case 5", + a: big.NewInt(1), + b: big.NewInt(0), + expectedX: big.NewInt(1), + expectedY: big.NewInt(0), + expectedG: big.NewInt(1), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualX, actualY, actualG := Igcdex(tt.a, tt.b) + + if actualX.Cmp(tt.expectedX) != 0 { + t.Errorf("got x: %v, want: %v", actualX, tt.expectedX) + } + if actualY.Cmp(tt.expectedY) != 0 { + t.Errorf("got x: %v, want: %v", actualY, tt.expectedY) + } + if actualG.Cmp(tt.expectedG) != 0 { + t.Errorf("got x: %v, want: %v", actualG, tt.expectedG) + } + }) + } +} diff --git a/pkg/vm/builtins/bitwise.go b/pkg/vm/builtins/bitwise.go index b498a321d..0906b0921 100644 --- a/pkg/vm/builtins/bitwise.go +++ b/pkg/vm/builtins/bitwise.go @@ -3,16 +3,19 @@ package builtins import ( "errors" "fmt" + "math/big" + "sort" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -const BitwiseName = "bitwise" - -const cellsPerBitwise = 5 -const inputCellsPerBitwise = 2 -const instancesPerComponentBitwise = 1 +const ( + BitwiseName = "bitwise" + cellsPerBitwise = 5 + inputCellsPerBitwise = 2 + instancesPerComponentBitwise = 1 +) type Bitwise struct { ratio uint64 @@ -100,3 +103,50 @@ func (b *Bitwise) String() string { func (b *Bitwise) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) { return getBuiltinAllocatedSize(segmentUsedSize, vmCurrentStep, b.ratio, inputCellsPerBitwise, instancesPerComponentBitwise, cellsPerBitwise) } + +type AirPrivateBuiltinBitwise struct { + Index int `json:"index"` + X string `json:"x"` + Y string `json:"y"` +} + +func (b *Bitwise) GetAirPrivateInput(bitwiseSegment *memory.Segment) []AirPrivateBuiltinBitwise { + valueMapping := make(map[int]AirPrivateBuiltinBitwise) + for index, value := range bitwiseSegment.Data { + if !value.Known() { + continue + } + idx, typ := index/cellsPerBitwise, index%cellsPerBitwise + if typ >= 2 { + continue + } + + builtinValue, exists := valueMapping[idx] + if !exists { + builtinValue = AirPrivateBuiltinBitwise{Index: idx} + } + + valueBig := big.Int{} + value.Felt.BigInt(&valueBig) + valueHex := fmt.Sprintf("0x%x", &valueBig) + if typ == 0 { + builtinValue.X = valueHex + } else { + builtinValue.Y = valueHex + } + valueMapping[idx] = builtinValue + } + + values := make([]AirPrivateBuiltinBitwise, 0) + + sortedIndexes := make([]int, 0, len(valueMapping)) + for index := range valueMapping { + sortedIndexes = append(sortedIndexes, index) + } + sort.Ints(sortedIndexes) + for _, index := range sortedIndexes { + value := valueMapping[index] + values = append(values, value) + } + return values +} diff --git a/pkg/vm/builtins/builtin_runner.go b/pkg/vm/builtins/builtin_runner.go index 77b74fa00..1b95fe77b 100644 --- a/pkg/vm/builtins/builtin_runner.go +++ b/pkg/vm/builtins/builtin_runner.go @@ -3,33 +3,54 @@ package builtins import ( "fmt" "math" + "strconv" - starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" ) -func Runner(name starknetParser.Builtin) memory.BuiltinRunner { +type BuiltinType uint8 + +const ( + OutputType BuiltinType = iota + 1 + RangeCheckType + PedersenType + ECDSAType + KeccakType + BitwiseType + ECOPType + PoseidonType + SegmentArenaType + RangeCheck96Type + AddModeType + MulModType +) + +func Runner(name BuiltinType) memory.BuiltinRunner { switch name { - case starknetParser.Output: + case OutputType: return &Output{} - case starknetParser.RangeCheck: + case RangeCheckType: return &RangeCheck{0, 8} - case starknetParser.RangeCheck96: + case RangeCheck96Type: return &RangeCheck{0, 6} - case starknetParser.Pedersen: + case PedersenType: return &Pedersen{} - case starknetParser.ECDSA: + case ECDSAType: return &ECDSA{} - case starknetParser.Keccak: + case KeccakType: return &Keccak{} - case starknetParser.Bitwise: + case BitwiseType: return &Bitwise{} - case starknetParser.ECOP: + case ECOPType: return &EcOp{} - case starknetParser.Poseidon: + case PoseidonType: return &Poseidon{} - case starknetParser.SegmentArena: + case AddModeType: + return &ModBuiltin{modBuiltinType: Add} + case MulModType: + return &ModBuiltin{modBuiltinType: Mul} + case SegmentArenaType: panic("Not implemented") default: panic("Unknown builtin") @@ -65,3 +86,71 @@ func getBuiltinAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64, ratio } return allocatedInstances * cellsPerInstance, nil } + +func (b BuiltinType) MarshalJSON() ([]byte, error) { + switch b { + case OutputType: + return []byte(OutputName), nil + case RangeCheckType: + return []byte(RangeCheckName), nil + case RangeCheck96Type: + return []byte(RangeCheck96Name), nil + case PedersenType: + return []byte(PedersenName), nil + case ECDSAType: + return []byte(ECDSAName), nil + case KeccakType: + return []byte(KeccakName), nil + case BitwiseType: + return []byte(BitwiseName), nil + case ECOPType: + return []byte(EcOpName), nil + case PoseidonType: + return []byte(PoseidonName), nil + case AddModeType: + return []byte("Add" + ModuloName), nil + case MulModType: + return []byte("Mul" + ModuloName), nil + case SegmentArenaType: + return []byte(SegmentArenaName), nil + + } + return nil, fmt.Errorf("marshal unknown builtin: %d", uint8(b)) +} + +func (b *BuiltinType) UnmarshalJSON(data []byte) error { + builtinName, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("unmarshal builtin: %w", err) + } + + switch builtinName { + case OutputName: + *b = OutputType + case RangeCheckName: + *b = RangeCheckType + case RangeCheck96Name: + *b = RangeCheck96Type + case PedersenName: + *b = PedersenType + case ECDSAName: + *b = ECDSAType + case KeccakName: + *b = KeccakType + case BitwiseName: + *b = BitwiseType + case EcOpName: + *b = ECOPType + case PoseidonName: + *b = PoseidonType + case "Add" + ModuloName: + *b = AddModeType + case "Mul" + ModuloName: + *b = MulModType + case SegmentArenaName: + *b = SegmentArenaType + default: + return fmt.Errorf("unmarshal unknown builtin: %s", builtinName) + } + return nil +} diff --git a/pkg/vm/builtins/ecdsa.go b/pkg/vm/builtins/ecdsa.go index 02fe7f3e5..7aae622bb 100644 --- a/pkg/vm/builtins/ecdsa.go +++ b/pkg/vm/builtins/ecdsa.go @@ -2,6 +2,7 @@ package builtins import ( "fmt" + "math/big" "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -10,14 +11,15 @@ import ( "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -const ECDSAName = "ecdsa" -const inputCellsPerECDSA = 2 -const cellsPerECDSA = 2 - -const instancesPerComponentECDSA = 1 +const ( + ECDSAName = "ecdsa" + inputCellsPerECDSA = 2 + cellsPerECDSA = 2 + instancesPerComponentECDSA = 1 +) type ECDSA struct { - signatures map[uint64]ecdsa.Signature + Signatures map[uint64]ecdsa.Signature ratio uint64 } @@ -58,7 +60,7 @@ func (e *ECDSA) CheckWrite(segment *memory.Segment, offset uint64, value *memory } pubKey := &ecdsa.PublicKey{A: key} - sig, ok := e.signatures[pubOffset] + sig, ok := e.Signatures[pubOffset] if !ok { return fmt.Errorf("signature is missing from ECDSA builtin") } @@ -117,8 +119,8 @@ Hint that will call this function looks like this: }, */ func (e *ECDSA) AddSignature(pubOffset uint64, r, s *fp.Element) error { - if e.signatures == nil { - e.signatures = make(map[uint64]ecdsa.Signature) + if e.Signatures == nil { + e.Signatures = make(map[uint64]ecdsa.Signature) } bytes := make([]byte, 0, 64) rBytes := r.Bytes() @@ -132,7 +134,7 @@ func (e *ECDSA) AddSignature(pubOffset uint64, r, s *fp.Element) error { return err } - e.signatures[pubOffset] = sig + e.Signatures[pubOffset] = sig return nil } @@ -162,3 +164,49 @@ func recoverY(x *fp.Element) (fp.Element, fp.Element, error) { negY.Neg(y) return *y, negY, nil } + +type AirPrivateBuiltinECDSASignatureInput struct { + R string `json:"r"` + W string `json:"w"` +} + +type AirPrivateBuiltinECDSA struct { + Index int `json:"index"` + PubKey string `json:"pubkey"` + Msg string `json:"msg"` + SignatureInput AirPrivateBuiltinECDSASignatureInput `json:"signature_input"` +} + +func (e *ECDSA) GetAirPrivateInput(ecdsaSegment *memory.Segment) ([]AirPrivateBuiltinECDSA, error) { + values := make([]AirPrivateBuiltinECDSA, 0) + for addrOffset, signature := range e.Signatures { + idx := addrOffset / cellsPerECDSA + pubKey, err := ecdsaSegment.Read(addrOffset) + if err != nil { + return values, err + } + msg, err := ecdsaSegment.Read(addrOffset + 1) + if err != nil { + return values, err + } + + pubKeyBig := big.Int{} + msgBig := big.Int{} + pubKey.Felt.BigInt(&pubKeyBig) + msg.Felt.BigInt(&msgBig) + pubKeyHex := fmt.Sprintf("0x%x", &pubKeyBig) + msgHex := fmt.Sprintf("0x%x", &msgBig) + + rBig := new(big.Int).SetBytes(signature.R[:]) + sBig := new(big.Int).SetBytes(signature.S[:]) + frModulusBig, _ := new(big.Int).SetString("3618502788666131213697322783095070105526743751716087489154079457884512865583", 10) + wBig := new(big.Int).ModInverse(sBig, frModulusBig) + signatureInput := AirPrivateBuiltinECDSASignatureInput{ + R: fmt.Sprintf("0x%x", rBig), + W: fmt.Sprintf("0x%x", wBig), + } + + values = append(values, AirPrivateBuiltinECDSA{Index: int(idx), PubKey: pubKeyHex, Msg: msgHex, SignatureInput: signatureInput}) + } + return values, nil +} diff --git a/pkg/vm/builtins/ecop.go b/pkg/vm/builtins/ecop.go index 211d2fb49..6d28d1108 100644 --- a/pkg/vm/builtins/ecop.go +++ b/pkg/vm/builtins/ecop.go @@ -3,6 +3,8 @@ package builtins import ( "errors" "fmt" + "math/big" + "sort" "github.com/NethermindEth/cairo-vm-go/pkg/utils" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -10,10 +12,12 @@ import ( "github.com/holiman/uint256" ) -const EcOpName = "ec_op" -const cellsPerEcOp = 7 -const inputCellsPerEcOp = 5 -const instancesPerComponentEcOp = 1 +const ( + EcOpName = "ec_op" + cellsPerEcOp = 7 + inputCellsPerEcOp = 5 + instancesPerComponentEcOp = 1 +) var feltThree fp.Element = fp.Element( []uint64{ @@ -234,3 +238,59 @@ func ecdouble(p *point, alpha *fp.Element) point { return point{x, y} } + +type AirPrivateBuiltinEcOp struct { + Index int `json:"index"` + PX string `json:"p_x"` + PY string `json:"p_y"` + M string `json:"m"` + QX string `json:"q_x"` + QY string `json:"q_y"` +} + +func (e *EcOp) GetAirPrivateInput(ecOpSegment *mem.Segment) []AirPrivateBuiltinEcOp { + valueMapping := make(map[int]AirPrivateBuiltinEcOp) + for index, value := range ecOpSegment.Data { + if !value.Known() { + continue + } + idx, typ := index/cellsPerEcOp, index%cellsPerEcOp + if typ >= inputCellsPerEcOp { + continue + } + + builtinValue, exists := valueMapping[idx] + if !exists { + builtinValue = AirPrivateBuiltinEcOp{Index: idx} + } + + valueBig := big.Int{} + value.Felt.BigInt(&valueBig) + valueHex := fmt.Sprintf("0x%x", &valueBig) + if typ == 0 { + builtinValue.PX = valueHex + } else if typ == 1 { + builtinValue.PY = valueHex + } else if typ == 2 { + builtinValue.QX = valueHex + } else if typ == 3 { + builtinValue.QY = valueHex + } else if typ == 4 { + builtinValue.M = valueHex + } + valueMapping[idx] = builtinValue + } + + values := make([]AirPrivateBuiltinEcOp, 0) + + sortedIndexes := make([]int, 0, len(valueMapping)) + for index := range valueMapping { + sortedIndexes = append(sortedIndexes, index) + } + sort.Ints(sortedIndexes) + for _, index := range sortedIndexes { + value := valueMapping[index] + values = append(values, value) + } + return values +} diff --git a/pkg/vm/builtins/keccak.go b/pkg/vm/builtins/keccak.go index 00e135fbf..e7dcb7767 100644 --- a/pkg/vm/builtins/keccak.go +++ b/pkg/vm/builtins/keccak.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "errors" "fmt" + "math/big" + "sort" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -17,10 +19,12 @@ import ( // It's useful to give users options to use Keccak just as Rust VM does it with it's keccak.cairo as library. // -const KeccakName = "keccak" -const cellsPerKeccak = 16 -const inputCellsPerKeccak = 8 -const instancesPerComponentKeccak = 16 +const ( + KeccakName = "keccak" + cellsPerKeccak = 16 + inputCellsPerKeccak = 8 + instancesPerComponentKeccak = 16 +) type Keccak struct { ratio uint64 @@ -89,3 +93,68 @@ func (k *Keccak) String() string { func (k *Keccak) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) { return getBuiltinAllocatedSize(segmentUsedSize, vmCurrentStep, k.ratio, inputCellsPerKeccak, instancesPerComponentKeccak, cellsPerKeccak) } + +type AirPrivateBuiltinKeccak struct { + Index int `json:"index"` + InputS0 string `json:"input_s0"` + InputS1 string `json:"input_s1"` + InputS2 string `json:"input_s2"` + InputS3 string `json:"input_s3"` + InputS4 string `json:"input_s4"` + InputS5 string `json:"input_s5"` + InputS6 string `json:"input_s6"` + InputS7 string `json:"input_s7"` +} + +func (k *Keccak) GetAirPrivateInput(keccakSegment *memory.Segment) []AirPrivateBuiltinKeccak { + valueMapping := make(map[int]AirPrivateBuiltinKeccak) + for index, value := range keccakSegment.Data { + if !value.Known() { + continue + } + idx, stateIndex := index/cellsPerKeccak, index%cellsPerKeccak + if stateIndex >= inputCellsPerKeccak { + continue + } + + builtinValue, exists := valueMapping[idx] + if !exists { + builtinValue = AirPrivateBuiltinKeccak{Index: idx} + } + + valueBig := big.Int{} + value.Felt.BigInt(&valueBig) + valueHex := fmt.Sprintf("0x%x", &valueBig) + if stateIndex == 0 { + builtinValue.InputS0 = valueHex + } else if stateIndex == 1 { + builtinValue.InputS1 = valueHex + } else if stateIndex == 2 { + builtinValue.InputS2 = valueHex + } else if stateIndex == 3 { + builtinValue.InputS3 = valueHex + } else if stateIndex == 4 { + builtinValue.InputS4 = valueHex + } else if stateIndex == 5 { + builtinValue.InputS5 = valueHex + } else if stateIndex == 6 { + builtinValue.InputS6 = valueHex + } else if stateIndex == 7 { + builtinValue.InputS7 = valueHex + } + valueMapping[idx] = builtinValue + } + + values := make([]AirPrivateBuiltinKeccak, 0) + + sortedIndexes := make([]int, 0, len(valueMapping)) + for index := range valueMapping { + sortedIndexes = append(sortedIndexes, index) + } + sort.Ints(sortedIndexes) + for _, index := range sortedIndexes { + value := valueMapping[index] + values = append(values, value) + } + return values +} diff --git a/pkg/vm/builtins/layouts.go b/pkg/vm/builtins/layouts.go index 2a464ff42..dba5f1d76 100644 --- a/pkg/vm/builtins/layouts.go +++ b/pkg/vm/builtins/layouts.go @@ -3,7 +3,6 @@ package builtins import ( "fmt" - "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -12,7 +11,7 @@ type LayoutBuiltin struct { // Runner for the builtin Runner memory.BuiltinRunner // Builtin id from starknet parser - Builtin starknet.Builtin + Builtin BuiltinType } type Layout struct { @@ -24,40 +23,131 @@ type Layout struct { Builtins []LayoutBuiltin } +func getPlainLayout() Layout { + return Layout{Name: "plain", RcUnits: 16, Builtins: []LayoutBuiltin{}} +} + func getSmallLayout() Layout { return Layout{Name: "small", RcUnits: 16, Builtins: []LayoutBuiltin{ - {Runner: &Output{}, Builtin: starknet.Output}, - {Runner: &Pedersen{ratio: 8}, Builtin: starknet.Pedersen}, - {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: starknet.RangeCheck}, - {Runner: &ECDSA{ratio: 512}, Builtin: starknet.ECDSA}, + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 8}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &ECDSA{ratio: 512}, Builtin: ECDSAType}, }} } -func getPlainLayout() Layout { - return Layout{Name: "plain", RcUnits: 16, Builtins: []LayoutBuiltin{}} +func getDexLayout() Layout { + return Layout{Name: "dex", RcUnits: 4, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 8}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &ECDSA{ratio: 512}, Builtin: ECDSAType}, + }} +} + +func getRecursiveLayout() Layout { + return Layout{Name: "recursive", RcUnits: 4, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 128}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &Bitwise{ratio: 8}, Builtin: BitwiseType}, + }} +} + +func getStarknetLayout() Layout { + return Layout{Name: "starknet", RcUnits: 4, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 32}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 16, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &ECDSA{ratio: 2048}, Builtin: ECDSAType}, + {Runner: &Bitwise{ratio: 64}, Builtin: BitwiseType}, + {Runner: &EcOp{ratio: 1024, cache: make(map[uint64]fp.Element)}, Builtin: ECOPType}, + {Runner: &Poseidon{ratio: 32, cache: make(map[uint64]fp.Element)}, Builtin: PoseidonType}, + }} } func getStarknetWithKeccakLayout() Layout { return Layout{Name: "starknet_with_keccak", RcUnits: 4, Builtins: []LayoutBuiltin{ - {Runner: &Output{}, Builtin: starknet.Output}, - {Runner: &Pedersen{ratio: 32}, Builtin: starknet.Pedersen}, - {Runner: &RangeCheck{ratio: 16, RangeCheckNParts: 8}, Builtin: starknet.RangeCheck}, - {Runner: &ECDSA{ratio: 2048}, Builtin: starknet.ECDSA}, - {Runner: &Bitwise{ratio: 64}, Builtin: starknet.Bitwise}, - {Runner: &EcOp{ratio: 1024, cache: make(map[uint64]fp.Element)}, Builtin: starknet.ECOP}, - {Runner: &Keccak{ratio: 2048, cache: make(map[uint64]fp.Element)}, Builtin: starknet.Keccak}, - {Runner: &Poseidon{ratio: 32, cache: make(map[uint64]fp.Element)}, Builtin: starknet.Poseidon}, + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 32}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 16, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &ECDSA{ratio: 2048}, Builtin: ECDSAType}, + {Runner: &Bitwise{ratio: 64}, Builtin: BitwiseType}, + {Runner: &EcOp{ratio: 1024, cache: make(map[uint64]fp.Element)}, Builtin: ECOPType}, + {Runner: &Keccak{ratio: 2048, cache: make(map[uint64]fp.Element)}, Builtin: KeccakType}, + {Runner: &Poseidon{ratio: 32, cache: make(map[uint64]fp.Element)}, Builtin: PoseidonType}, + }} +} + +func getRecursiveLargeOutputLayout() Layout { + return Layout{Name: "recursive_large_output", RcUnits: 4, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 128}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &Bitwise{ratio: 8}, Builtin: BitwiseType}, + {Runner: &Poseidon{ratio: 8, cache: make(map[uint64]fp.Element)}, Builtin: PoseidonType}, + }} +} + +func getRecursiveWithPoseidonLayout() Layout { + return Layout{Name: "recursive_with_poseidon", RcUnits: 4, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 256}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 16, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &Bitwise{ratio: 16}, Builtin: BitwiseType}, + {Runner: &Poseidon{ratio: 64, cache: make(map[uint64]fp.Element)}, Builtin: PoseidonType}, + }} +} + +func getAllSolidityLayout() Layout { + return Layout{Name: "recursive_with_poseidon", RcUnits: 8, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 8}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &ECDSA{ratio: 512}, Builtin: ECDSAType}, + {Runner: &Bitwise{ratio: 256}, Builtin: BitwiseType}, + {Runner: &EcOp{ratio: 256, cache: make(map[uint64]fp.Element)}, Builtin: ECOPType}, + }} +} + +func getAllCairoLayout() Layout { + return Layout{Name: "all_cairo", RcUnits: 8, Builtins: []LayoutBuiltin{ + {Runner: &Output{}, Builtin: OutputType}, + {Runner: &Pedersen{ratio: 256}, Builtin: PedersenType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 8}, Builtin: RangeCheckType}, + {Runner: &ECDSA{ratio: 2048}, Builtin: ECDSAType}, + {Runner: &Bitwise{ratio: 16}, Builtin: BitwiseType}, + {Runner: &EcOp{ratio: 1024, cache: make(map[uint64]fp.Element)}, Builtin: ECOPType}, + {Runner: &Keccak{ratio: 2048, cache: make(map[uint64]fp.Element)}, Builtin: KeccakType}, + {Runner: &Poseidon{ratio: 256, cache: make(map[uint64]fp.Element)}, Builtin: PoseidonType}, + {Runner: &RangeCheck{ratio: 8, RangeCheckNParts: 6}, Builtin: RangeCheck96Type}, + {Runner: &ModBuiltin{ratio: 128, wordBitLen: 96, batchSize: 1, modBuiltinType: Add}, Builtin: AddModeType}, + {Runner: &ModBuiltin{ratio: 256, wordBitLen: 96, batchSize: 1, modBuiltinType: Mul}, Builtin: MulModType}, }} } func GetLayout(layout string) (Layout, error) { switch layout { - case "small": - return getSmallLayout(), nil case "plain": return getPlainLayout(), nil + case "small": + return getSmallLayout(), nil + case "dex": + return getDexLayout(), nil + case "recursive": + return getRecursiveLayout(), nil + case "starknet": + return getStarknetLayout(), nil case "starknet_with_keccak": return getStarknetWithKeccakLayout(), nil + case "recursive_large_output": + return getRecursiveLargeOutputLayout(), nil + case "recursive_with_poseidon": + return getRecursiveWithPoseidonLayout(), nil + case "all_solidity": + return getAllSolidityLayout(), nil + case "all_cairo": + return getAllCairoLayout(), nil case "": return getPlainLayout(), nil default: diff --git a/pkg/vm/builtins/modulo.go b/pkg/vm/builtins/modulo.go new file mode 100644 index 000000000..450330724 --- /dev/null +++ b/pkg/vm/builtins/modulo.go @@ -0,0 +1,599 @@ +package builtins + +import ( + "fmt" + "math/big" + "strings" + + "github.com/NethermindEth/cairo-vm-go/pkg/utils" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +const ModuloName = "Mod" + +// These are the offsets in the array, which is used here as ModBuiltinInputs : +// INPUT_NAMES = [ +// +// "p0", +// "p1", +// "p2", +// "p3", +// "values_ptr", +// "offsets_ptr", +// "n", +// +// ] +const VALUES_PTR_OFFSET = 4 +const OFFSETS_PTR_OFFSET = 5 +const N_OFFSET = 6 + +// This is the number of felts in a UInt384 struct +const N_WORDS = 4 + +// number of memory cells per modulo builtin +// 4(felts) + 1(values_ptr) + 1(offsets_ptr) + 1(n) = 7 +const CELLS_PER_MOD = 7 + +// The maximum n value that the function fill_memory accepts +const MAX_N = 100000 + +// Represents a 384-bit unsigned integer d0 + 2**96 * d1 + 2**192 * d2 + 2**288 * d3 +// where each di is in [0, 2**96). +// +// struct UInt384 { +// d0: felt, +// d1: felt, +// d2: felt, +// d3: felt, +// } +// Instead of introducing UInt384, we use [N_WORDS]fp.Element to represent the 384-bit integer. + +type ModBuiltinInputs struct { + // The modulus. + p big.Int + pValues [N_WORDS]fp.Element + // A pointer to input values, the intermediate results and the output. + valuesPtr memory.MemoryAddress + // A pointer to offsets inside the values array, defining the circuit. + // The offsets array should contain 3 * n elements. + offsetsPtr memory.MemoryAddress + // The number of operations to perform. + n uint64 +} + +type ModBuiltinType string + +const ( + Add ModBuiltinType = "Add" + Mul ModBuiltinType = "Mul" +) + +type ModBuiltin struct { + ratio uint64 + // Add | Mul + modBuiltinType ModBuiltinType + // number of bits in a word + wordBitLen uint64 + batchSize uint64 + // shift by the number of bits present in a word + shift big.Int + // powers required to do the corresponding shift + shiftPowers [N_WORDS]big.Int + // k value that bounds p when finding unknown value in fillValue function + kBound *big.Int +} + +func NewModBuiltin(ratio uint64, wordBitLen uint64, batchSize uint64, modBuiltinType ModBuiltinType) *ModBuiltin { + shift := new(big.Int).Lsh(big.NewInt(1), uint(wordBitLen)) + shiftPowers := [N_WORDS]big.Int{} + shiftPowers[0] = *big.NewInt(1) + for i := 1; i < N_WORDS; i++ { + shiftPowers[i].Mul(&shiftPowers[i-1], shift) + } + kBound := big.NewInt(2) + if modBuiltinType == Mul { + kBound = nil + } + return &ModBuiltin{ + ratio: ratio, + modBuiltinType: modBuiltinType, + wordBitLen: wordBitLen, + batchSize: batchSize, + shift: *shift, + shiftPowers: shiftPowers, + kBound: kBound, + } +} + +func (m *ModBuiltin) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error { + return nil +} + +func (m *ModBuiltin) InferValue(segment *memory.Segment, offset uint64) error { + return fmt.Errorf("can't infer value") +} + +func (m *ModBuiltin) String() string { + return string(m.modBuiltinType) + ModuloName +} + +func (m *ModBuiltin) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) { + return 0, nil +} + +// Reads N_WORDS from memory, starting at address = addr. +// Returns the words and the value if all words are in memory. +// Verifies that all words are integers and are bounded by 2**wordBitLen. +func (m *ModBuiltin) readNWordsValue(memory *memory.Memory, addr memory.MemoryAddress) ([N_WORDS]fp.Element, *big.Int, error) { + var words [N_WORDS]fp.Element + value := new(big.Int).SetInt64(0) + + for i := 0; i < N_WORDS; i++ { + newAddr, err := addr.AddOffset(int16(i)) + if err != nil { + return [N_WORDS]fp.Element{}, nil, err + } + + wordFelt, err := memory.ReadAsElement(newAddr.SegmentIndex, newAddr.Offset) + if err != nil { + return [N_WORDS]fp.Element{}, nil, err + } + + var word big.Int + wordFelt.BigInt(&word) + if word.Cmp(&m.shift) >= 0 { + return [N_WORDS]fp.Element{}, nil, fmt.Errorf("expected integer at address %d:%d to be smaller than 2^%d. Got: %s", newAddr.SegmentIndex, newAddr.Offset, m.wordBitLen, word.String()) + } + + words[i] = wordFelt + value = new(big.Int).Add(value, new(big.Int).Mul(&word, &m.shiftPowers[i])) + } + + return words, value, nil +} + +// Reads the inputs to the builtin (p, p_values, values_ptr, offsets_ptr, n) from the memory at address = addr. +// Returns an instance of ModBuiltinInputs and asserts that it exists in memory. +// If `read_n` is false, avoid reading and validating the value of 'n'. +func (m *ModBuiltin) readInputs(mem *memory.Memory, addr memory.MemoryAddress, read_n bool) (ModBuiltinInputs, error) { + valuesPtrAddr, err := addr.AddOffset(int16(VALUES_PTR_OFFSET)) + if err != nil { + return ModBuiltinInputs{}, err + } + valuesPtr, err := mem.ReadAsAddress(&valuesPtrAddr) + if err != nil { + return ModBuiltinInputs{}, err + } + offsetsPtrAddr, err := addr.AddOffset(int16(OFFSETS_PTR_OFFSET)) + if err != nil { + return ModBuiltinInputs{}, err + } + offsetsPtr, err := mem.ReadAsAddress(&offsetsPtrAddr) + if err != nil { + return ModBuiltinInputs{}, err + } + n := uint64(0) + if read_n { + nFelt, err := mem.ReadAsElement(addr.SegmentIndex, addr.Offset+N_OFFSET) + if err != nil { + return ModBuiltinInputs{}, err + } + n = nFelt.Uint64() + if n < 1 { + return ModBuiltinInputs{}, fmt.Errorf("moduloBuiltin: Expected n >= 1. Got: %d", n) + } + } + pValues, p, err := m.readNWordsValue(mem, addr) + if err != nil { + return ModBuiltinInputs{}, err + } + return ModBuiltinInputs{ + p: *p, + pValues: pValues, + valuesPtr: valuesPtr, + n: n, + offsetsPtr: offsetsPtr, + }, nil +} + +// Fills the inputs to the instances of the builtin given the inputs to the first instance. +func (m *ModBuiltin) fillInputs(mem *memory.Memory, builtinPtr memory.MemoryAddress, inputs ModBuiltinInputs) error { + if inputs.n > MAX_N { + return fmt.Errorf("fill memory max exceeded") + } + + nInstances, err := utils.SafeDivUint64(inputs.n, m.batchSize) + if err != nil { + return err + } + + for instance := 1; instance < int(nInstances); instance++ { + instancePtr, err := builtinPtr.AddOffset(int16(instance * CELLS_PER_MOD)) + if err != nil { + return err + } + + // Filling the 4 values of a UInt384 struct + for i := 0; i < N_WORDS; i++ { + addr, err := instancePtr.AddOffset(int16(i)) + if err != nil { + return err + } + mv := memory.MemoryValueFromFieldElement(&inputs.pValues[i]) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + } + + addr, err := instancePtr.AddOffset(VALUES_PTR_OFFSET) + if err != nil { + return err + } + mv := memory.MemoryValueFromMemoryAddress(&inputs.valuesPtr) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + + addr, err = instancePtr.AddOffset(OFFSETS_PTR_OFFSET) + if err != nil { + return err + } + newAddr, err := inputs.offsetsPtr.AddOffset(3 * int16(instance) * int16(m.batchSize)) + if err != nil { + return err + } + mv = memory.MemoryValueFromMemoryAddress(&newAddr) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + + // This denotes the number of operations left + // n for new instance = original n - batch_size * (number of instances passed) + addr, err = instancePtr.AddOffset(N_OFFSET) + if err != nil { + return err + } + val := fp.NewElement(inputs.n - m.batchSize*uint64(instance)) + mv = memory.MemoryValueFromFieldElement(&val) + if err := mem.WriteToAddress(&addr, &mv); err != nil { + return err + } + } + + return nil +} + +// Copies the first offsets into memory, nCopies times. +func (m *ModBuiltin) fillOffsets(mem *memory.Memory, offsetsPtr memory.MemoryAddress, index, nCopies uint64) error { + if nCopies == 0 { + return nil + } + + for i := 0; i < 3; i++ { + addr, err := offsetsPtr.AddOffset(int16(i)) + if err != nil { + return err + } + + offset, err := mem.ReadAsAddress(&addr) + if err != nil { + return err + } + + for copyI := 0; copyI < int(nCopies); copyI++ { + copyAddr, err := offsetsPtr.AddOffset(int16(3*(index+uint64(copyI)) + uint64(i))) + if err != nil { + return err + } + mv := memory.MemoryValueFromMemoryAddress(&offset) + if err := mem.WriteToAddress(©Addr, &mv); err != nil { + return err + } + } + } + + return nil +} + +// Given a value, writes its n_words to memory, starting at address = addr. +func (m *ModBuiltin) writeNWordsValue(mem *memory.Memory, addr memory.MemoryAddress, value big.Int) error { + for i := 0; i < N_WORDS; i++ { + word := new(big.Int).Mod(&value, &m.shift) + modAddr, err := addr.AddOffset(int16(i)) + if err != nil { + return err + } + mv := memory.MemoryValueFromFieldElement(new(fp.Element).SetBigInt(word)) + if err := mem.WriteToAddress(&modAddr, &mv); err != nil { + return err + } + value.Div(&value, &m.shift) + } + if value.Sign() != 0 { + return fmt.Errorf("writeNWordsValue: value should be zero") + } + return nil +} + +// Fills a value in the values table, if exactly one value is missing. +// Returns 1 on success or if all values are already known. +// Returns 0 if there is an error or is the value cannot be filled +// Returns 2 when the mulModBuiltin has a zero divisor. +// Given known, res, p fillValue tries to compute the minimal integer operand x which +// satisfies the equation op(x,known) = res + k*p for some k in {0,1,...,self.k_bound-1}. +func (m *ModBuiltin) fillValue(mem *memory.Memory, inputs ModBuiltinInputs, index int, op ModBuiltinType) (int, error) { + addresses := make([]memory.MemoryAddress, 0, 3) + values := make([]*big.Int, 0, 3) + + for i := 0; i < 3; i++ { + addr, err := inputs.offsetsPtr.AddOffset(int16(3*index + i)) + if err != nil { + return 0, err + } + offsetFelt, err := mem.ReadAsElement(addr.SegmentIndex, addr.Offset) + if err != nil { + return 0, err + } + offset := offsetFelt.Uint64() + addr, err = inputs.valuesPtr.AddOffset(int16(offset)) + if err != nil { + return 0, err + } + addresses = append(addresses, addr) + // do not check for all errors, as the value might not be in memory + // only check for the error when the value in memory exceeds 2**wordBitLen + _, value, err := m.readNWordsValue(mem, addr) + if err != nil { + if strings.Contains(err.Error(), "expected integer at address") { + return 0, err + } + } + values = append(values, value) + } + + a, b, c := values[0], values[1], values[2] + + // 2 ** 384 (max value that can be stored in 4 felts) + intLim := new(big.Int).Lsh(big.NewInt(1), uint(m.wordBitLen)*N_WORDS) + kBound := m.kBound + if kBound == nil { + kBound = new(big.Int).Set(intLim) + } + switch { + case a != nil && b != nil && c == nil: + var value big.Int + if op == Add { + value = *new(big.Int).Add(a, b) + } else { + value = *new(big.Int).Mul(a, b) + } + // value - (kBound - 1) * p <= intLim - 1 + if new(big.Int).Sub(&value, new(big.Int).Mul((new(big.Int).Sub(kBound, big.NewInt(1))), &inputs.p)).Cmp(new(big.Int).Sub(intLim, big.NewInt(1))) == 1 { + return 0, fmt.Errorf("%s builtin: Expected a %s b - %d * p <= %d", m.String(), m.modBuiltinType, kBound.Sub(kBound, big.NewInt(1)), intLim.Sub(intLim, big.NewInt(1))) + } + if value.Cmp(new(big.Int).Mul(kBound, &inputs.p)) < 0 { + value.Mod(&value, &inputs.p) + } else { + value.Sub(&value, new(big.Int).Mul(new(big.Int).Sub(kBound, big.NewInt(1)), &inputs.p)) + } + if err := m.writeNWordsValue(mem, addresses[2], value); err != nil { + return 0, err + } + return 1, nil + case a != nil && b == nil && c != nil: + zeroDivisor := false + var value big.Int + if op == Add { + // Right now only k = 2 is an option, hence as we stated above that x + known can only take values + // from res to res + (k - 1) * p, hence known <= res + p + if a.Cmp(new(big.Int).Add(c, &inputs.p)) > 0 { + return 0, fmt.Errorf("%s builtin: addend greater than sum + p: %d > %d + %d", m.String(), a, c, &inputs.p) + } else { + if a.Cmp(c) <= 0 { + value = *new(big.Int).Sub(c, a) + } else { + value = *new(big.Int).Sub(c.Add(c, &inputs.p), a) + } + } + } else { + x, _, gcd := utils.Igcdex(a, &inputs.p) + // if gcd != 1, the known value is 0, in which case the res must be 0 + if gcd.Cmp(big.NewInt(1)) != 0 { + zeroDivisor = true + value = *new(big.Int).Div(&inputs.p, &gcd) + } else { + value = *new(big.Int).Mul(c, &x) + value = *value.Mod(&value, &inputs.p) + tmpK, err := utils.SafeDiv(new(big.Int).Sub(new(big.Int).Mul(a, &value), c), &inputs.p) + if err != nil { + return 0, err + } + if tmpK.Cmp(kBound) >= 0 { + return 0, fmt.Errorf("%s builtin: ((%d * q) - %d) / %d > %d for any q > 0, such that %d * q = %d (mod %d) ", m.String(), a, c, &inputs.p, kBound, a, c, &inputs.p) + } + if tmpK.Cmp(big.NewInt(0)) < 0 { + value = *value.Add(&value, new(big.Int).Mul(&inputs.p, new(big.Int).Div(new(big.Int).Sub(a, new(big.Int).Sub(&tmpK, big.NewInt(1))), a))) + } + } + } + if err := m.writeNWordsValue(mem, addresses[1], value); err != nil { + return 0, err + } + if zeroDivisor { + return 2, nil + } + return 1, nil + case a == nil && b != nil && c != nil: + zeroDivisor := false + var value big.Int + if op == Add { + // Right now only k = 2 is an option, hence as we stated above that x + known can only take values + // from res to res + (k - 1) * p, hence known <= res + p + if b.Cmp(new(big.Int).Add(c, &inputs.p)) > 0 { + return 0, fmt.Errorf("%s builtin: addend greater than sum + p: %d > %d + %d", m.String(), b, c, &inputs.p) + } else { + if b.Cmp(c) <= 0 { + value = *new(big.Int).Sub(c, b) + } else { + value = *new(big.Int).Sub(c.Add(c, &inputs.p), b) + } + } + } else { + x, _, gcd := utils.Igcdex(b, &inputs.p) + // if gcd != 1, the known value is 0, in which case the res must be 0 + if gcd.Cmp(big.NewInt(1)) != 0 { + zeroDivisor = true + value = *new(big.Int).Div(&inputs.p, &gcd) + } else { + value = *new(big.Int).Mul(c, &x) + value = *value.Mod(&value, &inputs.p) + tmpK, err := utils.SafeDiv(new(big.Int).Sub(new(big.Int).Mul(b, &value), c), &inputs.p) + if err != nil { + return 0, err + } + if tmpK.Cmp(kBound) >= 0 { + return 0, fmt.Errorf("%s builtin: ((%d * q) - %d) / %d > %d for any q > 0, such that %d * q = %d (mod %d) ", m.String(), b, c, &inputs.p, kBound, b, c, &inputs.p) + } + if tmpK.Cmp(big.NewInt(0)) < 0 { + value = *value.Add(&value, new(big.Int).Mul(&inputs.p, new(big.Int).Div(new(big.Int).Sub(b, new(big.Int).Sub(&tmpK, big.NewInt(1))), b))) + } + } + } + if err := m.writeNWordsValue(mem, addresses[0], value); err != nil { + return 0, err + } + if zeroDivisor { + return 2, nil + } + return 1, nil + case a != nil && b != nil && c != nil: + return 1, nil + default: + return 0, nil + } +} + +// Fills the memory with inputs to the builtin instances based on the inputs to the +// first instance, pads the offsets table to fit the number of operations written in the +// input to the first instance, and calculates missing values in the values table. +// +// The number of operations written to the input of the first instance n should be at +// least n and a multiple of batch_size. Previous offsets are copied to the end of the +// offsets table to make its length 3n'. +func FillMemory(mem *memory.Memory, addModInputAddress memory.MemoryAddress, nAddMods uint64, mulModInputAddress memory.MemoryAddress, nMulMods uint64) error { + if nAddMods > MAX_N { + return fmt.Errorf("AddMod builtin: n must be <= {MAX_N}") + } + if nMulMods > MAX_N { + return fmt.Errorf("MulMod builtin: n must be <= {MAX_N}") + } + + var addModBuiltinRunner *ModBuiltin + var mulModBuiltinRunner *ModBuiltin + var addModBuiltinInputs, mulModBuiltinInputs ModBuiltinInputs + var err error + + if nAddMods != 0 { + addModBuiltinSegment, ok := mem.FindSegmentWithBuiltin("AddMod") + if !ok { + return fmt.Errorf("AddMod builtin segment doesn't exist") + } + addModBuiltinRunner, ok = addModBuiltinSegment.BuiltinRunner.(*ModBuiltin) + if !ok { + return fmt.Errorf("addModBuiltinRunner is not a ModBuiltin") + } + + addModBuiltinInputs, err = addModBuiltinRunner.readInputs(mem, addModInputAddress, true) + if err != nil { + return err + } + if err := addModBuiltinRunner.fillInputs(mem, addModInputAddress, addModBuiltinInputs); err != nil { + return err + } + if err := addModBuiltinRunner.fillOffsets(mem, addModBuiltinInputs.offsetsPtr, nAddMods, addModBuiltinInputs.n-nAddMods); err != nil { + return err + } + } else { + addModBuiltinRunner = nil + } + + if nMulMods != 0 { + mulModBuiltinSegment, ok := mem.FindSegmentWithBuiltin("MulMod") + if !ok { + return fmt.Errorf("MulMod builtin segment doesn't exist") + } + mulModBuiltinRunner, ok = mulModBuiltinSegment.BuiltinRunner.(*ModBuiltin) + if !ok { + return fmt.Errorf("mulModBuiltinRunner is not a ModBuiltin") + } + + mulModBuiltinInputs, err = mulModBuiltinRunner.readInputs(mem, mulModInputAddress, true) + if err != nil { + return err + } + } else { + mulModBuiltinRunner = nil + } + + addModIndex, mulModIndex := uint64(0), uint64(0) + nComputedMulGates := uint64(0) + for addModIndex < nAddMods || mulModIndex < nMulMods { + if addModIndex < nAddMods && addModBuiltinRunner != nil { + res, err := addModBuiltinRunner.fillValue(mem, addModBuiltinInputs, int(addModIndex), Add) + if err != nil { + return err + } + if res == 1 { + addModIndex++ + } + } + + if mulModIndex < nMulMods && mulModBuiltinRunner != nil { + res, err := mulModBuiltinRunner.fillValue(mem, mulModBuiltinInputs, int(mulModIndex), Mul) + if err != nil { + return err + } + if res == 0 { + return fmt.Errorf("MulMod builtin: Could not fill the values table") + } + if res == 2 && nComputedMulGates == 0 { + nComputedMulGates = mulModIndex + } + mulModIndex++ + } + } + + // TODO: Investigate tests that fail when nComputedMulGates is not implemented + if mulModBuiltinRunner != nil { + if nComputedMulGates == 0 { + nComputedMulGates = mulModBuiltinInputs.n + if nComputedMulGates == 0 { + nComputedMulGates = nMulMods + } + mulModBuiltinInputs.n = nComputedMulGates + if err := mulModBuiltinRunner.fillOffsets(mem, mulModBuiltinInputs.offsetsPtr, nMulMods, nComputedMulGates-nMulMods); err != nil { + return err + } + } else { + if mulModBuiltinRunner.batchSize != 1 { + return fmt.Errorf("MulMod builtin: Inverse failure is supported only at batch_size == 1") + } + } + mulModBuiltinInputs.n = nComputedMulGates + mulModInputNAddr, err := mulModInputAddress.AddOffset(int16(N_OFFSET)) + if err != nil { + return err + } + mv := memory.MemoryValueFromFieldElement(new(fp.Element).SetUint64(nComputedMulGates)) + if err := mem.WriteToAddress(&mulModInputNAddr, &mv); err != nil { + return err + } + + if err := mulModBuiltinRunner.fillInputs(mem, mulModInputAddress, mulModBuiltinInputs); err != nil { + return err + } + } + return nil +} diff --git a/pkg/vm/builtins/modulo_test.go b/pkg/vm/builtins/modulo_test.go new file mode 100644 index 000000000..875adcdb2 --- /dev/null +++ b/pkg/vm/builtins/modulo_test.go @@ -0,0 +1,158 @@ +package builtins + +import ( + // "fmt" + "math/big" + "testing" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/stretchr/testify/require" +) + +/* +Tests whether runner completes a trio a, b, c as the input implies: +If inverse is False it tests whether a = x1, b = x2, c = None will be completed with c = res. +If inverse is True it tests whether c = x1, b = x2, a = None will be completed with a = res. +*/ +func checkResult(runner ModBuiltin, inverse bool, p, x1, x2 big.Int) (*big.Int, error) { + mem := memory.Memory{} + + mem.AllocateBuiltinSegment(&runner) + + offsetsPtr := memory.MemoryAddress{SegmentIndex: 0, Offset: 0} + + for i := 0; i < 3; i++ { + offsetsPtrAddr, err := offsetsPtr.AddOffset(int16(i)) + if err != nil { + return nil, err + } + + mv := memory.MemoryValueFromInt(i * N_WORDS) + if err := mem.WriteToAddress(&offsetsPtrAddr, &mv); err != nil { + return nil, err + } + } + + valuesAddr := memory.MemoryAddress{SegmentIndex: 0, Offset: 24} + + x1Addr, err := valuesAddr.AddOffset(int16(0)) + if err != nil { + return nil, err + } + + x2Addr, err := valuesAddr.AddOffset(int16(N_WORDS)) + if err != nil { + return nil, err + } + err = runner.writeNWordsValue(&mem, x2Addr, x2) + if err != nil { + return nil, err + } + + resAddr, err := valuesAddr.AddOffset(int16(2 * N_WORDS)) + if err != nil { + return nil, err + } + + if inverse { + x1Addr, resAddr = resAddr, x1Addr + } + + err = runner.writeNWordsValue(&mem, x1Addr, x1) + if err != nil { + return nil, err + } + + _, err = runner.fillValue(&mem, ModBuiltinInputs{ + p: p, + pValues: [N_WORDS]fp.Element{}, // not used in fillValue + valuesPtr: valuesAddr, + n: 0, // not used in fillValue + offsetsPtr: offsetsPtr, + }, 0, runner.modBuiltinType) + + if err != nil { + return nil, err + } + + _, OutRes, err := runner.readNWordsValue(&mem, resAddr) + if err != nil { + return nil, err + } + + return OutRes, nil +} + +func TestAddModBuiltinRunnerAddition(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Add) + res1, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(17), *big.NewInt(40)) + require.NoError(t, err) + require.Equal(t, big.NewInt(57), res1) + res2, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(82), *big.NewInt(31)) + require.NoError(t, err) + require.Equal(t, big.NewInt(46), res2) + res3, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(68), *big.NewInt(69)) + require.NoError(t, err) + require.Equal(t, big.NewInt(70), res3) + res4, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(68), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(1), res4) + _, err = checkResult(*runner, false, *big.NewInt(4094), *big.NewInt(4095), *big.NewInt(4095)) + require.ErrorContains(t, err, "Expected a Add b - 1 * p <= 4095") +} + +func TestAddModBuiltinRunnerSubtraction(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Add) + res1, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(52), *big.NewInt(38)) + require.NoError(t, err) + require.Equal(t, big.NewInt(14), res1) + res2, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(5), *big.NewInt(68)) + require.NoError(t, err) + require.Equal(t, big.NewInt(4), res2) + res3, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(5), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(5), res3) + res4, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(0), *big.NewInt(5)) + require.NoError(t, err) + require.Equal(t, big.NewInt(62), res4) + _, err = checkResult(*runner, true, *big.NewInt(67), *big.NewInt(70), *big.NewInt(138)) + require.ErrorContains(t, err, "addend greater than sum + p") +} + +func TestMulModBuiltinRunnerMultiplication(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Mul) + res1, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(11), *big.NewInt(8)) + require.NoError(t, err) + require.Equal(t, big.NewInt(21), res1) + res2, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(68), *big.NewInt(69)) + require.NoError(t, err) + require.Equal(t, big.NewInt(2), res2) + res3, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(525), *big.NewInt(526)) + require.NoError(t, err) + require.Equal(t, big.NewInt(1785), res3) + res4, err := checkResult(*runner, false, *big.NewInt(67), *big.NewInt(525), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(0), res4) + _, err = checkResult(*runner, false, *big.NewInt(67), *big.NewInt(3777), *big.NewInt(3989)) + require.ErrorContains(t, err, "Expected a Mul b - 4095 * p <= 4095") +} + +func TestMulModBuiltinRunnerDivision(t *testing.T) { + runner := NewModBuiltin(1, 3, 1, Mul) + res1, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(36), *big.NewInt(9)) + require.NoError(t, err) + require.Equal(t, big.NewInt(4), res1) + res2, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(138), *big.NewInt(41)) + require.NoError(t, err) + require.Equal(t, big.NewInt(5), res2) + res3, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(272), *big.NewInt(41)) + require.NoError(t, err) + require.Equal(t, big.NewInt(72), res3) + res4, err := checkResult(*runner, true, *big.NewInt(67), *big.NewInt(0), *big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, big.NewInt(1), res4) + res5, err := checkResult(*runner, true, *big.NewInt(66), *big.NewInt(6), *big.NewInt(3)) + require.NoError(t, err) + require.Equal(t, big.NewInt(22), res5) +} diff --git a/pkg/vm/builtins/pedersen.go b/pkg/vm/builtins/pedersen.go index 6391e090a..195202303 100644 --- a/pkg/vm/builtins/pedersen.go +++ b/pkg/vm/builtins/pedersen.go @@ -3,15 +3,19 @@ package builtins import ( "errors" "fmt" + "math/big" + "sort" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" pedersenhash "github.com/consensys/gnark-crypto/ecc/stark-curve/pedersen-hash" ) -const PedersenName = "pedersen" -const cellsPerPedersen = 3 -const inputCellsPerPedersen = 2 -const instancesPerComponentPedersen = 1 +const ( + PedersenName = "pedersen" + cellsPerPedersen = 3 + inputCellsPerPedersen = 2 + instancesPerComponentPedersen = 1 +) type Pedersen struct { ratio uint64 @@ -63,3 +67,50 @@ func (p *Pedersen) String() string { func (p *Pedersen) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64) (uint64, error) { return getBuiltinAllocatedSize(segmentUsedSize, vmCurrentStep, p.ratio, inputCellsPerPedersen, instancesPerComponentPedersen, cellsPerPedersen) } + +type AirPrivateBuiltinPedersen struct { + Index int `json:"index"` + X string `json:"x"` + Y string `json:"y"` +} + +func (p *Pedersen) GetAirPrivateInput(pedersenSegment *mem.Segment) []AirPrivateBuiltinPedersen { + valueMapping := make(map[int]AirPrivateBuiltinPedersen) + for index, value := range pedersenSegment.Data { + if !value.Known() { + continue + } + idx, typ := index/cellsPerPedersen, index%cellsPerPedersen + if typ == 2 { + continue + } + + builtinValue, exists := valueMapping[idx] + if !exists { + builtinValue = AirPrivateBuiltinPedersen{Index: idx} + } + + valueBig := big.Int{} + value.Felt.BigInt(&valueBig) + valueHex := fmt.Sprintf("0x%x", &valueBig) + if typ == 0 { + builtinValue.X = valueHex + } else { + builtinValue.Y = valueHex + } + valueMapping[idx] = builtinValue + } + + values := make([]AirPrivateBuiltinPedersen, 0) + + sortedIndexes := make([]int, 0, len(valueMapping)) + for index := range valueMapping { + sortedIndexes = append(sortedIndexes, index) + } + sort.Ints(sortedIndexes) + for _, index := range sortedIndexes { + value := valueMapping[index] + values = append(values, value) + } + return values +} diff --git a/pkg/vm/builtins/poseidon.go b/pkg/vm/builtins/poseidon.go index 0a809313b..02197e15c 100644 --- a/pkg/vm/builtins/poseidon.go +++ b/pkg/vm/builtins/poseidon.go @@ -2,6 +2,9 @@ package builtins import ( "errors" + "fmt" + "math/big" + "sort" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -62,3 +65,53 @@ func (p *Poseidon) GetAllocatedSize(segmentUsedSize uint64, vmCurrentStep uint64 func (p *Poseidon) String() string { return PoseidonName } + +type AirPrivateBuiltinPoseidon struct { + Index int `json:"index"` + InputS0 string `json:"input_s0"` + InputS1 string `json:"input_s1"` + InputS2 string `json:"input_s2"` +} + +func (p *Poseidon) GetAirPrivateInput(poseidonSegment *mem.Segment) []AirPrivateBuiltinPoseidon { + valueMapping := make(map[int]AirPrivateBuiltinPoseidon) + for index, value := range poseidonSegment.Data { + if !value.Known() { + continue + } + idx, stateIndex := index/cellsPerPoseidon, index%cellsPerPoseidon + if stateIndex >= inputCellsPerPoseidon { + continue + } + + builtinValue, exists := valueMapping[idx] + if !exists { + builtinValue = AirPrivateBuiltinPoseidon{Index: idx} + } + + valueBig := big.Int{} + value.Felt.BigInt(&valueBig) + valueHex := fmt.Sprintf("0x%x", &valueBig) + if stateIndex == 0 { + builtinValue.InputS0 = valueHex + } else if stateIndex == 1 { + builtinValue.InputS1 = valueHex + } else if stateIndex == 2 { + builtinValue.InputS2 = valueHex + } + valueMapping[idx] = builtinValue + } + + values := make([]AirPrivateBuiltinPoseidon, 0) + + sortedIndexes := make([]int, 0, len(valueMapping)) + for index := range valueMapping { + sortedIndexes = append(sortedIndexes, index) + } + sort.Ints(sortedIndexes) + for _, index := range sortedIndexes { + value := valueMapping[index] + values = append(values, value) + } + return values +} diff --git a/pkg/vm/builtins/range_check.go b/pkg/vm/builtins/range_check.go index 7670bef41..bf2bf5ab3 100644 --- a/pkg/vm/builtins/range_check.go +++ b/pkg/vm/builtins/range_check.go @@ -4,19 +4,24 @@ import ( "errors" "fmt" "math" + "math/big" "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -const inputCellsPerRangeCheck = 1 -const cellsPerRangeCheck = 1 -const instancesPerComponentRangeCheck = 1 +const ( + RangeCheckName = "range_check" + RangeCheck96Name = "range_check96" + inputCellsPerRangeCheck = 1 + cellsPerRangeCheck = 1 + instancesPerComponentRangeCheck = 1 -// Each range check instance consists of RangeCheckNParts 16-bit parts. INNER_RC_BOUND_SHIFT and INNER_RC_BOUND_MASK are used to extract 16-bit parts from the field elements stored in the range check segment. -const INNER_RC_BOUND_SHIFT = 16 -const INNER_RC_BOUND_MASK = (1 << 16) - 1 + // Each range check instance consists of RangeCheckNParts 16-bit parts. INNER_RC_BOUND_SHIFT and INNER_RC_BOUND_MASK are used to extract 16-bit parts from the field elements stored in the range check segment. + INNER_RC_BOUND_SHIFT = 16 + INNER_RC_BOUND_MASK = (1 << 16) - 1 +) type RangeCheck struct { ratio uint64 @@ -56,9 +61,9 @@ func (r *RangeCheck) InferValue(segment *memory.Segment, offset uint64) error { func (r *RangeCheck) String() string { if r.RangeCheckNParts == 6 { - return "range_check96" + return RangeCheck96Name } else { - return "range_check" + return RangeCheckName } } @@ -89,3 +94,22 @@ func (r *RangeCheck) GetRangeCheckUsage(rangeCheckSegment *memory.Segment) (uint } return minVal, maxVal } + +type AirPrivateBuiltinRangeCheck struct { + Index int `json:"index"` + Value string `json:"value"` +} + +func (r *RangeCheck) GetAirPrivateInput(rangeCheckSegment *memory.Segment) []AirPrivateBuiltinRangeCheck { + values := make([]AirPrivateBuiltinRangeCheck, 0) + for index, value := range rangeCheckSegment.Data { + if !value.Known() { + continue + } + valueBig := big.Int{} + value.Felt.BigInt(&valueBig) + valueHex := fmt.Sprintf("0x%x", &valueBig) + values = append(values, AirPrivateBuiltinRangeCheck{Index: index, Value: valueHex}) + } + return values +} diff --git a/pkg/vm/builtins/segment_arena.go b/pkg/vm/builtins/segment_arena.go new file mode 100644 index 000000000..ac4c27e91 --- /dev/null +++ b/pkg/vm/builtins/segment_arena.go @@ -0,0 +1,5 @@ +package builtins + +const ( + SegmentArenaName string = "segment_arena" +) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index db7be1ad8..ad83aee16 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -5,7 +5,7 @@ import ( "fmt" "math" - a "github.com/NethermindEth/cairo-vm-go/pkg/assembler" + asmb "github.com/NethermindEth/cairo-vm-go/pkg/assembler" "github.com/NethermindEth/cairo-vm-go/pkg/utils" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -85,7 +85,7 @@ type VirtualMachine struct { Trace []Context config VirtualMachineConfig // instructions cache - instructions map[uint64]*a.Instruction + instructions map[uint64]*asmb.Instruction // RcLimitsMin and RcLimitsMax define the range of values of instructions offsets, used for checking the number of potential range checks holes RcLimitsMin uint16 RcLimitsMax uint16 @@ -99,7 +99,9 @@ func NewVirtualMachine( // Initialize the trace if necesary var trace []Context if config.ProofMode || config.CollectTrace { - trace = make([]Context, 0) + // starknet defines a limit on the maximum number of computational steps that a transaction can contain when processed on the Starknet network. + // https://docs.starknet.io/tools/limits-and-triggers/ + trace = make([]Context, 0, 10000000) } return &VirtualMachine{ @@ -107,7 +109,7 @@ func NewVirtualMachine( Memory: memory, Trace: trace, config: config, - instructions: make(map[uint64]*a.Instruction), + instructions: make(map[uint64]*asmb.Instruction), RcLimitsMin: math.MaxUint16, RcLimitsMax: 0, }, nil @@ -133,7 +135,7 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { return fmt.Errorf("reading instruction: %w", err) } - instruction, err = a.DecodeInstruction(bytecodeInstruction) + instruction, err = asmb.DecodeInstruction(bytecodeInstruction) if err != nil { return fmt.Errorf("decoding instruction: %w", err) } @@ -156,7 +158,9 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { const RC_OFFSET_BITS = 16 -func (vm *VirtualMachine) RunInstruction(instruction *a.Instruction) error { +//go:nosplit +func (vm *VirtualMachine) RunInstruction(instruction *asmb.Instruction) error { + var off0 int = int(instruction.OffDest) + (1 << (RC_OFFSET_BITS - 1)) var off1 int = int(instruction.OffOp0) + (1 << (RC_OFFSET_BITS - 1)) var off2 int = int(instruction.OffOp1) + (1 << (RC_OFFSET_BITS - 1)) @@ -217,9 +221,9 @@ func (vm *VirtualMachine) RunInstruction(instruction *a.Instruction) error { return nil } -func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddress, error) { +func (vm *VirtualMachine) getDstAddr(instruction *asmb.Instruction) (mem.MemoryAddress, error) { var dstRegister uint64 - if instruction.DstRegister == a.Ap { + if instruction.DstRegister == asmb.Ap { dstRegister = vm.Context.Ap } else { dstRegister = vm.Context.Fp @@ -232,9 +236,9 @@ func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddr return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: addr}, nil } -func (vm *VirtualMachine) getOp0Addr(instruction *a.Instruction) (mem.MemoryAddress, error) { +func (vm *VirtualMachine) getOp0Addr(instruction *asmb.Instruction) (mem.MemoryAddress, error) { var op0Register uint64 - if instruction.Op0Register == a.Ap { + if instruction.Op0Register == asmb.Ap { op0Register = vm.Context.Ap } else { op0Register = vm.Context.Fp @@ -248,10 +252,10 @@ func (vm *VirtualMachine) getOp0Addr(instruction *a.Instruction) (mem.MemoryAddr return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: addr}, nil } -func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.MemoryAddress) (mem.MemoryAddress, error) { +func (vm *VirtualMachine) getOp1Addr(instruction *asmb.Instruction, op0Addr *mem.MemoryAddress) (mem.MemoryAddress, error) { var op1Address mem.MemoryAddress switch instruction.Op1Source { - case a.Op0: + case asmb.Op0: // in this case Op0 is being used as an address, and must be of unwrapped as it op0Value, err := vm.Memory.ReadFromAddress(op0Addr) if err != nil { @@ -263,11 +267,11 @@ func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.Me return mem.UnknownAddress, fmt.Errorf("op0 is not an address: %w", err) } op1Address = mem.MemoryAddress{SegmentIndex: op0Address.SegmentIndex, Offset: op0Address.Offset} - case a.Imm: + case asmb.Imm: op1Address = vm.Context.AddressPc() - case a.FpPlusOffOp1: + case asmb.FpPlusOffOp1: op1Address = vm.Context.AddressFp() - case a.ApPlusOffOp1: + case asmb.ApPlusOffOp1: op1Address = vm.Context.AddressAp() } @@ -281,13 +285,13 @@ func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.Me // when there is an assertion with a substraction or division like : x = y - z // the compiler treats it as y = x + z. This means that the VM knows the -// dstCell value and either op0Cell xor op1Cell. This function infers the +// dstCell value and either op0Cell or op1Cell. This function infers the // unknow operand as well as the `res` auxiliar value func (vm *VirtualMachine) inferOperand( - instruction *a.Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, + instruction *asmb.Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, ) (mem.MemoryValue, error) { - if instruction.Opcode != a.OpCodeAssertEq || - instruction.Res == a.Unconstrained || + if instruction.Opcode != asmb.OpCodeAssertEq || + instruction.Res == asmb.Unconstrained || !vm.Memory.KnownValueAtAddress(dstAddr) { return mem.MemoryValue{}, nil } @@ -308,7 +312,7 @@ func (vm *VirtualMachine) inferOperand( return mem.MemoryValue{}, nil } - if instruction.Res == a.Op1 && !op1Value.Known() { + if instruction.Res == asmb.Op1 && !op1Value.Known() { if err = vm.Memory.WriteToAddress(op1Addr, &dstValue); err != nil { return mem.MemoryValue{}, err } @@ -326,7 +330,7 @@ func (vm *VirtualMachine) inferOperand( } var missingVal mem.MemoryValue - if instruction.Res == a.AddOperands { + if instruction.Res == asmb.AddOperands { missingVal = mem.EmptyMemoryValueAs(dstValue.IsAddress()) err = missingVal.Sub(&dstValue, &knownOpValue) } else { @@ -344,12 +348,12 @@ func (vm *VirtualMachine) inferOperand( } func (vm *VirtualMachine) computeRes( - instruction *a.Instruction, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, + instruction *asmb.Instruction, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, ) (mem.MemoryValue, error) { switch instruction.Res { - case a.Unconstrained: + case asmb.Unconstrained: return mem.MemoryValue{}, nil - case a.Op1: + case asmb.Op1: op1, err := vm.Memory.ReadFromAddress(op1Addr) if err != nil { return mem.UnknownValue, fmt.Errorf("cannot read op1: %w", err) @@ -368,9 +372,9 @@ func (vm *VirtualMachine) computeRes( } res := mem.EmptyMemoryValueAs(op0.IsAddress() || op1.IsAddress()) - if instruction.Res == a.AddOperands { + if instruction.Res == asmb.AddOperands { err = res.Add(&op0, &op1) - } else if instruction.Res == a.MulOperands { + } else if instruction.Res == asmb.MulOperands { err = res.Mul(&op0, &op1) } else { return mem.MemoryValue{}, fmt.Errorf("invalid res flag value: %d", instruction.Res) @@ -380,13 +384,13 @@ func (vm *VirtualMachine) computeRes( } func (vm *VirtualMachine) opcodeAssertions( - instruction *a.Instruction, + instruction *asmb.Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, res *mem.MemoryValue, ) error { switch instruction.Opcode { - case a.OpCodeCall: + case asmb.OpCodeCall: fpAddr := vm.Context.AddressFp() fpMv := mem.MemoryValueFromMemoryAddress(&fpAddr) // Store at [ap] the current fp @@ -402,7 +406,7 @@ func (vm *VirtualMachine) opcodeAssertions( if err := vm.Memory.WriteToAddress(op0Addr, &apMv); err != nil { return err } - case a.OpCodeAssertEq: + case asmb.OpCodeAssertEq: // assert that the calculated res is stored in dst if err := vm.Memory.WriteToAddress(dstAddr, res); err != nil { return err @@ -412,18 +416,18 @@ func (vm *VirtualMachine) opcodeAssertions( } func (vm *VirtualMachine) updatePc( - instruction *a.Instruction, + instruction *asmb.Instruction, dstAddr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, res *mem.MemoryValue, ) (mem.MemoryAddress, error) { switch instruction.PcUpdate { - case a.PcUpdateNextInstr: + case asmb.PcUpdateNextInstr: return mem.MemoryAddress{ SegmentIndex: vm.Context.Pc.SegmentIndex, Offset: vm.Context.Pc.Offset + uint64(instruction.Size()), }, nil - case a.PcUpdateJump: + case asmb.PcUpdateJump: // both address and felt are allowed here. It can be a felt when used // with an immediate or a memory address holding a felt. It can be an address // when a memory address holds a memory address @@ -439,7 +443,7 @@ func (vm *VirtualMachine) updatePc( fmt.Errorf("absolute jump: invalid jump location: %w", err) } - case a.PcUpdateJumpRel: + case asmb.PcUpdateJumpRel: val, err := res.FieldElement() if err != nil { return mem.UnknownAddress, fmt.Errorf("relative jump: %w", err) @@ -447,7 +451,7 @@ func (vm *VirtualMachine) updatePc( newPc := vm.Context.Pc err = newPc.Add(&newPc, val) return newPc, err - case a.PcUpdateJnz: + case asmb.PcUpdateJnz: destMv, err := vm.Memory.ReadFromAddress(dstAddr) if err != nil { return mem.UnknownAddress, err @@ -482,11 +486,11 @@ func (vm *VirtualMachine) updatePc( return mem.UnknownAddress, fmt.Errorf("unkwon pc update value: %d", instruction.PcUpdate) } -func (vm *VirtualMachine) updateAp(instruction *a.Instruction, res *mem.MemoryValue) (uint64, error) { +func (vm *VirtualMachine) updateAp(instruction *asmb.Instruction, res *mem.MemoryValue) (uint64, error) { switch instruction.ApUpdate { - case a.SameAp: + case asmb.SameAp: return vm.Context.Ap, nil - case a.AddRes: + case asmb.AddRes: apFelt := new(f.Element).SetUint64(vm.Context.Ap) // Convert ap value to felt resFelt, err := res.FieldElement() // Extract the f.Element from MemoryValue @@ -499,20 +503,20 @@ func (vm *VirtualMachine) updateAp(instruction *a.Instruction, res *mem.MemoryVa return 0, fmt.Errorf("resulting AP value is too large to fit in uint64") } return newAp.Uint64(), nil // Return the addition as uint64 - case a.Add1: + case asmb.Add1: return vm.Context.Ap + 1, nil - case a.Add2: + case asmb.Add2: return vm.Context.Ap + 2, nil } return 0, fmt.Errorf("cannot update ap, unknown ApUpdate flag: %d", instruction.ApUpdate) } -func (vm *VirtualMachine) updateFp(instruction *a.Instruction, dstAddr *mem.MemoryAddress) (uint64, error) { +func (vm *VirtualMachine) updateFp(instruction *asmb.Instruction, dstAddr *mem.MemoryAddress) (uint64, error) { switch instruction.Opcode { - case a.OpCodeCall: + case asmb.OpCodeCall: // [ap] and [ap + 1] are written to memory return vm.Context.Ap + 2, nil - case a.OpCodeRet: + case asmb.OpCodeRet: // [dst] should be a memory address of the form (executionSegment, fp - 2) destMv, err := vm.Memory.ReadFromAddress(dstAddr) if err != nil { @@ -531,15 +535,13 @@ func (vm *VirtualMachine) updateFp(instruction *a.Instruction, dstAddr *mem.Memo // It returns the trace after relocation, i.e, relocates pc, ap and fp for each step // to be their real address value -func (vm *VirtualMachine) RelocateTrace() []Trace { +func (vm *VirtualMachine) RelocateTrace(relocatedTrace *[]Trace) { // one is added, because prover expect that the first element to be // indexed on 1 instead of 0 - relocatedTrace := make([]Trace, len(vm.Trace)) totalBytecode := vm.Memory.Segments[ProgramSegment].Len() + 1 for i := range vm.Trace { - relocatedTrace[i] = vm.Trace[i].Relocate(totalBytecode) + (*relocatedTrace)[i] = vm.Trace[i].Relocate(totalBytecode) } - return relocatedTrace } // It returns all segments in memory but relocated as a single segment @@ -553,17 +555,16 @@ func (vm *VirtualMachine) RelocateMemory() []*f.Element { relocatedMemory := make([]*f.Element, maxMemoryUsed) for i, segment := range vm.Memory.Segments { for j := uint64(0); j < segment.RealLen(); j++ { - cell := segment.Data[j] - if !cell.Known() { + if !segment.Data[j].Known() { continue } var felt *f.Element - if cell.IsAddress() { - addr, _ := cell.MemoryAddress() + if segment.Data[j].IsAddress() { + addr, _ := segment.Data[j].MemoryAddress() felt = addr.Relocate(segmentsOffsets) } else { - felt, _ = cell.FieldElement() + felt, _ = segment.Data[j].FieldElement() } relocatedMemory[segmentsOffsets[i]+j] = felt } diff --git a/rust_vm_bin/cairo-vm-cli b/rust_vm_bin/cairo-vm-cli new file mode 100755 index 000000000..afa56ea90 Binary files /dev/null and b/rust_vm_bin/cairo-vm-cli differ