swift-nio-redis/Sources/NIORedis/RESPChannelHandler.swift

350 lines
9.7 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the swift-nio-redis open source project
//
// Copyright (c) 2018 ZeeZide GmbH. and the swift-nio-redis project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
final class RedisChannelHandler : ChannelInboundHandler,
ChannelOutboundHandler
{
typealias InboundErr = RESPParserError
typealias InboundIn = ByteBuffer
typealias InboundOut = RESPValue
typealias OutboundIn = RESPEncodable
typealias OutboundOut = ByteBuffer
let nilStringBuffer = ConstantBuffers.nilStringBuffer
let nilArrayBuffer = ConstantBuffers.nilArrayBuffer
private final var parser = RESPParser()
// MARK: - Channel Open/Close
func channelActive(ctx: ChannelHandlerContext) {
ctx.fireChannelActive()
}
func channelInactive(ctx: ChannelHandlerContext) {
switch state {
case .protocolError, .start: break // all good
default:
ctx.fireErrorCaught(InboundErr.ProtocolError)
}
overflowBuffer = nil
ctx.fireChannelInactive()
}
// MARK: - Reading
public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
do {
let buffer = self.unwrapInboundIn(data)
try parser.feed(buffer) { respValue in
ctx.fireChannelRead(self.wrapInboundOut(respValue))
}
}
catch {
ctx.fireErrorCaught(error)
ctx.close(promise: nil)
return
}
}
public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
ctx.fireErrorCaught(InboundErr.TransportError(error))
}
// MARK: - Parsing
private enum ParserState {
case protocolError
case start
case error
case integer
case bulkStringLen
case bulkStringValue
case simpleString
case arrayCount
case telnet
}
@inline(__always)
private func makeArrayParseContext(_ parent: ArrayParseContext? = nil,
_ count: Int) -> ArrayParseContext
{
if parent == nil && cachedParseContext != nil {
let ctx = cachedParseContext!
cachedParseContext = nil
ctx.count = count
ctx.values.reserveCapacity(count)
return ctx
}
else {
return ArrayParseContext(parent, count)
}
}
private final var cachedParseContext : ArrayParseContext? = nil
final private class ArrayParseContext {
let parent : ArrayParseContext?
var values = ContiguousArray<RESPValue>()
var count : Int
init(_ parent: ArrayParseContext?, _ count: Int) {
self.parent = parent
self.count = count
}
var isDone : Bool { return count <= values.count }
var isNested : Bool { return parent != nil }
@inline(__always)
func append(value v: InboundOut) -> Bool {
assert(!isNested || !isDone,
"attempt to add to a context which is not TL or done")
values.append(v)
return isDone
}
}
private var state = ParserState.start
private var overflowSkipNL = false
private var hadMinus = false
private var countValue = 0
private var overflowBuffer : ByteBuffer?
private var arrayContext : ArrayParseContext?
// MARK: - Writing
func write(ctx: ChannelHandlerContext, data: NIOAny,
promise: EventLoopPromise<Void>?)
{
let data : RESPEncodable = self.unwrapOutboundIn(data)
let value = data.toRESPValue()
var out : ByteBuffer
switch value {
case .simpleString(var s): // +
out = ctx.channel.allocator.buffer(capacity: 1 + s.readableBytes + 3)
out.write(integer : UInt8(43)) // +
out.write(buffer : &s)
out.write(bytes : eol)
case .bulkString(.some(var s)): // $
let count = s.readableBytes
out = ctx.channel.allocator.buffer(capacity: 1 + 4 + 2 + count + 3)
out.write(integer : UInt8(36)) // $
out.write(integerAsString : count)
out.write(bytes : eol)
out.write(buffer : &s)
out.write(bytes : eol)
case .bulkString(.none): // $
out = nilStringBuffer
case .integer(let i): // :
out = ctx.channel.allocator.buffer(capacity: 1 + 20 + 3)
out.write(integer : UInt8(58)) // :
out.write(integerAsString : i)
out.write(bytes : eol)
case .error(let error): // -
out = ctx.channel.allocator.buffer(capacity: 256)
encode(error: error, out: &out)
case .array(.some(let array)): // *
let count = array.count
out = ctx.channel.allocator.buffer(capacity: 1 + 4 + 3 + count * 32)
out.write(integer: UInt8(42)) // *
out.write(string: String(array.count, radix: 10))
out.write(bytes: eol)
for item in array {
encode(ctx: ctx, data: item, level: 1, out: &out)
}
case .array(.none): // *
out = nilArrayBuffer
}
ctx.write(wrapOutboundOut(out), promise: promise)
}
@inline(__always)
func encode<S: ContiguousCollection>(simpleString bytes: S,
out: inout ByteBuffer)
where S.Element == UInt8
{
out.write(integer : UInt8(43)) // +
out.write(bytes : bytes)
out.write(bytes : eol)
}
@inline(__always)
func encode(simpleString bytes: ByteBuffer, out: inout ByteBuffer) {
var s = bytes
out.write(integer : UInt8(43)) // +
out.write(buffer : &s)
out.write(bytes : eol)
}
@inline(__always)
func encode(bulkString bytes: ByteBuffer?, out: inout ByteBuffer) {
if var s = bytes {
out.write(integer : UInt8(36)) // $
out.write(integerAsString : s.readableBytes)
out.write(bytes : eol)
out.write(buffer : &s)
out.write(bytes : eol)
}
else {
out.write(bytes: nilString)
}
}
@inline(__always)
func encode<S: ContiguousCollection>(bulkString bytes: S?,
out: inout ByteBuffer)
where S.Element == UInt8
{
if let s = bytes {
out.write(integer : UInt8(36)) // $
out.write(integerAsString : Int(s.count))
out.write(bytes : eol)
out.write(bytes : s)
out.write(bytes : eol)
}
else {
out.write(bytes: nilString)
}
}
@inline(__always)
func encode(integer i: Int, out: inout ByteBuffer) {
out.write(integer : UInt8(58)) // :
out.write(integerAsString : i)
out.write(bytes : eol)
}
@inline(__always)
func encode(error: RESPError, out: inout ByteBuffer) {
out.write(integer : UInt8(45)) // -
out.write(string : error.code)
out.write(integer : UInt8(32)) // ' '
out.write(string : error.message)
out.write(bytes : eol)
}
func encode(ctx : ChannelHandlerContext,
data : RESPValue,
out : inout ByteBuffer)
{
encode(ctx: ctx, data: data, level: 0, out: &out)
}
func encode(ctx : ChannelHandlerContext,
data : RESPValue,
level : Int,
out : inout ByteBuffer)
{
// FIXME: Creating a String for an Int is expensive, there is something
// something better in the HTTP-API async imp.
switch data {
case .simpleString(let s): // +
encode(simpleString: s, out: &out)
case .bulkString(let s): // $
encode(bulkString: s, out: &out)
case .integer(let i): // :
encode(integer: i, out: &out)
case .error(let error): // -
encode(error: error, out: &out)
case .array(let array): // *
if let array = array {
out.write(integer: UInt8(42)) // *
out.write(string: String(array.count, radix: 10))
out.write(bytes: eol)
for item in array {
encode(ctx: ctx, data: item, level: level + 1, out: &out)
}
}
else {
out.write(bytes: nilArray)
}
}
}
}
private let eol : ContiguousArray<UInt8> = [ 13, 10 ] // \r\n
private let nilString : ContiguousArray<UInt8> = [ 36, 45, 49, 13, 10 ] // $-1\r\n
private let nilArray : ContiguousArray<UInt8> = [ 42, 45, 49, 13, 10 ] // *-1\r\n
fileprivate enum ConstantBuffers {
static let nilStringBuffer : ByteBuffer = {
let alloc = ByteBufferAllocator()
var bb = alloc.buffer(capacity: 6)
bb.write(bytes: nilString)
return bb
}()
static let nilArrayBuffer : ByteBuffer = {
let alloc = ByteBufferAllocator()
var bb = alloc.buffer(capacity: 6)
bb.write(bytes: nilArray)
return bb
}()
}
extension ByteBuffer {
@discardableResult
public mutating func write<T: FixedWidthInteger>(integerAsString integer: T,
as: T.Type = T.self) -> Int
{
let bytesWritten = set(integerAsString: integer, at: self.writerIndex)
moveWriterIndex(forwardBy: bytesWritten)
return Int(bytesWritten)
}
@discardableResult
public mutating func set<T: FixedWidthInteger>(integerAsString integer: T,
at index: Int,
as: T.Type = T.self) -> Int
{
var value = integer
#if true // slow, fixme
let len = String(value, radix: 10)
return set(string: len, at: index)! // TBD: why is this returning an Opt?
#else
return Swift.withUnsafeBytes(of: &value) { ptr in
// TODO: do the itoa thing to make it fast
}
#endif
}
}