protobuf-swift/Source/CodedInputStream.swift

613 lines
21 KiB
Swift

// Protocol Buffers for Swift
//
// Copyright 2014 Alexey Khohklov(AlexeyXo).
// Copyright 2008 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
let DEFAULT_RECURSION_LIMIT:Int = 64
let DEFAULT_SIZE_LIMIT:Int = 64 << 20 // 64MB
let BUFFER_SIZE:Int = 4096
public class CodedInputStream {
public var buffer:[UInt8]
fileprivate var input:InputStream?
fileprivate var bufferSize:Int = 0
fileprivate var bufferSizeAfterLimit:Int = 0
fileprivate var bufferPos:Int = 0
fileprivate var lastTag:Int32 = 0
fileprivate var totalBytesRetired:Int = 0
fileprivate var currentLimit:Int = 0
fileprivate var recursionDepth:Int = 0
fileprivate var recursionLimit:Int = 0
fileprivate var sizeLimit:Int = 0
public init (data:Data) {
buffer = data.withUnsafeBytes {
[UInt8](UnsafeBufferPointer(start: $0, count: data.count))
}
bufferSize = buffer.count
currentLimit = Int.max
recursionLimit = DEFAULT_RECURSION_LIMIT
sizeLimit = DEFAULT_SIZE_LIMIT
}
public init (stream:InputStream) {
buffer = [UInt8](repeating: 0, count: BUFFER_SIZE)
bufferSize = 0
input = stream
input?.open()
//
currentLimit = Int.max
recursionLimit = DEFAULT_RECURSION_LIMIT
sizeLimit = DEFAULT_SIZE_LIMIT
}
private func isAtEnd() throws -> Bool {
if bufferPos == bufferSize {
if !(try refillBuffer(mustSucceed: false)) {
return true
}
}
return false
}
private func refillBuffer(mustSucceed:Bool) throws -> Bool {
guard bufferPos >= bufferSize else {
throw ProtocolBuffersError.illegalState("RefillBuffer called when buffer wasn't empty.")
}
if (totalBytesRetired + bufferSize == currentLimit) {
guard !mustSucceed else {
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
return false
}
totalBytesRetired += bufferSize
bufferPos = 0
bufferSize = 0
if let input = self.input {
// let pointer = UnsafeMutablePointerUInt8From(data: buffer)
bufferSize = input.read(&buffer, maxLength:buffer.count)
}
if bufferSize <= 0 {
bufferSize = 0
guard !mustSucceed else {
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
return false
} else {
recomputeBufferSizeAfterLimit()
let totalBytesRead = totalBytesRetired + bufferSize + bufferSizeAfterLimit
guard totalBytesRead <= sizeLimit || totalBytesRead >= 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Size Limit Exceeded")
}
return true
}
}
public func readRawData(size:Int) throws -> Data {
// let pointer = UnsafeMutablePointerUInt8From(data: buffer)
guard size >= 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Negative Size")
}
if totalBytesRetired + bufferPos + size > currentLimit {
try skipRawData(size: currentLimit - totalBytesRetired - bufferPos)
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
if (size <= bufferSize - bufferPos) {
let data = Data(bytes: &buffer + bufferPos, count: size)
bufferPos += size
return data
} else if (size < BUFFER_SIZE) {
var bytes = [UInt8](repeating: 0, count: size)
var pos = bufferSize - bufferPos
// let byPointer = UnsafeMutablePointerUInt8From(data: bytes)
memcpy(&bytes, &buffer + bufferPos, pos)
bufferPos = bufferSize
_ = try refillBuffer(mustSucceed: true)
while size - pos > bufferSize {
memcpy(&bytes + pos, &buffer, bufferSize)
pos += bufferSize
bufferPos = bufferSize
_ = try refillBuffer(mustSucceed: true)
}
memcpy(&bytes + pos, &buffer, size - pos)
bufferPos = size - pos
return Data(bytes:bytes, count:bytes.count)
} else {
let originalBufferPos = bufferPos
let originalBufferSize = bufferSize
totalBytesRetired += bufferSize
bufferPos = 0
bufferSize = 0
var sizeLeft = size - (originalBufferSize - originalBufferPos)
var chunks:Array<[UInt8]> = Array<[UInt8]>()
while sizeLeft > 0 {
var chunk = [UInt8](repeating: 0, count: min(sizeLeft, BUFFER_SIZE))
var pos:Int = 0
while pos < chunk.count {
var n:Int = 0
if input != nil {
// let pointer = UnsafeMutablePointerUInt8From(data: chunk)
n = input!.read(&chunk + pos, maxLength:chunk.count - pos)
}
guard n > 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
totalBytesRetired += n
pos += n
}
sizeLeft -= chunk.count
chunks.append(chunk)
}
var bytes = [UInt8](repeating: 0, count: size)
// let byPointer = UnsafeMutablePointerUInt8From(data: bytes)
var pos = originalBufferSize - originalBufferPos
memcpy(&bytes, &buffer + originalBufferPos, pos)
for chunk in chunks {
// let chPointer = UnsafeMutablePointerUInt8From(data: chunk)
memcpy(&bytes + pos, chunk, chunk.count)
pos += chunk.count
}
return Data(bytes)
}
}
public func skipRawData(size:Int) throws{
guard size >= 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Negative Size")
}
if (totalBytesRetired + bufferPos + size > currentLimit) {
try skipRawData(size: currentLimit - totalBytesRetired - bufferPos)
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
if (size <= (bufferSize - bufferPos)) {
bufferPos += size
}
else
{
var pos:Int = bufferSize - bufferPos
totalBytesRetired += pos
bufferPos = 0
bufferSize = 0
while (pos < size) {
var data = [UInt8](repeating: 0, count: size - pos)
var n:Int = 0
guard let input = self.input else {
n = -1
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
// let pointer = UnsafeMutablePointerUInt8From(data: data)
n = input.read(&data, maxLength:Int(size - pos))
pos += n
totalBytesRetired += n
}
}
}
public func readRawLittleEndian32() throws -> Int32 {
let b1 = try readRawByte()
let b2 = try readRawByte()
let b3 = try readRawByte()
let b4 = try readRawByte()
var result:Int32 = (Int32(b1) & 0xff)
result |= ((Int32(b2) & 0xff) << 8)
result |= ((Int32(b3) & 0xff) << 16)
result |= ((Int32(b4) & 0xff) << 24)
return result
}
public func readRawLittleEndian64() throws -> Int64 {
let b1 = try readRawByte()
let b2 = try readRawByte()
let b3 = try readRawByte()
let b4 = try readRawByte()
let b5 = try readRawByte()
let b6 = try readRawByte()
let b7 = try readRawByte()
let b8 = try readRawByte()
var result:Int64 = (Int64(b1) & 0xff)
result |= ((Int64(b2) & 0xff) << 8)
result |= ((Int64(b3) & 0xff) << 16)
result |= ((Int64(b4) & 0xff) << 24)
result |= ((Int64(b5) & 0xff) << 32)
result |= ((Int64(b6) & 0xff) << 40)
result |= ((Int64(b7) & 0xff) << 48)
result |= ((Int64(b8) & 0xff) << 56)
return result
}
public func readTag() throws -> Int32 {
if (try isAtEnd())
{
lastTag = 0
return 0
}
let tag = lastTag
lastTag = try readRawVarint32()
guard lastTag != 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Invalid Tag: after tag \(tag)")
}
return lastTag
}
public func checkLastTagWas(value:Int32) throws {
guard lastTag == value else {
throw ProtocolBuffersError.invalidProtocolBuffer("Invalid Tag: after tag \(lastTag)")
}
}
@discardableResult
public func skipField(tag:Int32) throws -> Bool {
let wireFormat = WireFormat.getTagWireType(tag: tag)
guard let format = WireFormat(rawValue: wireFormat) else {
throw ProtocolBuffersError.invalidProtocolBuffer("Invalid Wire Type")
}
switch format {
case .varint:
_ = try readInt32()
return true
case .fixed64:
_ = try readRawLittleEndian64()
return true
case .lengthDelimited:
try skipRawData(size: Int(try readRawVarint32()))
return true
case .startGroup:
try skipMessage()
try checkLastTagWas(value: WireFormat.endGroup.makeTag(fieldNumber: WireFormat.getTagFieldNumber(tag: tag)))
return true
case .endGroup:
return false
case .fixed32:
_ = try readRawLittleEndian32()
return true
default:
throw ProtocolBuffersError.invalidProtocolBuffer("Invalid Wire Type")
}
}
private func skipMessage() throws {
while (true) {
let tag:Int32 = try readTag()
let fieldSkip = try skipField(tag: tag)
if tag == 0 || !fieldSkip
{
break
}
}
}
public func readDouble() throws -> Double {
let convert:Int64 = try readRawLittleEndian64()
var result:Double = 0.0
result = WireFormat.convertTypes(convertValue: convert, defaultValue: result)
return result
}
public func readFloat() throws -> Float {
let convert:Int32 = try readRawLittleEndian32()
var result:Float = 0.0
result = WireFormat.convertTypes(convertValue: convert, defaultValue: result)
return result
}
public func readUInt64() throws -> UInt64 {
var retvalue:UInt64 = 0
retvalue = WireFormat.convertTypes(convertValue: try readRawVarint64(), defaultValue:retvalue)
return retvalue
}
public func readInt64() throws -> Int64 {
return try readRawVarint64()
}
public func readInt32() throws -> Int32 {
return try readRawVarint32()
}
public func readFixed64() throws -> UInt64 {
var retvalue:UInt64 = 0
retvalue = WireFormat.convertTypes(convertValue: try readRawLittleEndian64(), defaultValue:retvalue)
return retvalue
}
public func readFixed32() throws -> UInt32 {
var retvalue:UInt32 = 0
retvalue = WireFormat.convertTypes(convertValue: try readRawLittleEndian32(), defaultValue:retvalue)
return retvalue
}
public func readBool() throws ->Bool {
return try readRawVarint32() != 0
}
public func readRawByte() throws -> Int8 {
if (bufferPos == bufferSize) {
_ = try refillBuffer(mustSucceed: true)
}
let res = buffer[Int(bufferPos)]
bufferPos+=1
var convert:Int8 = 0
convert = WireFormat.convertTypes(convertValue: res, defaultValue: convert)
return convert
}
public class func readRawVarint32(firstByte:UInt8, inputStream:InputStream) throws -> Int32
{
if ((Int32(firstByte) & 0x80) == 0) {
return Int32(firstByte)
}
var result:Int32 = Int32(firstByte) & 0x7f
var offset:Int32 = 7
while offset < 32 {
var b:UInt8 = UInt8()
guard inputStream.read(&b, maxLength: 1) > 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
result |= (Int32(b) & 0x7f) << offset
if ((b & 0x80) == 0) {
return result
}
offset += 7
}
while offset < 64 {
var b:UInt8 = UInt8()
guard inputStream.read(&b, maxLength: 1) > 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
if ((b & 0x80) == 0) {
return result
}
offset += 7
}
throw ProtocolBuffersError.invalidProtocolBuffer("Truncated Message")
}
public func readRawVarint32() throws -> Int32 {
var tmp = try readRawByte();
if (tmp >= 0) {
return Int32(tmp);
}
var result : Int32 = Int32(tmp) & 0x7f;
tmp = try readRawByte()
if (tmp >= 0) {
result |= Int32(tmp) << 7;
} else {
result |= (Int32(tmp) & 0x7f) << 7;
tmp = try readRawByte()
if (tmp >= 0) {
result |= Int32(tmp) << 14;
} else {
result |= (Int32(tmp) & 0x7f) << 14;
tmp = try readRawByte()
if (tmp >= 0) {
result |= Int32(tmp) << 21;
} else {
result |= (Int32(tmp) & 0x7f) << 21;
tmp = try readRawByte()
result |= (Int32(tmp) << 28);
if (tmp < 0) {
// Discard upper 32 bits.
for _ in 0..<5 {
let byte = try readRawByte()
if (byte >= 0) {
return result;
}
}
throw ProtocolBuffersError.invalidProtocolBuffer("MalformedVarint")
}
}
}
}
return result;
}
public func readRawVarint64() throws -> Int64 {
var shift:Int64 = 0
var result:Int64 = 0
while (shift < 64) {
let b = try readRawByte()
result |= (Int64(b & 0x7F) << shift)
if ((Int32(b) & 0x80) == 0) {
return result
}
shift += 7
}
throw ProtocolBuffersError.invalidProtocolBuffer("MalformedVarint")
}
public func readString() throws -> String {
let size = Int(try readRawVarint32())
if size <= (bufferSize - bufferPos) && size > 0 {
let result = String(bytesNoCopy: &buffer + bufferPos, length: size, encoding: String.Encoding.utf8, freeWhenDone: false)
guard result != nil else {
throw ProtocolBuffersError.invalidProtocolBuffer("InvalidUTF8StringData")
}
bufferPos += size
return result!
} else {
let data = try readRawData(size: size)
return String(data: data, encoding: String.Encoding.utf8)!
}
}
public func readData() throws -> Data {
let size = Int(try readRawVarint32())
if size < bufferSize - bufferPos && size > 0 {
let data = Data(bytes: buffer[bufferPos..<bufferPos+size])
bufferPos += size
return data
} else {
return try readRawData(size: size)
}
}
public func readUInt32() throws -> UInt32 {
let value:Int32 = try readRawVarint32()
var retvalue:UInt32 = 0
retvalue = WireFormat.convertTypes(convertValue: value, defaultValue:retvalue)
return retvalue
}
public func readEnum() throws -> Int32 {
return try readRawVarint32()
}
public func readSFixed32() throws -> Int32 {
return try readRawLittleEndian32()
}
public func readSFixed64() throws -> Int64 {
return try readRawLittleEndian64()
}
public func readSInt32() throws -> Int32 {
return WireFormat.decodeZigZag32(n: try readRawVarint32())
}
public func readSInt64() throws -> Int64 {
return WireFormat.decodeZigZag64(n: try readRawVarint64())
}
public func setRecursionLimit(limit:Int) throws -> Int {
guard limit >= 0 else {
throw ProtocolBuffersError.illegalArgument("Recursion limit cannot be negative")
}
let oldLimit:Int = recursionLimit
recursionLimit = limit
return oldLimit
}
public func setSizeLimit(limit:Int) throws -> Int {
guard limit >= 0 else {
throw ProtocolBuffersError.illegalArgument("Recursion limit cannot be negative")
}
let oldLimit:Int = sizeLimit
sizeLimit = limit
return oldLimit
}
private func resetSizeCounter() {
totalBytesRetired = 0
}
private func recomputeBufferSizeAfterLimit() {
bufferSize += bufferSizeAfterLimit
let bufferEnd:Int = totalBytesRetired + bufferSize
if (bufferEnd > currentLimit) {
bufferSizeAfterLimit = bufferEnd - currentLimit
bufferSize -= bufferSizeAfterLimit
} else {
bufferSizeAfterLimit = 0
}
}
public func pushLimit(byteLimit:Int) throws -> Int {
guard byteLimit >= 0 else {
throw ProtocolBuffersError.invalidProtocolBuffer("Negative Size")
}
let newByteLimit = byteLimit + totalBytesRetired + bufferPos
let oldLimit = currentLimit
guard newByteLimit <= oldLimit else {
throw ProtocolBuffersError.invalidProtocolBuffer("MalformedVarint")
}
currentLimit = newByteLimit
recomputeBufferSizeAfterLimit()
return oldLimit
}
public func popLimit(oldLimit:Int) {
currentLimit = oldLimit
recomputeBufferSizeAfterLimit()
}
public func bytesUntilLimit() ->Int {
if currentLimit == Int.max {
return -1
}
let currentAbsolutePosition:Int = totalBytesRetired + bufferPos
return currentLimit - currentAbsolutePosition
}
public func readGroup(fieldNumber:Int, builder:ProtocolBuffersMessageBuilder, extensionRegistry:ExtensionRegistry) throws {
guard recursionDepth < recursionLimit else {
throw ProtocolBuffersError.invalidProtocolBuffer("Recursion Limit Exceeded")
}
recursionDepth+=1
_ = try builder.mergeFrom(codedInputStream: self, extensionRegistry:extensionRegistry)
try checkLastTagWas(value: WireFormat.endGroup.makeTag(fieldNumber: Int32(fieldNumber)))
recursionDepth-=1
}
public func readUnknownGroup(fieldNumber:Int32, builder:UnknownFieldSet.Builder) throws {
guard recursionDepth < recursionLimit else {
throw ProtocolBuffersError.invalidProtocolBuffer("Recursion Limit Exceeded")
}
recursionDepth+=1
_ = try builder.mergeFrom(codedInputStream: self)
try checkLastTagWas(value: WireFormat.endGroup.makeTag(fieldNumber: fieldNumber))
recursionDepth-=1
}
public func readMessage(builder:ProtocolBuffersMessageBuilder, extensionRegistry:ExtensionRegistry) throws {
let length = try readRawVarint32()
guard recursionDepth < recursionLimit else {
throw ProtocolBuffersError.invalidProtocolBuffer("Recursion Limit Exceeded")
}
let oldLimit = try pushLimit(byteLimit: Int(length))
recursionDepth+=1
_ = try builder.mergeFrom(codedInputStream: self, extensionRegistry:extensionRegistry)
try checkLastTagWas(value: 0)
recursionDepth-=1
popLimit(oldLimit: oldLimit)
}
}