333 lines
13 KiB
Swift
333 lines
13 KiB
Swift
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This source file is part of the SwiftNIO open source project
|
|
//
|
|
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
|
// Licensed under Apache License v2.0
|
|
//
|
|
// See LICENSE.txt for license information
|
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import NIO
|
|
|
|
/// Errors thrown by the NIO websocket module.
|
|
public enum NIOWebSocketError: Error {
|
|
/// The frame being sent is larger than the configured maximum
|
|
/// acceptable frame size
|
|
case invalidFrameLength
|
|
|
|
/// A control frame may not be fragmented.
|
|
case fragmentedControlFrame
|
|
|
|
/// A control frame may not have a length more than 125 bytes.
|
|
case multiByteControlFrameLength
|
|
}
|
|
|
|
internal extension WebSocketErrorCode {
|
|
internal init(_ error: NIOWebSocketError) {
|
|
switch error {
|
|
case .invalidFrameLength:
|
|
self = .messageTooLarge
|
|
case .fragmentedControlFrame,
|
|
.multiByteControlFrameLength:
|
|
self = .protocolError
|
|
}
|
|
}
|
|
}
|
|
|
|
public extension ByteBuffer {
|
|
/// Applies the WebSocket unmasking operation.
|
|
///
|
|
/// - parameters:
|
|
/// - maskingKey: The masking key.
|
|
public mutating func webSocketUnmask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) {
|
|
/// Shhhh: secretly unmasking and masking are the same operation!
|
|
webSocketMask(maskingKey, indexOffset: indexOffset)
|
|
}
|
|
|
|
/// Applies the websocket masking operation.
|
|
///
|
|
/// - parameters:
|
|
/// - maskingKey: The masking key.
|
|
/// - indexOffset: An integer offset to apply to the index into the masking key.
|
|
/// This is used when masking multiple "contiguous" byte buffers, to ensure that
|
|
/// the masking key is applied uniformly to the collection rather than from the
|
|
/// start each time.
|
|
public mutating func webSocketMask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) {
|
|
self.withUnsafeMutableReadableBytes {
|
|
for (index, byte) in $0.enumerated() {
|
|
$0[index] = byte ^ maskingKey[(index + indexOffset) % 4]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// The current state of the frame decoder.
|
|
enum DecoderState {
|
|
/// Waiting for a frame.
|
|
case idle
|
|
|
|
/// The initial frame byte has been received, but the length byte
|
|
/// has not.
|
|
case firstByteReceived
|
|
|
|
/// The length byte indicates that we need to wait for the length word, and we're
|
|
/// currently waiting for it.
|
|
case waitingForLengthWord
|
|
|
|
/// The length byte indicates that we need to wait for the length qword, and
|
|
/// we're currently waiting for it.
|
|
case waitingForLengthQWord
|
|
|
|
/// The mask bit indicates we are expecting a mask key.
|
|
case waitingForMask
|
|
|
|
/// All the header data is complete, we are waiting for the application data.
|
|
case waitingForData
|
|
}
|
|
|
|
enum ParseResult {
|
|
case insufficientData
|
|
case continueParsing
|
|
case result(WebSocketFrame)
|
|
}
|
|
|
|
/// An incremental websocket frame parser.
|
|
///
|
|
/// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that
|
|
/// we don't repeatedly partially parse the data.
|
|
struct WSParser {
|
|
internal private(set) var firstByte: UInt8? = nil
|
|
internal private(set) var length: Int? = nil
|
|
internal private(set) var masked: Bool = false
|
|
internal private(set) var maskingKey: WebSocketMaskingKey? = nil
|
|
|
|
/// The current state of the decoder during incremental parse.
|
|
var state: DecoderState = .idle
|
|
|
|
private mutating func reset() {
|
|
self.state = .idle
|
|
self.firstByte = nil
|
|
self.length = nil
|
|
self.masked = false
|
|
self.maskingKey = nil
|
|
}
|
|
|
|
mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult {
|
|
switch self.state {
|
|
case .idle:
|
|
// This is a new buffer. We want to find the first octet and save it off.
|
|
assert(self.firstByte == nil)
|
|
guard let firstByte = buffer.readInteger(as: UInt8.self) else {
|
|
return .insufficientData
|
|
}
|
|
self.firstByte = firstByte
|
|
self.state = .firstByteReceived
|
|
return .continueParsing
|
|
|
|
case .firstByteReceived:
|
|
// Now we're looking for the length. We begin by finding the length byte to see if we
|
|
// need any more data.
|
|
assert(self.length == nil)
|
|
assert(self.firstByte != nil)
|
|
guard let lengthByte = buffer.readInteger(as: UInt8.self) else {
|
|
return .insufficientData
|
|
}
|
|
|
|
self.masked = (lengthByte & 0x80) != 0
|
|
|
|
switch lengthByte & 0x7F {
|
|
case 126:
|
|
self.state = .waitingForLengthWord
|
|
case 127:
|
|
self.state = .waitingForLengthQWord
|
|
case let len:
|
|
assert(len <= 125)
|
|
self.length = Int(len)
|
|
self.state = self.masked ? .waitingForMask : .waitingForData
|
|
}
|
|
return .continueParsing
|
|
|
|
case .waitingForLengthWord:
|
|
// We've got a one-word length here.
|
|
assert(self.length == nil)
|
|
assert(self.firstByte != nil)
|
|
guard let lengthWord = buffer.readInteger(as: UInt16.self) else {
|
|
return .insufficientData
|
|
}
|
|
|
|
self.length = Int(lengthWord)
|
|
self.state = self.masked ? .waitingForMask : .waitingForData
|
|
return .continueParsing
|
|
|
|
case .waitingForLengthQWord:
|
|
// We've got a qword of length here.
|
|
assert(self.length == nil)
|
|
assert(self.firstByte != nil)
|
|
guard let lengthQWord = buffer.readInteger(as: UInt64.self) else {
|
|
return .insufficientData
|
|
}
|
|
|
|
self.length = Int(lengthQWord)
|
|
self.state = self.masked ? .waitingForMask : .waitingForData
|
|
return .continueParsing
|
|
|
|
case .waitingForMask:
|
|
// We're waiting for the masking key.
|
|
assert(maskingKey == nil)
|
|
assert(self.firstByte != nil)
|
|
assert(self.length != nil)
|
|
guard let maskingKey = buffer.readInteger(as: UInt32.self) else {
|
|
return .insufficientData
|
|
}
|
|
self.maskingKey = WebSocketMaskingKey(networkRepresentation: maskingKey)
|
|
self.state = .waitingForData
|
|
return .continueParsing
|
|
|
|
case .waitingForData:
|
|
assert(self.firstByte != nil)
|
|
assert(self.length != nil)
|
|
guard let data = buffer.readSlice(length: self.length!) else {
|
|
return .insufficientData
|
|
}
|
|
|
|
let frame = WebSocketFrame(firstByte: self.firstByte!, maskKey: self.maskingKey, applicationData: data)
|
|
self.reset()
|
|
return .result(frame)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// An inbound `ChannelHandler` that deserializes websocket frames into a structured
|
|
/// format for further processing.
|
|
///
|
|
/// This decoder has limited enforcement of compliance to RFC 6455. In particular, to guarantee
|
|
/// that the decoder can handle arbitrary extensions, only normative MUST/MUST NOTs that do not
|
|
/// relate to extensions (e.g. the requirement that control frames not have lengths larger than
|
|
/// 125 bytes) are enforced by this decoder.
|
|
///
|
|
/// This decoder does not have any support for decoding extensions. If you wish to support
|
|
/// extensions, you should implement a message-to-message decoder that performs the appropriate
|
|
/// frame transformation as needed. All the frame data is assumed to be application data by this
|
|
/// parser.
|
|
public final class WebSocketFrameDecoder: ByteToMessageDecoder {
|
|
public typealias InboundIn = ByteBuffer
|
|
public typealias InboundOut = WebSocketFrame
|
|
public typealias OutboundOut = WebSocketFrame
|
|
public var cumulationBuffer: ByteBuffer? = nil
|
|
|
|
/// The maximum frame size the decoder is willing to tolerate from the remote peer.
|
|
/* private but tests */ let maxFrameSize: Int
|
|
|
|
/// Our parser state.
|
|
private var parser = WSParser()
|
|
|
|
/// Whether we should continue to parse.
|
|
private var shouldKeepParsing = true
|
|
|
|
/// Whether this `ChannelHandler` should be performing automatic error handling.
|
|
private let automaticErrorHandling: Bool
|
|
|
|
/// Construct a new `WebSocketFrameDecoder`
|
|
///
|
|
/// - parameters:
|
|
/// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the
|
|
/// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but
|
|
/// this is an objectively unreasonable maximum value (on AMD64 systems it is not
|
|
/// possible to even allocate a buffer large enough to handle this size), so we
|
|
/// set a lower one. The default value is the same as the default HTTP/2 max frame
|
|
/// size, `2**14` bytes. Users may override this to any value up to `UInt32.max`.
|
|
/// Users are strongly encouraged not to increase this value unless they absolutely
|
|
/// must, as the decoder will not produce partial frames, meaning that it will hold
|
|
/// on to data until the *entire* body is received.
|
|
/// - automaticErrorHandling: Whether this `ChannelHandler` should automatically handle
|
|
/// protocol errors in frame serialization, or whether it should allow the pipeline
|
|
/// to handle them.
|
|
public init(maxFrameSize: Int = 1 << 14, automaticErrorHandling: Bool = true) {
|
|
precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size")
|
|
self.maxFrameSize = maxFrameSize
|
|
self.automaticErrorHandling = automaticErrorHandling
|
|
}
|
|
|
|
public func decode(ctx: ChannelHandlerContext, buffer: inout ByteBuffer) -> DecodingState {
|
|
// Even though the calling code will loop around calling us in `decode`, we can't quite
|
|
// rely on that: sometimes we have zero-length elements to parse, and the caller doesn't
|
|
// guarantee to call us with zero-length bytes.
|
|
parseLoop: while self.shouldKeepParsing {
|
|
switch parser.parseStep(&buffer) {
|
|
case .result(let frame):
|
|
ctx.fireChannelRead(self.wrapInboundOut(frame))
|
|
case .continueParsing:
|
|
do {
|
|
try self.validateState()
|
|
} catch {
|
|
self.handleError(error, ctx: ctx)
|
|
}
|
|
case .insufficientData:
|
|
break parseLoop
|
|
}
|
|
}
|
|
|
|
// We parse eagerly, so once we get here we definitionally need more data.
|
|
return .needMoreData
|
|
}
|
|
|
|
public func decodeLast(ctx: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
|
|
// EOF is not semantic in WebSocket, so ignore this.
|
|
return .needMoreData
|
|
}
|
|
|
|
/// Apply a number of validations to the incremental state, ensuring that the frame we're
|
|
/// receiving is valid.
|
|
private func validateState() throws {
|
|
if let length = parser.length, length > self.maxFrameSize {
|
|
throw NIOWebSocketError.invalidFrameLength
|
|
}
|
|
|
|
if let length = parser.length, let firstByte = parser.firstByte {
|
|
let isControlFrame = (firstByte & 0x08) != 0
|
|
let isFragment = (firstByte & 0x80) == 0
|
|
|
|
if isControlFrame && isFragment {
|
|
throw NIOWebSocketError.fragmentedControlFrame
|
|
}
|
|
if isControlFrame && length > 125 {
|
|
throw NIOWebSocketError.multiByteControlFrameLength
|
|
}
|
|
}
|
|
}
|
|
|
|
/// We hit a decoding error, we're going to tear things down now. To do this we're
|
|
/// basically going to send an error frame and then close the connection. Once we're
|
|
/// in this state we do no further parsing.
|
|
///
|
|
/// A clean websocket shutdown is not really supposed to have an immediate close,
|
|
/// but we're doing that because the remote peer has prevented us from doing
|
|
/// further frame parsing, so we can't really wait for the next frame.
|
|
private func handleError(_ error: Error, ctx: ChannelHandlerContext) {
|
|
guard let error = error as? NIOWebSocketError else {
|
|
fatalError("Can only handle NIOWebSocketErrors")
|
|
}
|
|
self.shouldKeepParsing = false
|
|
|
|
// If we've been asked to handle the errors here, we should.
|
|
// TODO(cory): Remove this in 2.0, in favour of `WebSocketProtocolErrorHandler`.
|
|
if self.automaticErrorHandling {
|
|
var data = ctx.channel.allocator.buffer(capacity: 2)
|
|
data.write(webSocketErrorCode: WebSocketErrorCode(error))
|
|
let frame = WebSocketFrame(fin: true,
|
|
opcode: .connectionClose,
|
|
data: data)
|
|
ctx.writeAndFlush(self.wrapOutboundOut(frame)).whenComplete {
|
|
ctx.close(promise: nil)
|
|
}
|
|
}
|
|
|
|
ctx.fireErrorCaught(error)
|
|
}
|
|
}
|