Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ModBuiltin #666

Merged
merged 34 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7af8191
implement partial functionalities and structs required
Sh0g0-1758 Sep 19, 2024
307c881
run fmt
Sh0g0-1758 Sep 19, 2024
0ed6213
Introduce max func
Sh0g0-1758 Sep 19, 2024
d845681
Added more functions
Sh0g0-1758 Sep 19, 2024
a219b2d
ran fmt
Sh0g0-1758 Sep 19, 2024
23129af
Add test for modulo
Sh0g0-1758 Sep 19, 2024
e1e550f
Added structure for integrating builtin with the vm
Sh0g0-1758 Sep 20, 2024
d609a30
Added useful comments
Sh0g0-1758 Sep 20, 2024
7c364e7
Update FillMemory
Sh0g0-1758 Sep 20, 2024
5d81a19
Added more comments
Sh0g0-1758 Sep 24, 2024
f5d4881
Add test
Sh0g0-1758 Sep 24, 2024
f9c5d88
Added TODOs
Sh0g0-1758 Sep 24, 2024
bedf562
fixes
Sh0g0-1758 Sep 26, 2024
b42f55c
update loop
Sh0g0-1758 Sep 26, 2024
1372810
Merge branch 'main' into mod
Sh0g0-1758 Sep 26, 2024
6497ed3
Added addmod and mulmod
Sh0g0-1758 Sep 26, 2024
2b5a41b
added some fixes
Sh0g0-1758 Sep 26, 2024
ce03991
test passes now
Sh0g0-1758 Sep 26, 2024
124d578
fix error and test
Sh0g0-1758 Sep 26, 2024
0635c8f
Make tests pass
Sh0g0-1758 Sep 26, 2024
7a48590
mod subtraction tests
Sh0g0-1758 Sep 26, 2024
7cbece4
Added some comments
Sh0g0-1758 Sep 26, 2024
7f85534
Added subtraction tests
Sh0g0-1758 Sep 26, 2024
76d201e
Added recursive case as well
Sh0g0-1758 Sep 26, 2024
e885b3d
nit
Sh0g0-1758 Sep 26, 2024
cb0ede7
Add multiplication test
Sh0g0-1758 Sep 26, 2024
8a63ad6
refactor math utils
Sh0g0-1758 Sep 26, 2024
96d5623
move one math utils function
Sh0g0-1758 Sep 26, 2024
b48ea8d
All Tests pass
Sh0g0-1758 Sep 27, 2024
84e76ab
nit
Sh0g0-1758 Sep 27, 2024
b86a637
Added mirror functionality
Sh0g0-1758 Sep 27, 2024
28459bd
remove redundant test
Sh0g0-1758 Sep 27, 2024
fefaefd
updated checkwrite and infervalue
Sh0g0-1758 Sep 27, 2024
fa837ee
nit
Sh0g0-1758 Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (hint U256InvModN) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerConte
n := new(big.Int).Lsh(&N1BigInt, 128)
n.Add(n, &N0BigInt)

_, r, g := u.Igcdex(n, b)
_, r, g := utils.Igcdex(n, b)
mask := new(big.Int).Lsh(big.NewInt(1), 128)
mask.Sub(mask, big.NewInt(1))

Expand Down
65 changes: 2 additions & 63 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math/big"

"github.com/NethermindEth/cairo-vm-go/pkg/utils"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

Expand Down Expand Up @@ -55,7 +56,7 @@ func AsIntBig(value *big.Int) big.Int {
func Divmod(n, m, p *big.Int) (big.Int, error) {
// https://github.com/starkware-libs/cairo-lang/blob/efa9648f57568aad8f8a13fbf027d2de7c63c2c0/src/starkware/python/math_utils.py#L26

a, _, c := Igcdex(m, p)
a, _, c := utils.Igcdex(m, p)
if c.Cmp(big.NewInt(1)) != 0 {
return *big.NewInt(0), errors.New("no solution exists (gcd(m, p) != 1)")
}
Expand All @@ -65,68 +66,6 @@ func Divmod(n, m, p *big.Int) (big.Int, error) {
return *res, nil
}

func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362

if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0)
}
g, x, y := gcdext(a, b)
return x, y, g
}

func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125

