swift-nio/Sources/NIOWebSocket/NIOWebSocketFrameAggregator...

139 lines
5.8 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
/// `NIOWebSocketFrameAggregator` buffers inbound fragmented `WebSocketFrame`'s and aggregates them into a single `WebSocketFrame`.
/// It guarantees that a `WebSocketFrame` with an `opcode` of `.continuation` is never forwarded.
/// Frames which are not fragmented are just forwarded without any processing.
/// Fragmented frames are unmasked, concatenated and forwarded as a new `WebSocketFrame` which is either a `.binary` or `.text` frame.
/// `extensionData`, `rsv1`, `rsv2` and `rsv3` are lost if a frame is fragmented because they cannot be concatenated.
/// - Note: `.ping`, `.pong`, `.closeConnection` frames are forwarded during frame aggregation
public final class NIOWebSocketFrameAggregator: ChannelInboundHandler {
public enum Error: Swift.Error {
case nonFinalFragmentSizeIsTooSmall
case tooManyFragments
case accumulatedFrameSizeIsTooLarge
case receivedNewFrameWithoutFinishingPrevious
case didReceiveFragmentBeforeReceivingTextOrBinaryFrame
}
public typealias InboundIn = WebSocketFrame
public typealias InboundOut = WebSocketFrame
private let minNonFinalFragmentSize: Int
private let maxAccumulatedFrameCount: Int
private let maxAccumulatedFrameSize: Int
private var bufferedFrames: [WebSocketFrame] = []
private var accumulatedFrameSize: Int = 0
/// Configures a `NIOWebSocketFrameAggregator`.
/// - Parameters:
/// - minNonFinalFragmentSize: Minimum size in bytes of a fragment which is not the last fragment of a complete frame. Used to defend against many really small payloads.
/// - maxAccumulatedFrameCount: Maximum number of fragments which are allowed to result in a complete frame.
/// - maxAccumulatedFrameSize: Maximum accumulated size in bytes of buffered fragments. It is essentially the maximum allowed size of an incoming frame after all fragments are concatenated.
public init(
minNonFinalFragmentSize: Int,
maxAccumulatedFrameCount: Int,
maxAccumulatedFrameSize: Int
) {
self.minNonFinalFragmentSize = minNonFinalFragmentSize
self.maxAccumulatedFrameCount = maxAccumulatedFrameCount
self.maxAccumulatedFrameSize = maxAccumulatedFrameSize
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let frame = unwrapInboundIn(data)
do {
switch frame.opcode {
case .continuation:
guard let firstFrameOpcode = self.bufferedFrames.first?.opcode else {
throw Error.didReceiveFragmentBeforeReceivingTextOrBinaryFrame
}
try self.bufferFrame(frame)
guard frame.fin else { break }
// final frame received
let aggregatedFrame = self.aggregateFrames(
opcode: firstFrameOpcode,
allocator: context.channel.allocator
)
self.clearBuffer()
context.fireChannelRead(wrapInboundOut(aggregatedFrame))
case .binary, .text:
if frame.fin {
guard self.bufferedFrames.isEmpty else {
throw Error.receivedNewFrameWithoutFinishingPrevious
}
// fast path: no need to check any constraints nor unmask and copy data
context.fireChannelRead(data)
} else {
try self.bufferFrame(frame)
}
default:
// control frames can't be fragmented
context.fireChannelRead(data)
}
} catch {
// free memory early
self.clearBuffer()
context.fireErrorCaught(error)
}
}
private func bufferFrame(_ frame: WebSocketFrame) throws {
guard self.bufferedFrames.isEmpty || frame.opcode == .continuation else {
throw Error.receivedNewFrameWithoutFinishingPrevious
}
guard frame.fin || frame.length >= self.minNonFinalFragmentSize else {
throw Error.nonFinalFragmentSizeIsTooSmall
}
guard self.bufferedFrames.count < self.maxAccumulatedFrameCount else {
throw Error.tooManyFragments
}
// if this is not a final frame, we will at least receive one more frame
guard frame.fin || (self.bufferedFrames.count + 1) < self.maxAccumulatedFrameCount else {
throw Error.tooManyFragments
}
self.bufferedFrames.append(frame)
self.accumulatedFrameSize += frame.length
guard self.accumulatedFrameSize <= self.maxAccumulatedFrameSize else {
throw Error.accumulatedFrameSizeIsTooLarge
}
}
private func aggregateFrames(opcode: WebSocketOpcode, allocator: ByteBufferAllocator) -> WebSocketFrame {
var dataBuffer = allocator.buffer(capacity: self.accumulatedFrameSize)
for frame in self.bufferedFrames {
var unmaskedData = frame.unmaskedData
dataBuffer.writeBuffer(&unmaskedData)
}
return WebSocketFrame(fin: true, opcode: opcode, data: dataBuffer)
}
private func clearBuffer() {
self.bufferedFrames.removeAll(keepingCapacity: true)
self.accumulatedFrameSize = 0
}
}