Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Draft) Base64 read & write Buffer #819

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 92 additions & 51 deletions okio/src/commonMain/kotlin/okio/-Base64.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,32 @@
@file:JvmName("-Base64")
package okio

import okio.ByteString.Companion.encodeUtf8
import kotlin.jvm.JvmName

/** @author Alexander Y. Kleymenov */

internal val BASE64 =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".encodeUtf8().data
internal val BASE64_URL_SAFE =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_".encodeUtf8().data
internal const val BASE64 =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
internal const val BASE64_URL_SAFE =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"

internal fun String.decodeBase64ToArray(): ByteArray? {
fun Buffer.commonWriteBase64(string: String): Buffer {
// Ignore trailing '=' padding and whitespace from the input.
var limit = length
var limit = string.length
while (limit > 0) {
val c = this[limit - 1]
val c = string[limit - 1]
if (c != '=' && c != '\n' && c != '\r' && c != ' ' && c != '\t') {
break
}
limit--
}

// If the input includes whitespace, this output array will be longer than necessary.
val out = ByteArray((limit * 6L / 8L).toInt())
var outCount = 0
var inCount = 0

var word = 0
for (pos in 0 until limit) {
val c = this[pos]

var pos = 0
var s = head
while (pos < limit) {
val c = string[pos++]
val bits: Int
if (c in 'A'..'Z') {
// char ASCII value
Expand All @@ -71,7 +67,7 @@ internal fun String.decodeBase64ToArray(): ByteArray? {
} else if (c == '\n' || c == '\r' || c == ' ' || c == '\t') {
continue
} else {
return null
throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException?
}

// Append this char's 6 bits to the word.
Expand All @@ -80,69 +76,114 @@ internal fun String.decodeBase64ToArray(): ByteArray? {
// For every 4 chars of input, we accumulate 24 bits of output. Emit 3 bytes.
inCount++
if (inCount % 4 == 0) {
out[outCount++] = (word shr 16).toByte()
out[outCount++] = (word shr 8).toByte()
out[outCount++] = word.toByte()
if (s == null || s.limit + 3 > Segment.SIZE) {
// For simplicity, don't try to write blocks across different segments, allocate new segment when current doesn't have enough capacity
s = writableSegment(3)
}
val data = s.data
var i = s.limit
data[i++] = (word shr 16).toByte()
data[i++] = (word shr 8).toByte()
data[i++] = word.toByte()
s.limit = i
size += 3
}
}

val lastWordChars = inCount % 4
when (lastWordChars) {
1 -> {
// We read 1 char followed by "===". But 6 bits is a truncated byte! Fail.
return null
throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException?
}
2 -> {
// We read 2 chars followed by "==". Emit 1 byte with 8 of those 12 bits.
if (s == null || s.limit + 1 > Segment.SIZE) {
s = writableSegment(1)
}
word = word shl 12
out[outCount++] = (word shr 16).toByte()
s.data[s.limit++] = (word shr 16).toByte()
size += 1
}
3 -> {
// We read 3 chars, followed by "=". Emit 2 bytes for 16 of those 18 bits.
if (s == null || s.limit + 2 > Segment.SIZE) {
s = writableSegment(2)
}
word = word shl 6
out[outCount++] = (word shr 16).toByte()
out[outCount++] = (word shr 8).toByte()
val data = s.data
var i = s.limit
data[i++] = (word shr 16).toByte()
data[i++] = (word shr 8).toByte()
s.limit = i
size += 2
}
}

// If we sized our out array perfectly, we're done.
if (outCount == out.size) return out

// Copy the decoded bytes to a new, right-sized array.
return out.copyOf(outCount)
return this
}

internal fun ByteArray.encodeBase64(map: ByteArray = BASE64): String {
val length = (size + 2) / 3 * 4
val out = ByteArray(length)
fun Buffer.commonReadBase64(): String =
readBase64(BASE64)

fun Buffer.commonReadBase64Url(): String =
readBase64(BASE64_URL_SAFE)

private fun Buffer.readBase64(map: String = BASE64): String {
val length = ((size + 2) / 3 * 4).toInt() // TODO: Prevent Int overflow / arithmetic overflow ?
val out = CharArray(length)
var index = 0
val end = size - size % 3
var i = 0
while (i < end) {
val b0 = this[i++].toInt()
val b1 = this[i++].toInt()
val b2 = this[i++].toInt()
out[index++] = map[(b0 and 0xff shr 2)]
out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)]
out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)]
out[index++] = map[(b2 and 0x3f)]
while (size >= 3) {
val s = head!!
val segmentSize = s.limit - s.pos
if (segmentSize > 3) {
// Read all complete blocks from head segment
val data = s.data
val end = s.limit - segmentSize % 3
var i = s.pos
while (i < end) {
val b0 = data[i++].toInt()
val b1 = data[i++].toInt()
val b2 = data[i++].toInt()
out[index++] = map[(b0 and 0xff shr 2)]
out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)]
out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)]
out[index++] = map[(b2 and 0x3f)]
}
size -= end - s.pos
if (end == s.limit) {
head = s.pop()
SegmentPool.recycle(s)
} else {
s.pos = end
}
} else {
// Read next block, which is spread over multiple segments
val b0 = readByte().toInt()
val b1 = readByte().toInt()
val b2 = readByte().toInt()
out[index++] = map[(b0 and 0xff shr 2)]
out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)]
out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)]
out[index++] = map[(b2 and 0x3f)]
}
}
when (size - end) {
1 -> {
val b0 = this[i].toInt()
when (size) {
1L -> {
val b0 = readByte().toInt()
out[index++] = map[b0 and 0xff shr 2]
out[index++] = map[b0 and 0x03 shl 4]
out[index++] = '='.toByte()
out[index] = '='.toByte()
out[index++] = '='
out[index] = '='
}
2 -> {
val b0 = this[i++].toInt()
val b1 = this[i].toInt()
2L -> {
val b0 = readByte().toInt()
val b1 = readByte().toInt()
out[index++] = map[(b0 and 0xff shr 2)]
out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)]
out[index++] = map[(b1 and 0x0f shl 2)]
out[index] = '='.toByte()
out[index] = '='
}
}
return out.toUtf8String()
return out.concatToString()
}
2 changes: 2 additions & 0 deletions okio/src/commonMain/kotlin/okio/Buffer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ expect class Buffer() : BufferedSource, BufferedSink {

override fun writeHexadecimalUnsignedLong(v: Long): Buffer

override fun writeBase64(string: String): Buffer

/** Returns a deep copy of this buffer. */
fun copy(): Buffer

Expand Down
5 changes: 5 additions & 0 deletions okio/src/commonMain/kotlin/okio/BufferedSink.kt
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ expect interface BufferedSink : Sink {
*/
fun writeHexadecimalUnsignedLong(v: Long): BufferedSink

/**
* Decodes the Base64-encoded bytes from [string] and writes them to this sink.
*/
fun writeBase64(string: String): BufferedSink

/**
* Writes all buffered data to the underlying sink, if one exists. Then that sink is recursively
* flushed which pushes data as far as possible towards its ultimate destination. Typically that
Expand Down
12 changes: 12 additions & 0 deletions okio/src/commonMain/kotlin/okio/BufferedSource.kt
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,18 @@ expect interface BufferedSource : Source {
*/
fun readUtf8CodePoint(): Int

/**
* Removes all bytes from this, encodes them as as [Base64](http://www.ietf.org/rfc/rfc2045.txt), and returns the
* string. In violation of the RFC, the returned string does not wrap lines at 76 columns.
*/
fun readBase64(): String

/**
* Removes all bytes from this, encodes them as as [URL-safe Base64](http://www.ietf.org/rfc/rfc4648.txt), and
* returns the string.
*/
fun readBase64Url(): String

/** Equivalent to [indexOf(b, 0)][indexOf]. */
fun indexOf(b: Byte): Long

Expand Down
19 changes: 12 additions & 7 deletions okio/src/commonMain/kotlin/okio/internal/ByteString.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@

package okio.internal

import okio.BASE64_URL_SAFE
import okio.Buffer
import okio.ByteString
import okio.REPLACEMENT_CODE_POINT
import okio.and
import okio.arrayRangeEquals
import okio.asUtf8ToByteArray
import okio.checkOffsetAndCount
import okio.decodeBase64ToArray
import okio.encodeBase64
import okio.isIsoControl
import okio.processUtf8CodePoints
import okio.commonReadBase64Url
import okio.shr
import okio.toUtf8String

Expand All @@ -46,10 +44,12 @@ internal inline fun ByteString.commonUtf8(): String {
}

@Suppress("NOTHING_TO_INLINE")
internal inline fun ByteString.commonBase64(): String = data.encodeBase64()
internal inline fun ByteString.commonBase64(): String =
Buffer().write(this).readBase64()

@Suppress("NOTHING_TO_INLINE")
internal inline fun ByteString.commonBase64Url() = data.encodeBase64(map = BASE64_URL_SAFE)
internal inline fun ByteString.commonBase64Url(): String =
Buffer().write(this).commonReadBase64Url()

internal val HEX_DIGIT_CHARS =
charArrayOf('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f')
Expand Down Expand Up @@ -266,8 +266,13 @@ internal inline fun String.commonEncodeUtf8(): ByteString {

@Suppress("NOTHING_TO_INLINE")
internal inline fun String.commonDecodeBase64(): ByteString? {
val decoded = decodeBase64ToArray()
return if (decoded != null) ByteString(decoded) else null
val buffer = Buffer()
try {
buffer.writeBase64(this)
} catch (e: IllegalArgumentException) { // TODO: Dedicated Base64 exception?
return null
}
return buffer.readByteString()
}

@Suppress("NOTHING_TO_INLINE")
Expand Down
6 changes: 6 additions & 0 deletions okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ internal inline fun RealBufferedSink.commonWriteHexadecimalUnsignedLong(v: Long)
return emitCompleteSegments()
}

internal inline fun RealBufferedSink.commonWriteBase64(string: String): BufferedSink {
check(!closed) { "closed" }
buffer.writeBase64(string)
return emitCompleteSegments()
}

internal inline fun RealBufferedSink.commonEmitCompleteSegments(): BufferedSink {
check(!closed) { "closed" }
val byteCount = buffer.completeSegmentByteCount()
Expand Down
10 changes: 10 additions & 0 deletions okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ internal inline fun RealBufferedSource.commonReadByteArray(byteCount: Long): Byt
return buffer.readByteArray(byteCount)
}

internal inline fun RealBufferedSource.commonReadBase64(): String {
buffer.writeAll(source)
return buffer.readBase64()
}

internal inline fun RealBufferedSource.commonReadBase64Url(): String {
buffer.writeAll(source)
return buffer.readBase64Url()
}

internal inline fun RealBufferedSource.commonReadFully(sink: ByteArray) {
try {
require(sink.size.toLong())
Expand Down
7 changes: 7 additions & 0 deletions okio/src/jsMain/kotlin/okio/Buffer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ actual class Buffer : BufferedSource, BufferedSink {

override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint()

override fun readBase64(): String = commonReadBase64()

override fun readBase64Url(): String = commonReadBase64Url()

override fun select(options: Options): Int = commonSelect(options)

override fun readByteArray(): ByteArray = commonReadByteArray()
Expand Down Expand Up @@ -192,6 +196,9 @@ actual class Buffer : BufferedSource, BufferedSink {
actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer =
commonWriteHexadecimalUnsignedLong(v)

actual override fun writeBase64(string: String): Buffer =
commonWriteBase64(string)

override fun write(source: Buffer, byteCount: Long): Unit = commonWrite(source, byteCount)

override fun read(sink: Buffer, byteCount: Long): Long = commonRead(sink, byteCount)
Expand Down
2 changes: 2 additions & 0 deletions okio/src/jsMain/kotlin/okio/BufferedSink.kt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ actual interface BufferedSink : Sink {

actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink

actual fun writeBase64(string: String): BufferedSink

actual fun emit(): BufferedSink

actual fun emitCompleteSegments(): BufferedSink
Expand Down
4 changes: 4 additions & 0 deletions okio/src/jsMain/kotlin/okio/BufferedSource.kt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ actual interface BufferedSource : Source {

actual fun readUtf8CodePoint(): Int

actual fun readBase64(): String

actual fun readBase64Url(): String

actual fun indexOf(b: Byte): Long

actual fun indexOf(b: Byte, fromIndex: Long): Long
Expand Down
2 changes: 2 additions & 0 deletions okio/src/jsMain/kotlin/okio/RealBufferedSink.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import okio.internal.commonTimeout
import okio.internal.commonToString
import okio.internal.commonWrite
import okio.internal.commonWriteAll
import okio.internal.commonWriteBase64
import okio.internal.commonWriteByte
import okio.internal.commonWriteDecimalLong
import okio.internal.commonWriteHexadecimalUnsignedLong
Expand Down Expand Up @@ -66,6 +67,7 @@ internal actual class RealBufferedSink actual constructor(
override fun writeLongLe(v: Long) = commonWriteLongLe(v)
override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v)
override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v)
override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string)
override fun emitCompleteSegments() = commonEmitCompleteSegments()
override fun emit() = commonEmit()
override fun flush() = commonFlush()
Expand Down
Loading