if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 {
g := new(big.Int)
if a.Cmp(big.NewInt(0)) == 0 {
g.Abs(b)
} else {
g.Abs(a)
}

if g.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0)
}
return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g)
}

xSign, aSigned := sign(a)
ySign, bSigned := sign(b)
x, r := big.NewInt(1), big.NewInt(0)
y, s := big.NewInt(0), big.NewInt(1)

for bSigned.Sign() != 0 {
q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int))
aSigned = bSigned
bSigned = *c
x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r))
y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s))
}

return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign)))
}

func sign(n *big.Int) (int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119

if n.Sign() < 0 {
return -1, *new(big.Int).Abs(n)
}
return 1, *new(big.Int).Set(n)
}

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}

func IsQuadResidue(x *fp.Element) bool {
// Implementation adapted from sympy implementation which can be found here :
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689
Expand Down
66 changes: 0 additions & 66 deletions pkg/hintrunner/utils/math_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,69 +46,3 @@ func TestDivMod(t *testing.T) {
})
}
}

func TestIgcdex(t *testing.T) {
// https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278
tests := []struct {
name string
a, b *big.Int
expectedX, expectedY, expectedG *big.Int
}{
{
name: "Case 1",
a: big.NewInt(2),
b: big.NewInt(3),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(1),
},
{
name: "Case 2",
a: big.NewInt(10),
b: big.NewInt(12),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(2),
},
{
name: "Case 3",
a: big.NewInt(100),
b: big.NewInt(2004),
expectedX: big.NewInt(-20),
expectedY: big.NewInt(1),
expectedG: big.NewInt(4),
},
{
name: "Case 4",
a: big.NewInt(0),
b: big.NewInt(0),
expectedX: big.NewInt(0),
expectedY: big.NewInt(1),
expectedG: big.NewInt(0),
},
{
name: "Case 5",
a: big.NewInt(1),
b: big.NewInt(0),
expectedX: big.NewInt(1),
expectedY: big.NewInt(0),
expectedG: big.NewInt(1),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualX, actualY, actualG := Igcdex(tt.a, tt.b)

if actualX.Cmp(tt.expectedX) != 0 {
t.Errorf("got x: %v, want: %v", actualX, tt.expectedX)
}
if actualY.Cmp(tt.expectedY) != 0 {
t.Errorf("got x: %v, want: %v", actualY, tt.expectedY)
}
if actualG.Cmp(tt.expectedG) != 0 {
t.Errorf("got x: %v, want: %v", actualG, tt.expectedG)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func newDivModNSafeDivPlusOneHint() hinter.Hinter {
valueBig.Mul(resBig, bBig)
valueBig.Sub(valueBig, aBig)

newValueBig, err := secp_utils.SafeDiv(valueBig, nBig)
newValueBig, err := utils.SafeDiv(valueBig, nBig)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/hintrunner/zero/zerohint_signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
secp_utils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
Expand Down Expand Up @@ -372,7 +373,7 @@ func newDivModSafeDivHint() hinter.Hinter {
}

divisor := new(big.Int).Sub(new(big.Int).Mul(res, b), a)
value, err := secp_utils.SafeDiv(divisor, N)
value, err := utils.SafeDiv(divisor, N)
if err != nil {
return err
}
Expand Down
22 changes: 22 additions & 0 deletions pkg/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,28 @@ func TestEcOpBuiltin(t *testing.T) {
require.NoError(t, err)
}

func TestModuloBuiltin(t *testing.T) {
// modulo is located at fp - 3
// we first write 2048 and 5 to modulo
// then we read the modulo result from add and mul
// runner := createRunner(`
// [ap] = 2048;
// [ap] = [[fp - 3]];

// [ap + 1] = 5;
// [ap + 1] = [[fp - 3] + 1];
// ret;
// `, "small", sn.AddMod, sn.MulMod)

// err := runner.Run()
// require.NoError(t, err)

// modulo, ok := runner.vm.Memory.FindSegmentWithBuiltin("add_mod")
// require.True(t, ok)

// requireEqualSegments(t, createSegment(2048, 5), modulo)
}

func createRunner(code string, layoutName string, builtins ...builtins.BuiltinType) ZeroRunner {
program := createProgramWithBuiltins(code, builtins...)

Expand Down
74 changes: 74 additions & 0 deletions pkg/utils/math.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package utils

import (
"errors"
"fmt"
"math"
"math/big"
"math/bits"
Expand Down Expand Up @@ -156,3 +158,75 @@ func Int16FromBigInt(n *big.Int) (int16, bool) {
func RightRot(value uint32, n uint32) uint32 {
return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n))
}

func SafeDivUint64(x, y uint64) (uint64, error) {
if y == 0 {
return 0, fmt.Errorf("cannot divide: y division is zero")
}
if x%y != 0 {
return 0, errors.New("cannot divide: x is not divisible by y")
}
return x / y, nil
}

func Igcdex(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/core/intfunc.py#L362

if a.Cmp(big.NewInt(0)) == 0 && b.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(1), *big.NewInt(0)
}
g, x, y := gcdext(a, b)
return x, y, g
}

func gcdext(a, b *big.Int) (big.Int, big.Int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L125

if a.Cmp(big.NewInt(0)) == 0 || b.Cmp(big.NewInt(0)) == 0 {
g := new(big.Int)
if a.Cmp(big.NewInt(0)) == 0 {
g.Abs(b)
} else {
g.Abs(a)
}

if g.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), *big.NewInt(0), *big.NewInt(0)
}
return *g, *new(big.Int).Div(a, g), *new(big.Int).Div(b, g)
}

xSign, aSigned := sign(a)
ySign, bSigned := sign(b)
x, r := big.NewInt(1), big.NewInt(0)
y, s := big.NewInt(0), big.NewInt(1)

for bSigned.Sign() != 0 {
q, c := new(big.Int).DivMod(&aSigned, &bSigned, new(big.Int))
aSigned = bSigned
bSigned = *c
x, r = r, new(big.Int).Sub(x, new(big.Int).Mul(q, r))
y, s = s, new(big.Int).Sub(y, new(big.Int).Mul(q, s))
}

return aSigned, *new(big.Int).Mul(x, big.NewInt(int64(xSign))), *new(big.Int).Mul(y, big.NewInt(int64(ySign)))
}

func sign(n *big.Int) (int, big.Int) {
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/external/ntheory.py#L119

if n.Sign() < 0 {
return -1, *new(big.Int).Abs(n)
}
return 1, *new(big.Int).Set(n)
}

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}
67 changes: 67 additions & 0 deletions pkg/utils/math_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
Expand Down Expand Up @@ -119,3 +120,69 @@ func TestRightRot(t *testing.T) {
})
}
}

