diff --git a/Sources/PotentASN1/ASN1DERReader.swift b/Sources/PotentASN1/ASN1DERReader.swift index a63d58289..4d683c2b2 100644 --- a/Sources/PotentASN1/ASN1DERReader.swift +++ b/Sources/PotentASN1/ASN1DERReader.swift @@ -82,20 +82,20 @@ internal enum ASN1DERReader { private static func parseItem(_ buffer: inout UnsafeBufferPointer) throws -> ASN1 { var (tagValue, itemBuffer) = try parseTagged(&buffer) - defer { - assert(itemBuffer.isEmpty) + + let item = try parseItem(&itemBuffer, as: tagValue) + + if !itemBuffer.isEmpty { + throw ASN1Serialization.Error.invalidTaggedItem } - return try parseItem(&itemBuffer, as: tagValue) + return item } private static func parseItem( _ itemBuffer: inout UnsafeBufferPointer, as tagValue: ASN1.AnyTag ) throws -> ASN1 { - defer { - assert(itemBuffer.isEmpty) - } guard let tag = ASN1.Tag(rawValue: tagValue) else { // Check required constructed types @@ -110,80 +110,87 @@ internal enum ASN1DERReader { } } + let item: ASN1 + switch tag { case .boolean: - return .boolean(try itemBuffer.pop() != 0) + item = .boolean(try itemBuffer.pop() != 0) case .integer: - return .integer(try parseInt(&itemBuffer)) + item = .integer(try parseInt(&itemBuffer)) case .bitString: let unusedBits = try itemBuffer.pop() let data = Data(try itemBuffer.popAll()) - return .bitString((data.count * 8) - Int(unusedBits), data) + item = .bitString((data.count * 8) - Int(unusedBits), data) case .octetString: - return .octetString(Data(try itemBuffer.popAll())) + item = .octetString(Data(try itemBuffer.popAll())) case .null: - return .null + item = .null case .objectIdentifier: - return .objectIdentifier(try parseOID(&itemBuffer)) + item = .objectIdentifier(try parseOID(&itemBuffer)) case .real: - return .real(try parseReal(&itemBuffer)) + item = .real(try parseReal(&itemBuffer)) case .utf8String: - return .utf8String(try parseString(&itemBuffer, encoding: .utf8)) + item = .utf8String(try parseString(&itemBuffer, encoding: .utf8)) case .numericString: - return .numericString(try parseString(&itemBuffer, encoding: .ascii)) + item = .numericString(try parseString(&itemBuffer, encoding: .ascii)) case .printableString: - return .printableString(try parseString(&itemBuffer, encoding: .ascii)) + item = .printableString(try parseString(&itemBuffer, encoding: .ascii)) case .teletexString: - return .teletexString(try parseString(&itemBuffer, encoding: .ascii)) + item = .teletexString(try parseString(&itemBuffer, encoding: .ascii)) case .videotexString: - return .videotexString(try parseString(&itemBuffer, encoding: .ascii)) + item = .videotexString(try parseString(&itemBuffer, encoding: .ascii)) case .ia5String: - return .ia5String(try parseString(&itemBuffer, encoding: .ascii)) + item = .ia5String(try parseString(&itemBuffer, encoding: .ascii)) case .utcTime: - return .utcTime(try parseTime(&itemBuffer, formatter: utcFormatter)) + item = .utcTime(try parseTime(&itemBuffer, formatter: utcFormatter)) case .generalizedTime: - return .generalizedTime(try parseTime(&itemBuffer, formatter: generalizedFormatter)) + item = .generalizedTime(try parseTime(&itemBuffer, formatter: generalizedFormatter)) case .graphicString: - return .graphicString(try parseString(&itemBuffer, encoding: .ascii)) + item = .graphicString(try parseString(&itemBuffer, encoding: .ascii)) case .visibleString: - return .visibleString(try parseString(&itemBuffer, encoding: .ascii)) + item = .visibleString(try parseString(&itemBuffer, encoding: .ascii)) case .generalString: - return .generalString(try parseString(&itemBuffer, encoding: .ascii)) + item = .generalString(try parseString(&itemBuffer, encoding: .ascii)) case .universalString: - return .universalString(try parseString(&itemBuffer, encoding: .ascii)) + item = .universalString(try parseString(&itemBuffer, encoding: .ascii)) case .characterString: - return .characterString(try parseString(&itemBuffer, encoding: .ascii)) + item = .characterString(try parseString(&itemBuffer, encoding: .ascii)) case .bmpString: - return .bmpString(try parseString(&itemBuffer, encoding: .ascii)) + item = .bmpString(try parseString(&itemBuffer, encoding: .ascii)) case .sequence, .set: throw ASN1Serialization.Error.nonConstructedCollection case .objectDescriptor, .external, .enumerated, .embedded, .relativeOID: // Default to saving tagged version - return .tagged(tag.rawValue, Data(try itemBuffer.popAll())) + item = .tagged(tag.rawValue, Data(try itemBuffer.popAll())) + } + + if !itemBuffer.isEmpty { + throw ASN1Serialization.Error.invalidTaggedItem } + return item } private static func parseTime( @@ -213,7 +220,7 @@ internal enum ASN1DERReader { private static func parseReal(_ buffer: inout UnsafeBufferPointer) throws -> Decimal { let lead = try buffer.pop() if lead & 0x40 == 0x40 { - return lead & 0x1 == 0 ? Decimal(Double.infinity) : Decimal(-Double.infinity) + throw ASN1Serialization.Error.unsupportedReal } else if lead & 0xC0 == 0 { let bytes = try buffer.popAll() @@ -303,7 +310,20 @@ internal enum ASN1DERReader { } for _ in 0 ..< numBytes { - length = (length * 0x100) + Int(try buffer.pop()) + + let newLength = (length &* 0x100) &+ Int(try buffer.pop()) + + // Check for overflow + if newLength < length { + throw ASN1Serialization.Error.lengthOverflow + } + + // Check avaiable data + if newLength > buffer.count { + throw ASN1Serialization.Error.unexpectedEOF + } + + length = newLength } return length diff --git a/Sources/PotentASN1/ASN1Serialization.swift b/Sources/PotentASN1/ASN1Serialization.swift index 406efbba1..c91aecdf1 100644 --- a/Sources/PotentASN1/ASN1Serialization.swift +++ b/Sources/PotentASN1/ASN1Serialization.swift @@ -31,10 +31,12 @@ public enum ASN1Serialization { case invalidGeneralizedTime /// Unsupported REAL type. case unsupportedReal - /// Encoded value length could not be stored. + /// Encoded value length could not be stored or exceeds available data. case lengthOverflow /// Number of fields in OID is invalid case invalidObjectIdentifierLength + /// Tagged item was encoded incorrectly + case invalidTaggedItem } /// Read ASN.1/DER encoded data as a collection of ``ASN1`` values. diff --git a/Tests/ASN1Tests.swift b/Tests/ASN1Tests.swift index c9f25ece5..bd6ec8b84 100644 --- a/Tests/ASN1Tests.swift +++ b/Tests/ASN1Tests.swift @@ -249,4 +249,31 @@ class ASN1Tests: XCTestCase { XCTAssertEqual(offset(5), -40920) } + func testRandomData() throws { + + struct TestStruct: Codable, SchemaSpecified { + var id: OID + var data: Data + + static var asn1Schema: Schema { + .sequence([ + "id": .objectIdentifier(), + "data": .octetString(size: .is(16)) + ]) + } + } + + let encoded = try ASN1.Encoder.encode(TestStruct(id: [1, 2, 3, 4, 5], data: Data(count: 16))) + + for _ in 0 ..< 10000 { + + var random = Data(capacity: encoded.count) + for _ in 0 ..< encoded.count { + random.append(UInt8.random(in: 0 ..< .max)) + } + + XCTAssertThrowsError(try ASN1.Decoder.decode(TestStruct.self, from: random)) + } + } + } diff --git a/Tests/AnyValueTests.swift b/Tests/AnyValueTests.swift index 39c0da99f..62fb86ec2 100644 --- a/Tests/AnyValueTests.swift +++ b/Tests/AnyValueTests.swift @@ -270,7 +270,7 @@ class AnyValueTests: XCTestCase { XCTAssertEqual(try AnyValue.wrapped(uuid), .uuid(uuid)) let date = Date() XCTAssertEqual(try AnyValue.wrapped(date), .date(date)) - XCTAssertEqual(try AnyValue.wrapped([1, "test", true]), .array([1, "test", true])) + XCTAssertEqual(try AnyValue.wrapped([1, "test", true] as [Any]), .array([1, "test", true])) // Unorderd dictionaries XCTAssertEqual( @@ -283,9 +283,9 @@ class AnyValueTests: XCTestCase { .dictionaryValue.map { val in Dictionary(uniqueKeysWithValues: val.map { ($0, $1) }) }, [.int(1): .string("a"), .int(2): .string("b"), .int(3): .string("c")] ) - XCTAssertEqual(try AnyValue.wrapped(["a": 1, "b": "test", "c": true] as OrderedDictionary), + XCTAssertEqual(try AnyValue.wrapped(["a": 1, "b": "test", "c": true] as OrderedDictionary), .dictionary(["a": 1, "b": "test", "c": true])) - XCTAssertEqual(try AnyValue.wrapped([1: 1, 2: "test", 3: true] as OrderedDictionary), + XCTAssertEqual(try AnyValue.wrapped([1: 1, 2: "test", 3: true] as OrderedDictionary), .dictionary([1: 1, 2: "test", 3: true])) // Passthrough diff --git a/Tests/CBORTests.swift b/Tests/CBORTests.swift index 852452f30..53fa63676 100644 --- a/Tests/CBORTests.swift +++ b/Tests/CBORTests.swift @@ -92,4 +92,24 @@ class CBORTests: XCTestCase { XCTAssertEqual(map, try CBORSerialization.cbor(from: Data([0xA3, 0x61, 0x63, 0x01, 0x61, 0x61, 0x02, 0x61, 0x62, 0x03]))) } + func testRandomData() throws { + + struct TestStruct: Codable { + var id: [Int] + var data: Data + } + + let encoded = try CBOR.Encoder.default.encode(TestStruct(id: [1, 2, 3, 4, 5], data: Data(count: 16))) + + for _ in 0 ..< 10000 { + + var random = Data(capacity: encoded.count) + for _ in 0 ..< encoded.count { + random.append(UInt8.random(in: 0 ..< .max)) + } + + XCTAssertThrowsError(try CBOR.Decoder.default.decode(TestStruct.self, from: random)) + } + } + }