336 lines
11 KiB
Swift
336 lines
11 KiB
Swift
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This source file is part of the swift-nio-redis open source project
|
|
//
|
|
// Copyright (c) 2018 ZeeZide GmbH. and the swift-nio-redis project authors
|
|
// Licensed under Apache License v2.0
|
|
//
|
|
// See LICENSE.txt for license information
|
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import struct NIO.ByteBuffer
|
|
import struct NIO.ByteBufferAllocator
|
|
|
|
public enum RESPParserError : Error {
|
|
case UnexpectedStartByte(char: UInt8, buffer: ByteBuffer)
|
|
case UnexpectedEndByte (char: UInt8, buffer: ByteBuffer)
|
|
case TransportError(Swift.Error)
|
|
case ProtocolError
|
|
case UnexpectedNegativeCount
|
|
case InternalInconsistency
|
|
}
|
|
|
|
public struct RESPParser {
|
|
|
|
public typealias Yield = ( RESPValue ) -> Void
|
|
|
|
private let allocator = ByteBufferAllocator()
|
|
|
|
public mutating func feed(_ buffer: ByteBuffer, yield: Yield) throws {
|
|
try buffer.withUnsafeReadableBytes { bp in
|
|
let count = bp.count
|
|
var i = 0
|
|
|
|
@inline(__always)
|
|
func doSkipNL() {
|
|
if i >= count {
|
|
overflowSkipNL = true
|
|
}
|
|
else {
|
|
if bp[i] == 10 /* LF */ { i += 1 }
|
|
overflowSkipNL = false
|
|
}
|
|
}
|
|
|
|
if overflowSkipNL { doSkipNL() }
|
|
|
|
while i < count {
|
|
let c = bp[i]; i += 1
|
|
|
|
switch state {
|
|
|
|
case .protocolError:
|
|
throw RESPParserError.ProtocolError
|
|
|
|
case .start:
|
|
switch c {
|
|
case 43 /* + */: state = .simpleString
|
|
case 45 /* - */: state = .error
|
|
case 58 /* : */: state = .integer
|
|
case 36 /* $ */: state = .bulkStringLen
|
|
case 42 /* * */: state = .arrayCount
|
|
default: state = .telnet
|
|
}
|
|
countValue = 0
|
|
if state == .telnet || state == .simpleString || state == .error {
|
|
overflowBuffer = allocator.buffer(capacity: 80)
|
|
overflowBuffer?.write(integer: c)
|
|
}
|
|
else {
|
|
overflowBuffer = nil
|
|
}
|
|
|
|
case .telnet:
|
|
assert(overflowBuffer != nil, "missing overflow buffer")
|
|
if c == 13 || c == 10 {
|
|
if c == 13 { doSkipNL() }
|
|
let count = overflowBuffer?.readableBytes ?? 0
|
|
if count > 0 {
|
|
// just a quick hack for telnet mode
|
|
guard let s = overflowBuffer?.readString(length: count) else {
|
|
throw RESPParserError.ProtocolError
|
|
}
|
|
let vals = s.components(separatedBy: " ")
|
|
.lazy.map { RESPValue(bulkString: $0) }
|
|
decoded(value: .array(ContiguousArray(vals)), yield: yield)
|
|
}
|
|
}
|
|
else {
|
|
overflowBuffer?.write(integer: c)
|
|
}
|
|
|
|
case .arrayCount, .bulkStringLen, .integer:
|
|
let c0 : UInt8 = 48, c9 : UInt8 = 57, cMinus : UInt8 = 45
|
|
if c >= c0 && c <= c9 {
|
|
let digit = c - c0
|
|
countValue = (countValue * 10) + Int(digit)
|
|
}
|
|
else if !hadMinus && c == cMinus && countValue == 0 {
|
|
hadMinus = true
|
|
}
|
|
else if c == 13 || c == 10 {
|
|
let doNegate = hadMinus
|
|
hadMinus = false
|
|
if c == 13 { doSkipNL() }
|
|
|
|
switch state {
|
|
|
|
case .arrayCount:
|
|
if doNegate {
|
|
guard countValue == 1 else {
|
|
self.state = .protocolError
|
|
throw RESPParserError.UnexpectedNegativeCount
|
|
}
|
|
decoded(value: .array(nil), yield: yield)
|
|
}
|
|
else {
|
|
if countValue > 0 {
|
|
pushArrayContext(expectedCount: countValue)
|
|
}
|
|
else {
|
|
decoded(value: .array([]), yield: yield)
|
|
}
|
|
}
|
|
state = .start
|
|
|
|
case .bulkStringLen:
|
|
if doNegate {
|
|
state = .start
|
|
decoded(value: .bulkString(nil), yield: yield)
|
|
}
|
|
else {
|
|
if (count - i) >= (countValue + 2) { // include CRLF
|
|
let value = buffer.getSlice(at: buffer.readerIndex + i,
|
|
length: countValue)!
|
|
i += countValue
|
|
decoded(value: .bulkString(value), yield: yield)
|
|
|
|
let ec = bp[i]
|
|
guard ec == 13 || ec == 10 else {
|
|
self.state = .protocolError
|
|
throw RESPParserError.UnexpectedStartByte(char: bp[i],
|
|
buffer: buffer)
|
|
}
|
|
i += 1
|
|
if ec == 13 { doSkipNL() }
|
|
|
|
state = .start
|
|
}
|
|
else {
|
|
state = .bulkStringValue
|
|
overflowBuffer = allocator.buffer(capacity:countValue + 1)
|
|
}
|
|
}
|
|
|
|
case .integer:
|
|
let value = doNegate ? -countValue : countValue
|
|
countValue = 0 // reset
|
|
|
|
decoded(value: .integer(value), yield: yield)
|
|
state = .start
|
|
|
|
default:
|
|
assertionFailure("unexpected enum case \(state)")
|
|
state = .protocolError
|
|
throw RESPParserError.InternalInconsistency
|
|
}
|
|
}
|
|
else {
|
|
self.state = .protocolError
|
|
throw RESPParserError.UnexpectedStartByte(char: c, buffer: buffer)
|
|
}
|
|
|
|
case .bulkStringValue:
|
|
let pending = countValue - (overflowBuffer?.readableBytes ?? 0)
|
|
|
|
if pending > 0 {
|
|
overflowBuffer?.write(integer: c)
|
|
let stillPending = pending - 1
|
|
let avail = min(stillPending, (count - i))
|
|
if avail > 0 {
|
|
overflowBuffer?.write(bytes: bp[i..<(i + avail)])
|
|
i += avail
|
|
}
|
|
}
|
|
else if pending == 0 && (c == 13 || c == 10) {
|
|
if c == 13 { doSkipNL() }
|
|
|
|
let value = overflowBuffer
|
|
overflowBuffer = nil
|
|
|
|
decoded(value: .bulkString(value), yield: yield)
|
|
state = .start
|
|
}
|
|
else {
|
|
self.state = .protocolError
|
|
throw RESPParserError.UnexpectedEndByte(char: c, buffer: buffer)
|
|
}
|
|
|
|
case .simpleString, .error:
|
|
assert(overflowBuffer != nil, "missing overflow buffer")
|
|
if c == 13 || c == 10 {
|
|
if c == 13 { doSkipNL() }
|
|
|
|
if state == .simpleString {
|
|
if let v = overflowBuffer {
|
|
decoded(value: .simpleString(v), yield: yield)
|
|
}
|
|
}
|
|
else {
|
|
// TODO: make nice :-)
|
|
let avail = overflowBuffer?.readableBytes ?? 0
|
|
let value = overflowBuffer?.readBytes(length: avail) ?? []
|
|
let pair = value.split(separator: 32, maxSplits: 1)
|
|
let code = pair.count > 0 ? String.decode(utf8: pair[0]) ?? "" :""
|
|
let msg = pair.count > 1 ? String.decode(utf8: pair[1]) ?? "" :""
|
|
let error = RESPError(code: code, message: msg)
|
|
decoded(value: .error(error), yield: yield)
|
|
}
|
|
overflowBuffer = nil
|
|
|
|
state = .start
|
|
}
|
|
else {
|
|
overflowBuffer?.write(integer: c)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
assert(ctxIndex < 0 || !arrayContextBuffer[ctxIndex].isDone,
|
|
"array context on stack which is done? \(arrayContextBuffer)")
|
|
}
|
|
|
|
|
|
// MARK: - Parsing
|
|
|
|
@inline(__always)
|
|
private mutating func pushArrayContext(expectedCount: Int) {
|
|
if ctxIndex == ctxCapacity {
|
|
for _ in 0..<4 {
|
|
arrayContextBuffer.append(ArrayParseContext(expectedCount: -44))
|
|
}
|
|
ctxCapacity = arrayContextBuffer.count
|
|
}
|
|
assert(ctxIndex < ctxCapacity, "index overflow")
|
|
ctxIndex += 1
|
|
arrayContextBuffer[ctxIndex].expectedCount = expectedCount
|
|
arrayContextBuffer[ctxIndex].values.reserveCapacity(expectedCount)
|
|
}
|
|
|
|
@inline(__always)
|
|
private mutating func decoded(value: RESPValue, yield: Yield) {
|
|
if ctxIndex < 0 {
|
|
return yield(value)
|
|
}
|
|
|
|
let idx = ctxIndex
|
|
let isDone = arrayContextBuffer[idx].append(value: value)
|
|
if isDone {
|
|
let value = RESPValue.array(arrayContextBuffer[idx].values)
|
|
arrayContextBuffer[idx].values = emptyValueArray
|
|
arrayContextBuffer[idx].expectedCount = -1337
|
|
|
|
ctxIndex -= 1
|
|
if ctxIndex < 0 {
|
|
yield(value)
|
|
}
|
|
else {
|
|
decoded(value: value, yield: yield)
|
|
}
|
|
}
|
|
}
|
|
|
|
let emptyValueArray = ContiguousArray<RESPValue>()
|
|
|
|
private enum ParserState {
|
|
case protocolError
|
|
case start
|
|
case error
|
|
case integer
|
|
case bulkStringLen
|
|
case bulkStringValue
|
|
case simpleString
|
|
case arrayCount
|
|
case telnet
|
|
}
|
|
|
|
private var ctxIndex : Int
|
|
private var ctxCapacity : Int
|
|
private var arrayContextBuffer : ContiguousArray<ArrayParseContext>
|
|
|
|
init() {
|
|
ctxIndex = -1
|
|
ctxCapacity = 2
|
|
arrayContextBuffer = ContiguousArray<ArrayParseContext>()
|
|
arrayContextBuffer.reserveCapacity(8)
|
|
for _ in 0..<ctxCapacity {
|
|
arrayContextBuffer.append(ArrayParseContext(expectedCount: -42))
|
|
}
|
|
}
|
|
|
|
private struct ArrayParseContext {
|
|
|
|
var values = ContiguousArray<RESPValue>()
|
|
var expectedCount : Int
|
|
|
|
init(expectedCount: Int) {
|
|
self.expectedCount = expectedCount
|
|
values.reserveCapacity(expectedCount + 1)
|
|
}
|
|
|
|
var isDone : Bool {
|
|
@inline(__always) get { return expectedCount <= values.count }
|
|
}
|
|
|
|
@inline(__always)
|
|
mutating func append(value v: RESPValue) -> Bool {
|
|
assert(!isDone, "attempt to add to a context which is not TL or done")
|
|
values.append(v)
|
|
return isDone
|
|
}
|
|
}
|
|
|
|
private var state = ParserState.start
|
|
private var hadMinus = false
|
|
|
|
private var countValue = 0
|
|
private var overflowSkipNL = false
|
|
private var overflowBuffer : ByteBuffer?
|
|
|
|
}
|