func TestIgcdex(t *testing.T) {
// https://github.com/sympy/sympy/blob/e7fb2714f17b30b83e424448aad0da9e94a4b577/sympy/core/tests/test_numbers.py#L278
tests := []struct {
name string
a, b *big.Int
expectedX, expectedY, expectedG *big.Int
}{
{
name: "Case 1",
a: big.NewInt(2),
b: big.NewInt(3),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(1),
},
{
name: "Case 2",
a: big.NewInt(10),
b: big.NewInt(12),
expectedX: big.NewInt(-1),
expectedY: big.NewInt(1),
expectedG: big.NewInt(2),
},
{
name: "Case 3",
a: big.NewInt(100),
b: big.NewInt(2004),
expectedX: big.NewInt(-20),
expectedY: big.NewInt(1),
expectedG: big.NewInt(4),
},
{
name: "Case 4",
a: big.NewInt(0),
b: big.NewInt(0),
expectedX: big.NewInt(0),
expectedY: big.NewInt(1),
expectedG: big.NewInt(0),
},
{
name: "Case 5",
a: big.NewInt(1),
b: big.NewInt(0),
expectedX: big.NewInt(1),
expectedY: big.NewInt(0),
expectedG: big.NewInt(1),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualX, actualY, actualG := Igcdex(tt.a, tt.b)

if actualX.Cmp(tt.expectedX) != 0 {
t.Errorf("got x: %v, want: %v", actualX, tt.expectedX)
}
if actualY.Cmp(tt.expectedY) != 0 {
t.Errorf("got x: %v, want: %v", actualY, tt.expectedY)
}
if actualG.Cmp(tt.expectedG) != 0 {
t.Errorf("got x: %v, want: %v", actualG, tt.expectedG)
}
})
}
}
Loading
Loading