diff --git a/okio/src/commonMain/kotlin/okio/-Base64.kt b/okio/src/commonMain/kotlin/okio/-Base64.kt index bdf2e1bc65..9d401ce3f3 100644 --- a/okio/src/commonMain/kotlin/okio/-Base64.kt +++ b/okio/src/commonMain/kotlin/okio/-Base64.kt @@ -18,36 +18,32 @@ @file:JvmName("-Base64") package okio -import okio.ByteString.Companion.encodeUtf8 import kotlin.jvm.JvmName /** @author Alexander Y. Kleymenov */ -internal val BASE64 = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".encodeUtf8().data -internal val BASE64_URL_SAFE = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_".encodeUtf8().data +internal const val BASE64 = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +internal const val BASE64_URL_SAFE = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" -internal fun String.decodeBase64ToArray(): ByteArray? { +fun Buffer.commonWriteBase64(string: String): Buffer { // Ignore trailing '=' padding and whitespace from the input. - var limit = length + var limit = string.length while (limit > 0) { - val c = this[limit - 1] + val c = string[limit - 1] if (c != '=' && c != '\n' && c != '\r' && c != ' ' && c != '\t') { break } limit-- } - // If the input includes whitespace, this output array will be longer than necessary. - val out = ByteArray((limit * 6L / 8L).toInt()) - var outCount = 0 var inCount = 0 - var word = 0 - for (pos in 0 until limit) { - val c = this[pos] - + var pos = 0 + var s = head + while (pos < limit) { + val c = string[pos++] val bits: Int if (c in 'A'..'Z') { // char ASCII value @@ -71,7 +67,7 @@ internal fun String.decodeBase64ToArray(): ByteArray? { } else if (c == '\n' || c == '\r' || c == ' ' || c == '\t') { continue } else { - return null + throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? } // Append this char's 6 bits to the word. @@ -80,9 +76,17 @@ internal fun String.decodeBase64ToArray(): ByteArray? { // For every 4 chars of input, we accumulate 24 bits of output. Emit 3 bytes. inCount++ if (inCount % 4 == 0) { - out[outCount++] = (word shr 16).toByte() - out[outCount++] = (word shr 8).toByte() - out[outCount++] = word.toByte() + if (s == null || s.limit + 3 > Segment.SIZE) { + // For simplicity, don't try to write blocks across different segments, allocate new segment when current doesn't have enough capacity + s = writableSegment(3) + } + val data = s.data + var i = s.limit + data[i++] = (word shr 16).toByte() + data[i++] = (word shr 8).toByte() + data[i++] = word.toByte() + s.limit = i + size += 3 } } @@ -90,59 +94,96 @@ internal fun String.decodeBase64ToArray(): ByteArray? { when (lastWordChars) { 1 -> { // We read 1 char followed by "===". But 6 bits is a truncated byte! Fail. - return null + throw IllegalArgumentException("Invalid Base64") // TODO: Dedicated exception? IOException? } 2 -> { // We read 2 chars followed by "==". Emit 1 byte with 8 of those 12 bits. + if (s == null || s.limit + 1 > Segment.SIZE) { + s = writableSegment(1) + } word = word shl 12 - out[outCount++] = (word shr 16).toByte() + s.data[s.limit++] = (word shr 16).toByte() + size += 1 } 3 -> { // We read 3 chars, followed by "=". Emit 2 bytes for 16 of those 18 bits. + if (s == null || s.limit + 2 > Segment.SIZE) { + s = writableSegment(2) + } word = word shl 6 - out[outCount++] = (word shr 16).toByte() - out[outCount++] = (word shr 8).toByte() + val data = s.data + var i = s.limit + data[i++] = (word shr 16).toByte() + data[i++] = (word shr 8).toByte() + s.limit = i + size += 2 } } - // If we sized our out array perfectly, we're done. - if (outCount == out.size) return out - - // Copy the decoded bytes to a new, right-sized array. - return out.copyOf(outCount) + return this } -internal fun ByteArray.encodeBase64(map: ByteArray = BASE64): String { - val length = (size + 2) / 3 * 4 - val out = ByteArray(length) +fun Buffer.commonReadBase64(): String = + readBase64(BASE64) + +fun Buffer.commonReadBase64Url(): String = + readBase64(BASE64_URL_SAFE) + +private fun Buffer.readBase64(map: String = BASE64): String { + val length = ((size + 2) / 3 * 4).toInt() // TODO: Prevent Int overflow / arithmetic overflow ? + val out = CharArray(length) var index = 0 - val end = size - size % 3 - var i = 0 - while (i < end) { - val b0 = this[i++].toInt() - val b1 = this[i++].toInt() - val b2 = this[i++].toInt() - out[index++] = map[(b0 and 0xff shr 2)] - out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] - out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] - out[index++] = map[(b2 and 0x3f)] + while (size >= 3) { + val s = head!! + val segmentSize = s.limit - s.pos + if (segmentSize > 3) { + // Read all complete blocks from head segment + val data = s.data + val end = s.limit - segmentSize % 3 + var i = s.pos + while (i < end) { + val b0 = data[i++].toInt() + val b1 = data[i++].toInt() + val b2 = data[i++].toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] + out[index++] = map[(b2 and 0x3f)] + } + size -= end - s.pos + if (end == s.limit) { + head = s.pop() + SegmentPool.recycle(s) + } else { + s.pos = end + } + } else { + // Read next block, which is spread over multiple segments + val b0 = readByte().toInt() + val b1 = readByte().toInt() + val b2 = readByte().toInt() + out[index++] = map[(b0 and 0xff shr 2)] + out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] + out[index++] = map[(b1 and 0x0f shl 2) or (b2 and 0xff shr 6)] + out[index++] = map[(b2 and 0x3f)] + } } - when (size - end) { - 1 -> { - val b0 = this[i].toInt() + when (size) { + 1L -> { + val b0 = readByte().toInt() out[index++] = map[b0 and 0xff shr 2] out[index++] = map[b0 and 0x03 shl 4] - out[index++] = '='.toByte() - out[index] = '='.toByte() + out[index++] = '=' + out[index] = '=' } - 2 -> { - val b0 = this[i++].toInt() - val b1 = this[i].toInt() + 2L -> { + val b0 = readByte().toInt() + val b1 = readByte().toInt() out[index++] = map[(b0 and 0xff shr 2)] out[index++] = map[(b0 and 0x03 shl 4) or (b1 and 0xff shr 4)] out[index++] = map[(b1 and 0x0f shl 2)] - out[index] = '='.toByte() + out[index] = '=' } } - return out.toUtf8String() + return out.concatToString() } diff --git a/okio/src/commonMain/kotlin/okio/Buffer.kt b/okio/src/commonMain/kotlin/okio/Buffer.kt index 1ff7342556..0bbc7abb07 100644 --- a/okio/src/commonMain/kotlin/okio/Buffer.kt +++ b/okio/src/commonMain/kotlin/okio/Buffer.kt @@ -114,6 +114,8 @@ expect class Buffer() : BufferedSource, BufferedSink { override fun writeHexadecimalUnsignedLong(v: Long): Buffer + override fun writeBase64(string: String): Buffer + /** Returns a deep copy of this buffer. */ fun copy(): Buffer diff --git a/okio/src/commonMain/kotlin/okio/BufferedSink.kt b/okio/src/commonMain/kotlin/okio/BufferedSink.kt index 40c26585c8..0646ae49b9 100644 --- a/okio/src/commonMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/commonMain/kotlin/okio/BufferedSink.kt @@ -241,6 +241,11 @@ expect interface BufferedSink : Sink { */ fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + /** + * Decodes the Base64-encoded bytes from [string] and writes them to this sink. + */ + fun writeBase64(string: String): BufferedSink + /** * Writes all buffered data to the underlying sink, if one exists. Then that sink is recursively * flushed which pushes data as far as possible towards its ultimate destination. Typically that diff --git a/okio/src/commonMain/kotlin/okio/BufferedSource.kt b/okio/src/commonMain/kotlin/okio/BufferedSource.kt index 0ba4d152bd..eccda298c0 100644 --- a/okio/src/commonMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/commonMain/kotlin/okio/BufferedSource.kt @@ -421,6 +421,18 @@ expect interface BufferedSource : Source { */ fun readUtf8CodePoint(): Int + /** + * Removes all bytes from this, encodes them as as [Base64](http://www.ietf.org/rfc/rfc2045.txt), and returns the + * string. In violation of the RFC, the returned string does not wrap lines at 76 columns. + */ + fun readBase64(): String + + /** + * Removes all bytes from this, encodes them as as [URL-safe Base64](http://www.ietf.org/rfc/rfc4648.txt), and + * returns the string. + */ + fun readBase64Url(): String + /** Equivalent to [indexOf(b, 0)][indexOf]. */ fun indexOf(b: Byte): Long diff --git a/okio/src/commonMain/kotlin/okio/internal/ByteString.kt b/okio/src/commonMain/kotlin/okio/internal/ByteString.kt index 5f0ec021e4..1d4fe82ded 100644 --- a/okio/src/commonMain/kotlin/okio/internal/ByteString.kt +++ b/okio/src/commonMain/kotlin/okio/internal/ByteString.kt @@ -16,7 +16,6 @@ package okio.internal -import okio.BASE64_URL_SAFE import okio.Buffer import okio.ByteString import okio.REPLACEMENT_CODE_POINT @@ -24,10 +23,9 @@ import okio.and import okio.arrayRangeEquals import okio.asUtf8ToByteArray import okio.checkOffsetAndCount -import okio.decodeBase64ToArray -import okio.encodeBase64 import okio.isIsoControl import okio.processUtf8CodePoints +import okio.commonReadBase64Url import okio.shr import okio.toUtf8String @@ -46,10 +44,12 @@ internal inline fun ByteString.commonUtf8(): String { } @Suppress("NOTHING_TO_INLINE") -internal inline fun ByteString.commonBase64(): String = data.encodeBase64() +internal inline fun ByteString.commonBase64(): String = + Buffer().write(this).readBase64() @Suppress("NOTHING_TO_INLINE") -internal inline fun ByteString.commonBase64Url() = data.encodeBase64(map = BASE64_URL_SAFE) +internal inline fun ByteString.commonBase64Url(): String = + Buffer().write(this).commonReadBase64Url() internal val HEX_DIGIT_CHARS = charArrayOf('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f') @@ -266,8 +266,13 @@ internal inline fun String.commonEncodeUtf8(): ByteString { @Suppress("NOTHING_TO_INLINE") internal inline fun String.commonDecodeBase64(): ByteString? { - val decoded = decodeBase64ToArray() - return if (decoded != null) ByteString(decoded) else null + val buffer = Buffer() + try { + buffer.writeBase64(this) + } catch (e: IllegalArgumentException) { // TODO: Dedicated Base64 exception? + return null + } + return buffer.readByteString() } @Suppress("NOTHING_TO_INLINE") diff --git a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt index 49b0c4d755..54aae6bd6b 100644 --- a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt +++ b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSink.kt @@ -163,6 +163,12 @@ internal inline fun RealBufferedSink.commonWriteHexadecimalUnsignedLong(v: Long) return emitCompleteSegments() } +internal inline fun RealBufferedSink.commonWriteBase64(string: String): BufferedSink { + check(!closed) { "closed" } + buffer.writeBase64(string) + return emitCompleteSegments() +} + internal inline fun RealBufferedSink.commonEmitCompleteSegments(): BufferedSink { check(!closed) { "closed" } val byteCount = buffer.completeSegmentByteCount() diff --git a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt index 40938ea2f8..b955058449 100644 --- a/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt +++ b/okio/src/commonMain/kotlin/okio/internal/RealBufferedSource.kt @@ -110,6 +110,16 @@ internal inline fun RealBufferedSource.commonReadByteArray(byteCount: Long): Byt return buffer.readByteArray(byteCount) } +internal inline fun RealBufferedSource.commonReadBase64(): String { + buffer.writeAll(source) + return buffer.readBase64() +} + +internal inline fun RealBufferedSource.commonReadBase64Url(): String { + buffer.writeAll(source) + return buffer.readBase64Url() +} + internal inline fun RealBufferedSource.commonReadFully(sink: ByteArray) { try { require(sink.size.toLong()) diff --git a/okio/src/jsMain/kotlin/okio/Buffer.kt b/okio/src/jsMain/kotlin/okio/Buffer.kt index 402bf6c41d..28a742ea49 100644 --- a/okio/src/jsMain/kotlin/okio/Buffer.kt +++ b/okio/src/jsMain/kotlin/okio/Buffer.kt @@ -130,6 +130,10 @@ actual class Buffer : BufferedSource, BufferedSink { override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint() + override fun readBase64(): String = commonReadBase64() + + override fun readBase64Url(): String = commonReadBase64Url() + override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() @@ -192,6 +196,9 @@ actual class Buffer : BufferedSource, BufferedSink { actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer = commonWriteHexadecimalUnsignedLong(v) + actual override fun writeBase64(string: String): Buffer = + commonWriteBase64(string) + override fun write(source: Buffer, byteCount: Long): Unit = commonWrite(source, byteCount) override fun read(sink: Buffer, byteCount: Long): Long = commonRead(sink, byteCount) diff --git a/okio/src/jsMain/kotlin/okio/BufferedSink.kt b/okio/src/jsMain/kotlin/okio/BufferedSink.kt index 65d717c60a..20db6cbc00 100644 --- a/okio/src/jsMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/jsMain/kotlin/okio/BufferedSink.kt @@ -54,6 +54,8 @@ actual interface BufferedSink : Sink { actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + actual fun writeBase64(string: String): BufferedSink + actual fun emit(): BufferedSink actual fun emitCompleteSegments(): BufferedSink diff --git a/okio/src/jsMain/kotlin/okio/BufferedSource.kt b/okio/src/jsMain/kotlin/okio/BufferedSource.kt index 98b7718a14..23323e81e6 100644 --- a/okio/src/jsMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/jsMain/kotlin/okio/BufferedSource.kt @@ -76,6 +76,10 @@ actual interface BufferedSource : Source { actual fun readUtf8CodePoint(): Int + actual fun readBase64(): String + + actual fun readBase64Url(): String + actual fun indexOf(b: Byte): Long actual fun indexOf(b: Byte, fromIndex: Long): Long diff --git a/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt b/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt index ed03094ec3..924fe32316 100644 --- a/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt +++ b/okio/src/jsMain/kotlin/okio/RealBufferedSink.kt @@ -24,6 +24,7 @@ import okio.internal.commonTimeout import okio.internal.commonToString import okio.internal.commonWrite import okio.internal.commonWriteAll +import okio.internal.commonWriteBase64 import okio.internal.commonWriteByte import okio.internal.commonWriteDecimalLong import okio.internal.commonWriteHexadecimalUnsignedLong @@ -66,6 +67,7 @@ internal actual class RealBufferedSink actual constructor( override fun writeLongLe(v: Long) = commonWriteLongLe(v) override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v) override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v) + override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string) override fun emitCompleteSegments() = commonEmitCompleteSegments() override fun emit() = commonEmit() override fun flush() = commonFlush() diff --git a/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt b/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt index d6f4b94221..bec6bd7b87 100644 --- a/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt +++ b/okio/src/jsMain/kotlin/okio/RealBufferedSource.kt @@ -23,6 +23,8 @@ import okio.internal.commonPeek import okio.internal.commonRangeEquals import okio.internal.commonRead import okio.internal.commonReadAll +import okio.internal.commonReadBase64 +import okio.internal.commonReadBase64Url import okio.internal.commonReadByte import okio.internal.commonReadByteArray import okio.internal.commonReadByteString @@ -62,6 +64,8 @@ internal actual class RealBufferedSource actual constructor( override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) + override fun readBase64(): String = commonReadBase64() + override fun readBase64Url(): String = commonReadBase64Url() override fun read(sink: ByteArray): Int = read(sink, 0, sink.size) override fun readFully(sink: ByteArray): Unit = commonReadFully(sink) override fun read(sink: ByteArray, offset: Int, byteCount: Int): Int = diff --git a/okio/src/jvmMain/kotlin/okio/Buffer.kt b/okio/src/jvmMain/kotlin/okio/Buffer.kt index c514bb6eba..9c76b01bb3 100644 --- a/okio/src/jvmMain/kotlin/okio/Buffer.kt +++ b/okio/src/jvmMain/kotlin/okio/Buffer.kt @@ -330,6 +330,12 @@ actual class Buffer : BufferedSource, BufferedSink, Cloneable, ByteChannel { @Throws(EOFException::class) override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint() + @Throws(EOFException::class) + override fun readBase64(): String = commonReadBase64() + + @Throws(EOFException::class) + override fun readBase64Url(): String = commonReadBase64Url() + override fun readByteArray() = commonReadByteArray() @Throws(EOFException::class) @@ -448,6 +454,9 @@ actual class Buffer : BufferedSource, BufferedSink, Cloneable, ByteChannel { actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer = commonWriteHexadecimalUnsignedLong(v) + actual override fun writeBase64(string: String): Buffer = + commonWriteBase64(string) + internal actual fun writableSegment(minimumCapacity: Int): Segment = commonWritableSegment(minimumCapacity) diff --git a/okio/src/jvmMain/kotlin/okio/BufferedSink.kt b/okio/src/jvmMain/kotlin/okio/BufferedSink.kt index edb632f910..953293ae2b 100644 --- a/okio/src/jvmMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/jvmMain/kotlin/okio/BufferedSink.kt @@ -90,6 +90,9 @@ actual interface BufferedSink : Sink, WritableByteChannel { @Throws(IOException::class) actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + @Throws(IOException::class) + actual fun writeBase64(string: String): BufferedSink + @Throws(IOException::class) actual override fun flush() diff --git a/okio/src/jvmMain/kotlin/okio/BufferedSource.kt b/okio/src/jvmMain/kotlin/okio/BufferedSource.kt index b30c635ae2..24d328f2a1 100644 --- a/okio/src/jvmMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/jvmMain/kotlin/okio/BufferedSource.kt @@ -117,6 +117,12 @@ actual interface BufferedSource : Source, ReadableByteChannel { @Throws(IOException::class) actual fun readUtf8CodePoint(): Int + @Throws(IOException::class) + actual fun readBase64(): String + + @Throws(IOException::class) + actual fun readBase64Url(): String + /** Removes all bytes from this, decodes them as `charset`, and returns the string. */ @Throws(IOException::class) fun readString(charset: Charset): String diff --git a/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt b/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt index 7df3f93776..d3fb3ec5ef 100644 --- a/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt +++ b/okio/src/jvmMain/kotlin/okio/RealBufferedSink.kt @@ -23,6 +23,7 @@ import okio.internal.commonTimeout import okio.internal.commonToString import okio.internal.commonWrite import okio.internal.commonWriteAll +import okio.internal.commonWriteBase64 import okio.internal.commonWriteByte import okio.internal.commonWriteDecimalLong import okio.internal.commonWriteHexadecimalUnsignedLong @@ -100,6 +101,7 @@ internal actual class RealBufferedSink actual constructor( override fun writeLongLe(v: Long) = commonWriteLongLe(v) override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v) override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v) + override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string) override fun emitCompleteSegments() = commonEmitCompleteSegments() override fun emit() = commonEmit() diff --git a/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt b/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt index 109ef1402e..1a58f9e288 100644 --- a/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt +++ b/okio/src/jvmMain/kotlin/okio/RealBufferedSource.kt @@ -23,6 +23,8 @@ import okio.internal.commonPeek import okio.internal.commonRangeEquals import okio.internal.commonRead import okio.internal.commonReadAll +import okio.internal.commonReadBase64 +import okio.internal.commonReadBase64Url import okio.internal.commonReadByte import okio.internal.commonReadByteArray import okio.internal.commonReadByteString @@ -72,6 +74,8 @@ internal actual class RealBufferedSource actual constructor( override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) + override fun readBase64(): String = commonReadBase64() + override fun readBase64Url(): String = commonReadBase64Url() override fun read(sink: ByteArray): Int = read(sink, 0, sink.size) override fun readFully(sink: ByteArray): Unit = commonReadFully(sink) override fun read(sink: ByteArray, offset: Int, byteCount: Int): Int = diff --git a/okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt b/okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt new file mode 100644 index 0000000000..73df64b147 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferBase64Test.kt @@ -0,0 +1,135 @@ +package okio + +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +@RunWith(Parameterized::class) +class BufferBase64Test(val size: Int) { + + private val random = Random(4975347L + size) + private val bytes = random.nextBytes(size) + + companion object { + @get:Parameterized.Parameters(name = "{0}") + @get:JvmStatic + val parameters: List + get() = (0..32).toList() + ((Segment.SIZE - 32)..Segment.SIZE + 32).toList() + + private val base64Encoder = Base64.getEncoder() + private val base64UrlEncoder = Base64.getUrlEncoder() + + private fun Random.nextBytes(size: Int): ByteArray = + ByteArray(size).also { nextBytes(it) } + } + + @Test + fun write() { + val encoded = base64Encoder.encodeToString(bytes) + + val buffer = Buffer().apply { writeBase64(encoded) } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun writeWithWhitespace() { + val encoded = base64Encoder.encodeToString(bytes).chunked(8).joinToString("\n") + + val buffer = Buffer().apply { writeBase64(encoded) } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun writeCorruptedInvalidChar() { + if (size == 0) return // Skip this test when there is no data + val corruptedIndex = random.nextInt(size) + val encoded = + base64Encoder.encodeToString(bytes).replaceRange(corruptedIndex..corruptedIndex, "?") + + val buffer = Buffer() + assertFailsWith { + buffer.writeBase64(encoded) + } + } + + @Test + fun writeCorruptedInvalidLength() { + if (size == 0) return // Skip this test when there is no data + val encoded = base64Encoder.encodeToString(bytes) + "A" + + val buffer = Buffer() + assertFailsWith { + buffer.writeBase64(encoded) + } + } + + @Test + fun writeMultiple() { + val buffer = Buffer().apply { + bytes.asList().chunked(4).forEach { + val encoded = base64Encoder.encodeToString(it.toByteArray()) + writeBase64(encoded) + } + } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun writeUrlEncoded() { + val encoded = base64UrlEncoder.encodeToString(bytes) + + val buffer = Buffer().apply { writeBase64(encoded) } + + val byteArray = buffer.readByteArray() + assertArrayEquals(bytes, byteArray) + } + + @Test + fun read() { + val buffer = Buffer().apply { write(bytes) } + + val s = buffer.readBase64() + + assertEquals(base64Encoder.encodeToString(bytes), s) + } + + @Test + fun readUrlEncoded() { + val buffer = Buffer().apply { write(bytes) } + + val s = buffer.readBase64Url() + + val encoded = base64UrlEncoder.encodeToString(bytes) + assertEquals(encoded, s) + } + + @Test + fun readFragmented() { + // Buffer made of segments with only one byte, randomly located + val buffer = Buffer().apply { + bytes.forEach { + val s = writableSegment(Segment.SIZE) + check(s.pos == 0 && s.limit == 0) // Implementation should provide an empty segment + val pos = random.nextInt(Segment.SIZE) + s.pos = pos + s.data[pos] = it + s.limit = pos + 1 + size++ + } + } + + val s = buffer.readBase64() + + val encoded = base64Encoder.encodeToString(bytes) + assertEquals(encoded, s) + } +} diff --git a/okio/src/nativeMain/kotlin/okio/Buffer.kt b/okio/src/nativeMain/kotlin/okio/Buffer.kt index 49623cbae8..63be873dc7 100644 --- a/okio/src/nativeMain/kotlin/okio/Buffer.kt +++ b/okio/src/nativeMain/kotlin/okio/Buffer.kt @@ -132,6 +132,10 @@ actual class Buffer : BufferedSource, BufferedSink { override fun readUtf8CodePoint(): Int = commonReadUtf8CodePoint() + override fun readBase64(): String = commonReadBase64() + + override fun readBase64Url(): String = commonReadBase64Url() + override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) @@ -192,6 +196,9 @@ actual class Buffer : BufferedSource, BufferedSink { actual override fun writeHexadecimalUnsignedLong(v: Long): Buffer = commonWriteHexadecimalUnsignedLong(v) + actual override fun writeBase64(string: String): Buffer = + commonWriteBase64(string) + override fun write(source: Buffer, byteCount: Long): Unit = commonWrite(source, byteCount) override fun read(sink: Buffer, byteCount: Long): Long = commonRead(sink, byteCount) diff --git a/okio/src/nativeMain/kotlin/okio/BufferedSink.kt b/okio/src/nativeMain/kotlin/okio/BufferedSink.kt index 65d717c60a..20db6cbc00 100644 --- a/okio/src/nativeMain/kotlin/okio/BufferedSink.kt +++ b/okio/src/nativeMain/kotlin/okio/BufferedSink.kt @@ -54,6 +54,8 @@ actual interface BufferedSink : Sink { actual fun writeHexadecimalUnsignedLong(v: Long): BufferedSink + actual fun writeBase64(string: String): BufferedSink + actual fun emit(): BufferedSink actual fun emitCompleteSegments(): BufferedSink diff --git a/okio/src/nativeMain/kotlin/okio/BufferedSource.kt b/okio/src/nativeMain/kotlin/okio/BufferedSource.kt index 98b7718a14..23323e81e6 100644 --- a/okio/src/nativeMain/kotlin/okio/BufferedSource.kt +++ b/okio/src/nativeMain/kotlin/okio/BufferedSource.kt @@ -76,6 +76,10 @@ actual interface BufferedSource : Source { actual fun readUtf8CodePoint(): Int + actual fun readBase64(): String + + actual fun readBase64Url(): String + actual fun indexOf(b: Byte): Long actual fun indexOf(b: Byte, fromIndex: Long): Long diff --git a/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt b/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt index ed03094ec3..924fe32316 100644 --- a/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt +++ b/okio/src/nativeMain/kotlin/okio/RealBufferedSink.kt @@ -24,6 +24,7 @@ import okio.internal.commonTimeout import okio.internal.commonToString import okio.internal.commonWrite import okio.internal.commonWriteAll +import okio.internal.commonWriteBase64 import okio.internal.commonWriteByte import okio.internal.commonWriteDecimalLong import okio.internal.commonWriteHexadecimalUnsignedLong @@ -66,6 +67,7 @@ internal actual class RealBufferedSink actual constructor( override fun writeLongLe(v: Long) = commonWriteLongLe(v) override fun writeDecimalLong(v: Long) = commonWriteDecimalLong(v) override fun writeHexadecimalUnsignedLong(v: Long) = commonWriteHexadecimalUnsignedLong(v) + override fun writeBase64(string: String): BufferedSink = commonWriteBase64(string) override fun emitCompleteSegments() = commonEmitCompleteSegments() override fun emit() = commonEmit() override fun flush() = commonFlush() diff --git a/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt b/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt index d6f4b94221..bec6bd7b87 100644 --- a/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt +++ b/okio/src/nativeMain/kotlin/okio/RealBufferedSource.kt @@ -23,6 +23,8 @@ import okio.internal.commonPeek import okio.internal.commonRangeEquals import okio.internal.commonRead import okio.internal.commonReadAll +import okio.internal.commonReadBase64 +import okio.internal.commonReadBase64Url import okio.internal.commonReadByte import okio.internal.commonReadByteArray import okio.internal.commonReadByteString @@ -62,6 +64,8 @@ internal actual class RealBufferedSource actual constructor( override fun select(options: Options): Int = commonSelect(options) override fun readByteArray(): ByteArray = commonReadByteArray() override fun readByteArray(byteCount: Long): ByteArray = commonReadByteArray(byteCount) + override fun readBase64(): String = commonReadBase64() + override fun readBase64Url(): String = commonReadBase64Url() override fun read(sink: ByteArray): Int = read(sink, 0, sink.size) override fun readFully(sink: ByteArray): Unit = commonReadFully(sink) override fun read(sink: ByteArray, offset: Int, byteCount: Int): Int =