Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: "start at" #882

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/consensys/gnark

go 1.20
go 1.21

require (
github.com/bits-and-blooms/bitset v1.8.0
Expand Down
156 changes: 58 additions & 98 deletions std/compress/lzss_v1/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"bytes"
"encoding/binary"
"fmt"
"index/suffixarray"

"github.com/consensys/gnark/std/compress/lzss_v1/suffixarray"

"math/bits"

"github.com/consensys/gnark-crypto/utils"
Expand All @@ -24,25 +26,37 @@ import (
func Compress(d []byte, settings Settings) (c []byte, err error) {
// d[i < 0] = Settings.BackRefSettings.Symbol by convention
var out bytes.Buffer
out.Grow(len(d))

emitBackRef := func(offset, length int) {
out.WriteByte(0)
// fmt.Println("offset -1", offset-1)
emit(&out, offset-1, settings.NbBytesAddress)
emit(&out, length-1, settings.NbBytesLength)
}
compressor := newCompressor(d, settings)
i := 0
i := int(settings.StartAt)

// under that threshold, it's more interesting to write the symbol directly.
t := int(1 + compressor.settings.NbBytesAddress + compressor.settings.NbBytesLength)

for i < len(d) {
addr, length := compressor.longestMostRecentBackRef(i)
if length == -1 {
// no backref found
if d[i] == 0 {
var addr, length int
if d[i] == 0 {
addr, length = compressor.longestMostRecentBackRef(i, 1)
if length == -1 {
// no backref found
return nil, fmt.Errorf("could not find an RLE backref at index %d", i)
}
out.WriteByte(d[i])
i++
continue
} else {
addr, length = compressor.longestMostRecentBackRef(i, t)
if length == -1 {
out.WriteByte(d[i])
i++
continue
}
}

emitBackRef(i-addr, length)
i += length
}
Expand All @@ -51,12 +65,9 @@ func Compress(d []byte, settings Settings) (c []byte, err error) {
}

type compressor struct {
// TODO @gbotrel we have to be a bit careful with the size
// and do some extra checks; here we assume that we never compress more than 1MB
longestZeroPrefix [1 << 20]int // longestZeroPrefix[i] = longest run of zeroes starting at i
d []byte
index *suffixarray.Index
settings Settings
d []byte
index *suffixarray.Index
settings Settings
}

func newCompressor(d []byte, settings Settings) *compressor {
Expand All @@ -65,107 +76,56 @@ func newCompressor(d []byte, settings Settings) *compressor {
index: suffixarray.New(d),
settings: settings,
}
compressor.initZeroPrefix()
return compressor
}

func (compressor *compressor) initZeroPrefix() {
d := compressor.d
for j := len(d) - 1; j >= 0; j-- {
if d[j] != 0 {
compressor.longestZeroPrefix[j] = 0
continue
}
compressor.longestZeroPrefix[j] = 1 + compressor.longestZeroPrefix[j+1]
}
}

// longestMostRecentBackRef attempts to find a backref that is 1) longest 2) most recent in that order of priority
func (compressor *compressor) longestMostRecentBackRef(i int) (addr, length int) {
func (compressor *compressor) longestMostRecentBackRef(i, minRefLen int) (addr, length int) {
d := compressor.d
// var backRefLen int
brAddressRange := 1 << (compressor.settings.NbBytesAddress * 8)
brLengthRange := 1 << (compressor.settings.NbBytesLength * 8)
minBackRefAddr := i - brAddressRange

windowStart := utils.Max(0, minBackRefAddr)
endWindow := utils.Min(i+brAddressRange, len(d))

if d[i] == 0 { // RLE; prune the options
// we can't encode 0 as is, so we must find a backref.

// runLen := compressor.countZeroes(i, brLengthRange) // utils.Min(getRunLength(d, i), brLengthRange)
runLen := utils.Min(compressor.longestZeroPrefix[i], brLengthRange)

backrefAddr := -1
backrefLen := -1
for j := i - 1; j >= windowStart; j-- {
n := utils.Min(compressor.longestZeroPrefix[j], runLen)
if n == 0 {
continue
}
// check if we can make this backref longer
m := matchLen(d[i+n:endWindow], d[j+n:]) + n

if m > backrefLen {
if m >= brLengthRange {
// we can stop we won't find a longer backref
return j, brLengthRange
}
backrefLen = m
backrefAddr = j
}
}
if (backrefLen == -1 && minBackRefAddr < 0) || (backrefLen != -1 && minBackRefAddr < 0 && backrefLen < -minBackRefAddr) {
backrefAddr = minBackRefAddr
backrefLen = utils.Min(runLen, -minBackRefAddr)
}
return backrefAddr, backrefLen
maxRefLen := brLengthRange // utils.Min(i+brLengthRange, len(d))
if i+maxRefLen > len(d) {
maxRefLen = len(d) - i
}

// else -->
// d[i] != 0

// under that threshold, it's more interesting to write the symbol directly.
t := int(1 + compressor.settings.NbBytesAddress + compressor.settings.NbBytesLength)

if i+t > len(d) {
if i+minRefLen > len(d) {
return -1, -1
}

matches := compressor.index.Lookup(d[i:i+t], -1)

bLen := -1
bAddr := -1
for _, offset := range matches {
if offset < windowStart || offset >= i {
// out of the window bound
continue
}
n := matchLen(d[i+t:endWindow], d[offset+t:]) + t
if n > bLen {
bLen = n
if bLen >= brLengthRange {
// we can stop we won't find a longer backref
return offset, brLengthRange
}
bAddr = offset
}

addr, len := compressor.index.LookupLongest(d[i:i+maxRefLen], minRefLen, maxRefLen, windowStart, i)
if len == -1 {
return -1, -1
}
return addr, len

// matches := compressor.index.Lookup(d[i:i+t], -1)

// bLen := -1
// bAddr := -1
// for _, offset := range matches {
// if offset < windowStart || offset >= i {
// // out of the window bound
// continue
// }
// n := matchLen(d[i+t:endWindow], d[offset+t:]) + t
// if n > bLen {
// bLen = n
// if bLen >= 64 {
// // we can stop we won't find a longer backref
// return offset, min(bLen, brLengthRange)
// }
// bAddr = offset
// }

// }

// return bAddr, bLen

return bAddr, bLen

}

func countZeroes(a []byte, maxCount int) (count int) {
for i := 0; i < len(a) && count < maxCount; i++ {
if a[i] != 0 {
break
}
count++
}
return
}

// matchLen returns the maximum common prefix length of a and b.
Expand Down
50 changes: 38 additions & 12 deletions std/compress/lzss_v1/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"testing"
Expand All @@ -26,20 +27,23 @@ func testCompressionRoundTrip(t *testing.T, nbBytesAddress uint, d []byte, testC
d, err = os.ReadFile("../test_cases/" + testCaseName[0] + "/data.bin")
require.NoError(t, err)
}
const contextSize = 256
d = append(make([]byte, contextSize), d...)
settings := Settings{
BackRefSettings: BackRefSettings{
NbBytesAddress: nbBytesAddress,
NbBytesLength: 1,
},
StartAt: 256,
}
c, err := Compress(d, settings)
if len(testCaseName) == 1 {
assert.NoError(t, os.WriteFile("../test_cases/"+testCaseName[0]+"/data.lzssv1", c, 0600))
}
cStream := compress.NewStreamFromBytes(c)
cHuff := huffman.Encode(cStream)
fmt.Println("Size Compression ratio:", float64(len(d))/float64(len(c)))
fmt.Println("Estimated Compression ratio (with Huffman):", float64(8*len(d))/float64(len(cHuff.D)))
fmt.Println("Size Compression ratio:", float64(len(d)-contextSize)/float64(len(c)))
fmt.Println("Estimated Compression ratio (with Huffman):", float64(8*(len(d)-contextSize))/float64(len(cHuff.D)))
if len(c) > 1024 {
fmt.Printf("Compressed size: %dKB\n", int(float64(len(c)*100)/1024)/100)
fmt.Printf("Compressed size (with Huffman): %dKB\n", int(float64(len(cHuff.D)*100)/8192)/100)
Expand All @@ -53,7 +57,7 @@ func testCompressionRoundTrip(t *testing.T, nbBytesAddress uint, d []byte, testC
printHex(c)
}

require.Equal(t, d, dBack)
require.Equal(t, d[contextSize:], dBack)

// store huffman code lengths
lens := huffman.GetCodeLengths(cStream)
Expand Down Expand Up @@ -144,14 +148,15 @@ func TestAverageBatch(t *testing.T) {
data, err := hex.DecodeString(string(d))
assert.NoError(err)

dict := getDictionnary()
// test compress round trip with s2, zstd and lzss
// s2Res, err := compressWithS2(data)
// assert.NoError(err)

// zstdRes, err := compressWithZstd(data)
// assert.NoError(err)

lzssRes, err := compresslzss_v1(data)
lzssRes, err := compresslzss_v1(data, dict)
assert.NoError(err)

// fmt.Println("s2 compression ratio:", s2Res.ratio)
Expand All @@ -167,7 +172,7 @@ func TestAverageBatch(t *testing.T) {
// zstdDecompressed, err := decompressWithZstd(zstdRes.compressed)
// assert.NoError(err)

lzssDecompressed, err := decompresslzss_v1(lzssRes.compressed)
lzssDecompressed, err := decompresslzss_v1(lzssRes.compressed, dict)
assert.NoError(err)

// assert.True(bytes.Equal(data, s2Decompressed))
Expand All @@ -189,6 +194,8 @@ func BenchmarkAverageBatch(b *testing.B) {
b.Fatal(err)
}

dict := getDictionnary()

// benchmark s2
// b.Run("s2", func(b *testing.B) {
// for i := 0; i < b.N; i++ {
Expand All @@ -212,7 +219,7 @@ func BenchmarkAverageBatch(b *testing.B) {
// benchmark lzss
b.Run("lzss", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := compresslzss_v1(data)
_, err := compresslzss_v1(data, dict)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -280,21 +287,30 @@ func compressWithZstd(data []byte) (compressResult, error) {
return res, nil
}

func decompresslzss_v1(data []byte) ([]byte, error) {
func decompresslzss_v1(data, dict []byte) ([]byte, error) {
data = append(dict, data...)
return DecompressPureGo(data, Settings{
BackRefSettings: BackRefSettings{
NbBytesAddress: 2,
NbBytesLength: 1,
NbBytesAddress: nbBytesAddress,
NbBytesLength: nbBytesLength,
},
StartAt: uint(len(dict)),
})
}

func compresslzss_v1(data []byte) (compressResult, error) {
const (
nbBytesAddress = 3
nbBytesLength = 1
)

func compresslzss_v1(data []byte, dict []byte) (compressResult, error) {
data = append(dict, data...)
c, err := Compress(data, Settings{
BackRefSettings: BackRefSettings{
NbBytesAddress: 2,
NbBytesLength: 1,
NbBytesAddress: nbBytesAddress,
NbBytesLength: nbBytesLength,
},
StartAt: uint(len(dict)),
})
if err != nil {
return compressResult{}, err
Expand All @@ -306,3 +322,13 @@ func compresslzss_v1(data []byte) (compressResult, error) {
ratio: float64(len(data)) / float64(len(c)),
}, nil
}

func getDictionnary() []byte {
// read the dictionary from the file
d, err := ioutil.ReadFile("dict_naive")
if err != nil {
panic(err)
}
d = append(d, bytes.Repeat([]byte{0, 0}, 8)...)
return d
}
1 change: 1 addition & 0 deletions std/compress/lzss_v1/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ func (s BackRefSettings) NbBytes() int {

type Settings struct {
BackRefSettings
StartAt uint
}
10 changes: 9 additions & 1 deletion std/compress/lzss_v1/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ func DecompressPureGo(c []byte, settings Settings) (d []byte, err error) {
return
}

// read until startAt and write bytes as is
tmpBuf := make([]byte, settings.StartAt)
_, err = in.Read(tmpBuf)
if err != nil {
return nil, err
}
out.Write(tmpBuf)

s, err := in.ReadByte()
for err == nil {
if s == 0 {
Expand All @@ -40,7 +48,7 @@ func DecompressPureGo(c []byte, settings Settings) (d []byte, err error) {
s, err = in.ReadByte()
}

return out.Bytes(), nil
return out.Bytes()[settings.StartAt:], nil
}

func readNum(bytes []byte) int { //little endian
Expand Down
Binary file added std/compress/lzss_v1/dict_naive
Binary file not shown.
Loading
Loading