Skip to content

Commit

Permalink
WIP wrapped hash
Browse files Browse the repository at this point in the history
  • Loading branch information
ivokub committed Oct 25, 2023
1 parent 5b6b370 commit d1ab301
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
78 changes: 66 additions & 12 deletions std/recursion/wrapped_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
stdhash "github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/math/bits"
"golang.org/x/exp/slices"
)

type shortNativeHash struct {
Expand Down Expand Up @@ -86,16 +87,24 @@ func newShortFromParam(hf hash.Hash, bitBlockSize, outSize int) hash.Hash {
}

func (h *shortNativeHash) Write(p []byte) (n int, err error) {
// we first write to the buffer. We want to be able to partition the inputs
// into smaller parts and buffer is good to keep track of the excess.
h.ringBuf.Write(p) // nosec: doesnt fail
for h.ringBuf.Len() >= (len(h.buf) - 1) {
// the buffer contains now enough bytes so that we can write it to the
// underlying hash.
h.ringBuf.Read(h.buf[1:])
h.wrapped.Write(h.buf)
}
return len(p), nil
}

func (h *shortNativeHash) Sum(b []byte) []byte {
// zero everything
// the cache buffer may contain still something. Write everything into the
// underlying hasher before we digest.

// zero the buffer we use for transporting bytes from bytes.Buffer to
// underlying hash. Remember that the cache buffer may not be full.
for i := range h.buf {
h.buf[i] = 0
}
Expand All @@ -104,14 +113,17 @@ func (h *shortNativeHash) Sum(b []byte) []byte {

// TODO: I'm cutting the hash on bit short to avoid edge cases:
res := h.wrapped.Sum(nil)
fmt.Printf("sum %x\n", res)
nbBytes := (h.outSize + 7) / 8
res = res[len(res)-nbBytes:]
mask := (1 << ((h.outSize - 1) % 8)) - 1
res[0] &= byte(mask)
res = res[len(res)-nbBytes+1:]
// mask := (1 << ((h.outSize - 1) % 8)) - 1
// res[0] &= byte(mask)
return append(b, res...)
}

func (h *shortNativeHash) Reset() {
h.ringBuf.Reset()
h.buf = make([]byte, (h.bitBlockSize+7)/8)
h.wrapped.Reset()
}

Expand All @@ -128,13 +140,19 @@ type shortCircuitHash struct {
bitLength int
wrapped stdhash.FieldHasher
buf []frontend.Variable
tmp []frontend.Variable
}

func newHashFromParameter(api frontend.API, hf stdhash.FieldHasher, bitLength int) stdhash.FieldHasher {
tmp := make([]frontend.Variable, ((bitLength+7)/8)*8-8)
for i := range tmp {
tmp[i] = 0
}
return &shortCircuitHash{
api: api,
bitLength: bitLength,
wrapped: hf,
tmp: tmp,
}
}

Expand All @@ -158,27 +176,63 @@ func NewHash(api frontend.API, target *big.Int) (stdhash.FieldHasher, error) {
}

func (h *shortCircuitHash) Sum() frontend.Variable {
v := bits.FromBinary(h.api, h.buf)
h.Write(v)
// before we compute the digest we have to write the rest of the buffer into
// the underlying hash. We know that we have maximum one variable left, as
// otherwise we would have written in the [Write] method.

// but first, we have to zero the buffer we use for reversing. The cache
// buffer may not be full and so some bits may be set.
for i := range h.tmp {
h.tmp[i] = 0
}
copy(h.tmp, h.buf)
slices.Reverse(h.tmp)
v := bits.FromBinary(h.api, h.tmp)
h.wrapped.Write(v)
res := h.wrapped.Sum()
h.api.Println(res)
resBts := bits.ToBinary(h.api, res)
res = bits.FromBinary(h.api, resBts[:h.bitLength-1])
res = bits.FromBinary(h.api, resBts[:len(h.tmp)])
return res
}

func (h *shortCircuitHash) Write(data ...frontend.Variable) {
// tricky part - bits representation is little-endian, i.e. least
// significant bit is at position zero. However, in the native version least
// significant BYTE is at the highest position. When we decompose into bits,
// then we first have to reverse the bits so that when we partition maximum
// number of full bytes out so it would correspond to the native version.
//
// But this means that later we have to reverse again when we recompose.
for i := range data {
bts := bits.ToBinary(h.api, data[i])
// h.tmp is maximum full number of bytes. This is one byte less than in
// the native version (the bits are on full number of bytes). Luckily,
// [bits.ToBinary] allows to decompose into arbitrary number of bits.
bts := bits.ToBinary(h.api, data[i], bits.WithNbDigits(len(h.tmp)+8))
// reverse to be in sync with native version when we later slice
// len(h.tmp) bits.
slices.Reverse(bts)
// store in the buffer. At every round we try to write to the wrapped
// hash as much as possible so the buffer isn't usually very big.
h.buf = append(h.buf, bts...)
}
blockLength := (((h.bitLength + 7) / 8) - 1) * 8
for len(h.buf) >= blockLength {
v := bits.FromBinary(h.api, h.buf[:blockLength])
for len(h.buf) >= len(h.tmp) {
// OK, now there is sufficient number of bits we can write to hash
// function. First we take the maximum number of full bytes.
copy(h.tmp, h.buf[:len(h.tmp)])
// and reverse it so that when recomposing is correct.
slices.Reverse(h.tmp)
v := bits.FromBinary(h.api, h.tmp)
// write to the underlying hash and empty the buffer.
h.wrapped.Write(v)
h.buf = h.buf[blockLength:]
h.buf = h.buf[len(h.tmp):]
}
}

func (h *shortCircuitHash) Reset() {
h.buf = nil
for i := range h.tmp {
h.tmp[i] = 0
}
h.wrapped.Reset()
}
10 changes: 8 additions & 2 deletions std/recursion/wrapped_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (c *shortHashCircuit) Define(api frontend.API) error {

func TestShortHash(t *testing.T) {
outerCurves := []ecc.ID{
// tinyfield.Modulus(),
ecc.BN254,
ecc.BLS12_381,
ecc.BLS12_377,
Expand All @@ -50,7 +51,7 @@ func TestShortHash(t *testing.T) {
}

assert := test.NewAssert(t)
nbInputs := 10
nbInputs := 100
for _, outer := range outerCurves {
outer := outer
for _, inner := range innerCurves {
Expand All @@ -62,15 +63,20 @@ func TestShortHash(t *testing.T) {
witness := &shortHashCircuit{Input: make([]frontend.Variable, nbInputs), inner: inner}
buf := make([]byte, (outer.ScalarField().BitLen()+7)/8)
for i := range witness.Input {
// el, _ := new(big.Int).SetString("1231231230981238971241240982112382934728934798234981324798123981724198712467928497124987124", 10)
// el.Mod(el, outer)
el, err := rand.Int(rand.Reader, outer.ScalarField())
assert.NoError(err)
el.FillBytes(buf)
// fmt.Printf("input: %x\n", buf)
h.Write(buf)
witness.Input[i] = el
}
res := h.Sum(nil)
witness.Output = res
assert.CheckCircuit(circuit, test.WithCurves(outer), test.WithValidAssignment(witness), test.NoFuzzing(), test.NoSerializationChecks(), test.NoSolidityChecks())
err = test.IsSolved(circuit, witness, outer.ScalarField())
assert.NoError(err)
// assert.CheckCircuit(circuit, test.WithCurves(outer), test.WithValidAssignment(witness), test.NoFuzzing(), test.NoSerializationChecks(), test.NoSolidityChecks())
}, inner.String())
}
}
Expand Down

0 comments on commit d1ab301

Please sign in to comment.