HTTPObjectAggregator implementation (#1664)
* HTTPServerObjectAggregator for requests * Apply suggestions from code review Co-authored-by: Cory Benfield <lukasa@apple.com> * most of code review comments addressed * tidy up state machine close * bad line breaks * more verbose error reporting * removes expect:continue functionality due to #1422 * public type naming and error handling tweaks * wrong expectation in test * do not quietly swallow unexpected errors in channelRead Co-authored-by: Cory Benfield <lukasa@apple.com>
This commit is contained in:
parent
d2372de507
commit
e6b7d718a8
|
@ -1435,3 +1435,21 @@ extension HTTPMethod: RawRepresentable {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension HTTPResponseHead {
|
||||
internal var contentLength: Int? {
|
||||
return headers.contentLength
|
||||
}
|
||||
}
|
||||
|
||||
extension HTTPRequestHead {
|
||||
internal var contentLength: Int? {
|
||||
return headers.contentLength
|
||||
}
|
||||
}
|
||||
|
||||
extension HTTPHeaders {
|
||||
internal var contentLength: Int? {
|
||||
return self.first(name: "content-length").flatMap { Int($0) }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,400 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import NIO
|
||||
|
||||
/// The parts of a complete HTTP response from the view of the client.
|
||||
///
|
||||
/// A full HTTP request is made up of a response header encoded by `.head`
|
||||
/// and an optional `.body`.
|
||||
public struct NIOHTTPServerRequestFull {
|
||||
public var head: HTTPRequestHead
|
||||
public var body: ByteBuffer?
|
||||
|
||||
public init(head: HTTPRequestHead, body: ByteBuffer?) {
|
||||
self.head = head
|
||||
self.body = body
|
||||
}
|
||||
}
|
||||
|
||||
extension NIOHTTPServerRequestFull: Equatable {}
|
||||
|
||||
/// The parts of a complete HTTP response from the view of the client.
|
||||
///
|
||||
/// Afull HTTP response is made up of a response header encoded by `.head`
|
||||
/// and an optional `.body`.
|
||||
public struct NIOHTTPClientResponseFull {
|
||||
public var head: HTTPResponseHead
|
||||
public var body: ByteBuffer?
|
||||
|
||||
public init(head: HTTPResponseHead, body: ByteBuffer?) {
|
||||
self.head = head
|
||||
self.body = body
|
||||
}
|
||||
}
|
||||
|
||||
extension NIOHTTPClientResponseFull: Equatable {}
|
||||
|
||||
public struct NIOHTTPObjectAggregatorError: Error, Equatable {
|
||||
private enum Base {
|
||||
case frameTooLong
|
||||
case connectionClosed
|
||||
case endingIgnoredMessage
|
||||
case unexpectedMessageHead
|
||||
case unexpectedMessageBody
|
||||
case unexpectedMessageEnd
|
||||
}
|
||||
|
||||
private var base: Base
|
||||
|
||||
private init(base: Base) {
|
||||
self.base = base
|
||||
}
|
||||
|
||||
public static let frameTooLong = NIOHTTPObjectAggregatorError(base: .frameTooLong)
|
||||
public static let connectionClosed = NIOHTTPObjectAggregatorError(base: .connectionClosed)
|
||||
public static let endingIgnoredMessage = NIOHTTPObjectAggregatorError(base: .endingIgnoredMessage)
|
||||
public static let unexpectedMessageHead = NIOHTTPObjectAggregatorError(base: .unexpectedMessageHead)
|
||||
public static let unexpectedMessageBody = NIOHTTPObjectAggregatorError(base: .unexpectedMessageBody)
|
||||
public static let unexpectedMessageEnd = NIOHTTPObjectAggregatorError(base: .unexpectedMessageEnd)
|
||||
}
|
||||
|
||||
public struct NIOHTTPObjectAggregatorEvent: Hashable {
|
||||
private enum Base {
|
||||
case httpExpectationFailed
|
||||
case httpFrameTooLong
|
||||
}
|
||||
|
||||
private var base: Base
|
||||
|
||||
private init(base: Base) {
|
||||
self.base = base
|
||||
}
|
||||
|
||||
public static let httpExpectationFailed = NIOHTTPObjectAggregatorEvent(base: .httpExpectationFailed)
|
||||
public static let httpFrameTooLong = NIOHTTPObjectAggregatorEvent(base: .httpFrameTooLong)
|
||||
}
|
||||
|
||||
/// The state of the aggregator connection.
|
||||
internal enum AggregatorState {
|
||||
/// Nothing is active on this connection, the next message we expect would be a request `.head`.
|
||||
case idle
|
||||
|
||||
/// Ill-behaving client may be sending content that is too large
|
||||
case ignoringContent
|
||||
|
||||
/// We are receiving and aggregating a request
|
||||
case receiving
|
||||
|
||||
/// Connection should be closed
|
||||
case closed
|
||||
|
||||
mutating func messageHeadReceived() throws {
|
||||
switch self {
|
||||
case .idle:
|
||||
self = .receiving
|
||||
case .ignoringContent, .receiving:
|
||||
throw NIOHTTPObjectAggregatorError.unexpectedMessageHead
|
||||
case .closed:
|
||||
throw NIOHTTPObjectAggregatorError.connectionClosed
|
||||
}
|
||||
}
|
||||
|
||||
mutating func messageBodyReceived() throws {
|
||||
switch self {
|
||||
case .receiving:
|
||||
()
|
||||
case .ignoringContent:
|
||||
throw NIOHTTPObjectAggregatorError.frameTooLong
|
||||
case .idle:
|
||||
throw NIOHTTPObjectAggregatorError.unexpectedMessageBody
|
||||
case .closed:
|
||||
throw NIOHTTPObjectAggregatorError.connectionClosed
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
mutating func messageEndReceived() throws {
|
||||
switch self {
|
||||
case .receiving:
|
||||
// Got the request end we were waiting for.
|
||||
self = .idle
|
||||
case .ignoringContent:
|
||||
// Expected transition from a state where message contents are getting
|
||||
// ignored because the message is too large. Throwing an error prevents
|
||||
// the normal control flow from continuing into dispatching the completed
|
||||
// invalid message to the next handler.
|
||||
self = .idle
|
||||
throw NIOHTTPObjectAggregatorError.endingIgnoredMessage
|
||||
case .idle:
|
||||
throw NIOHTTPObjectAggregatorError.unexpectedMessageEnd
|
||||
case .closed:
|
||||
throw NIOHTTPObjectAggregatorError.connectionClosed
|
||||
}
|
||||
}
|
||||
|
||||
mutating func handlingOversizeMessage() {
|
||||
switch self {
|
||||
case .receiving, .idle:
|
||||
self = .ignoringContent
|
||||
case .ignoringContent, .closed:
|
||||
// If we are already ignoring content or connection is closed, should not get here
|
||||
preconditionFailure("Unreachable state: should never handle overized message in \(self)")
|
||||
}
|
||||
}
|
||||
|
||||
mutating func closed() {
|
||||
self = .closed
|
||||
}
|
||||
}
|
||||
|
||||
/// A `ChannelInboundHandler` that handles HTTP chunked `HTTPServerRequestPart`
|
||||
/// messages by aggregating individual message chunks into a single
|
||||
/// `NIOHTTPServerRequestFull`.
|
||||
///
|
||||
/// This is achieved by buffering the contents of all received `HTTPServerRequestPart`
|
||||
/// messages until `HTTPServerRequestPart.end` is received, then assembling the
|
||||
/// full message and firing a channel read upstream with it. It is useful for when you do not
|
||||
/// want to deal with chunked messages and just want to receive everything at once, and
|
||||
/// are happy with the additional memory used and delay handling of the message until
|
||||
/// everything has been received.
|
||||
///
|
||||
/// `NIOHTTPServerRequestAggregator` may end up sending a `HTTPResponseHead`:
|
||||
/// - Response status `413 Request Entity Too Large` when either the
|
||||
/// `content-length` or the bytes received so far exceed `maxContentLength`.
|
||||
///
|
||||
/// `NIOHTTPServerRequestAggregator` may close the connection if it is impossible
|
||||
/// to recover:
|
||||
/// - If `content-length` is too large and `keep-alive` is off.
|
||||
/// - If the bytes received exceed `maxContentLength` and the client didn't signal
|
||||
/// `content-length`
|
||||
public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, RemovableChannelHandler {
|
||||
public typealias InboundIn = HTTPServerRequestPart
|
||||
public typealias InboundOut = NIOHTTPServerRequestFull
|
||||
|
||||
// Aggregator may generate responses of its own
|
||||
public typealias OutboundOut = HTTPServerResponsePart
|
||||
|
||||
private var fullMessageHead: HTTPRequestHead? = nil
|
||||
private var buffer: ByteBuffer! = nil
|
||||
private var maxContentLength: Int
|
||||
private var closeOnExpectationFailed: Bool
|
||||
private var state: AggregatorState
|
||||
|
||||
public init(maxContentLength: Int, closeOnExpectationFailed: Bool = false) {
|
||||
precondition(maxContentLength >= 0, "maxContentLength must not be negative")
|
||||
self.maxContentLength = maxContentLength
|
||||
self.closeOnExpectationFailed = closeOnExpectationFailed
|
||||
self.state = .idle
|
||||
}
|
||||
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let msg = self.unwrapInboundIn(data)
|
||||
var serverResponse: HTTPResponseHead? = nil
|
||||
|
||||
do {
|
||||
switch msg {
|
||||
case .head(let httpHead):
|
||||
try self.state.messageHeadReceived()
|
||||
serverResponse = self.beginAggregation(context: context, request: httpHead, message: msg)
|
||||
case .body(var content):
|
||||
try self.state.messageBodyReceived()
|
||||
serverResponse = self.aggregate(context: context, content: &content, message: msg)
|
||||
case .end(let trailingHeaders):
|
||||
try self.state.messageEndReceived()
|
||||
self.endAggregation(context: context, trailingHeaders: trailingHeaders)
|
||||
}
|
||||
} catch let error as NIOHTTPObjectAggregatorError {
|
||||
context.fireErrorCaught(error)
|
||||
// Won't be able to complete those
|
||||
self.fullMessageHead = nil
|
||||
self.buffer.clear()
|
||||
} catch let error {
|
||||
context.fireErrorCaught(error)
|
||||
}
|
||||
|
||||
// Generated a server esponse to send back
|
||||
if let response = serverResponse {
|
||||
context.write(self.wrapOutboundOut(.head(response)), promise: nil)
|
||||
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
|
||||
if response.status == .payloadTooLarge {
|
||||
// If indicated content length is too large
|
||||
self.state.handlingOversizeMessage()
|
||||
context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
|
||||
context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
|
||||
}
|
||||
if !response.headers.isKeepAlive(version: response.version) {
|
||||
context.close(promise: nil)
|
||||
self.state.closed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func beginAggregation(context: ChannelHandlerContext, request: HTTPRequestHead, message: InboundIn) -> HTTPResponseHead? {
|
||||
self.fullMessageHead = request
|
||||
if let contentLength = request.contentLength, contentLength > self.maxContentLength {
|
||||
return self.handleOversizeMessage(message: message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer, message: InboundIn) -> HTTPResponseHead? {
|
||||
if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) {
|
||||
return self.handleOversizeMessage(message: message)
|
||||
} else {
|
||||
self.buffer.writeBuffer(&content)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func endAggregation(context: ChannelHandlerContext, trailingHeaders: HTTPHeaders?) {
|
||||
if var aggregated = self.fullMessageHead {
|
||||
// Remove `Trailer` from existing header fields and append trailer fields to existing header fields
|
||||
// See rfc7230 4.1.3 Decoding Chunked
|
||||
if let headers = trailingHeaders {
|
||||
aggregated.headers.remove(name: "trailer")
|
||||
aggregated.headers.add(contentsOf: headers)
|
||||
}
|
||||
|
||||
let fullMessage = NIOHTTPServerRequestFull(head: aggregated,
|
||||
body: self.buffer.readableBytes > 0 ? self.buffer : nil)
|
||||
self.fullMessageHead = nil
|
||||
self.buffer.clear()
|
||||
context.fireChannelRead(NIOAny(fullMessage))
|
||||
}
|
||||
}
|
||||
|
||||
private func handleOversizeMessage(message: InboundIn) -> HTTPResponseHead {
|
||||
var payloadTooLargeHead = HTTPResponseHead(
|
||||
version: self.fullMessageHead?.version ?? .init(major: 1, minor: 1),
|
||||
status: .payloadTooLarge,
|
||||
headers: HTTPHeaders([("content-length", "0")]))
|
||||
|
||||
switch message {
|
||||
case .head(let request):
|
||||
if !request.isKeepAlive {
|
||||
// If keep-alive is off and, no need to leave the connection open.
|
||||
// Send back a 413 and close the connection.
|
||||
payloadTooLargeHead.headers.add(name: "connection", value: "close")
|
||||
}
|
||||
default:
|
||||
// The client started to send data already, close because it's impossible to recover.
|
||||
// Send back a 413 and close the connection.
|
||||
payloadTooLargeHead.headers.add(name: "connection", value: "close")
|
||||
}
|
||||
|
||||
return payloadTooLargeHead
|
||||
}
|
||||
|
||||
public func handlerAdded(context: ChannelHandlerContext) {
|
||||
self.buffer = context.channel.allocator.buffer(capacity: 0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A `ChannelInboundHandler` that handles HTTP chunked `HTTPClientResponsePart`
|
||||
/// messages by aggregating individual message chunks into a single
|
||||
/// `NIOHTTPClientResponseFull`.
|
||||
///
|
||||
/// This is achieved by buffering the contents of all received `HTTPClientResponsePart`
|
||||
/// messages until `HTTPClientResponsePart.end` is received, then assembling the
|
||||
/// full message and firing a channel read upstream with it. Useful when you do not
|
||||
/// want to deal with chunked messages and just want to receive everything at once, and
|
||||
/// are happy with the additional memory used and delay handling of the message until
|
||||
/// everything has been received.
|
||||
///
|
||||
/// If `NIOHTTPClientResponseAggregator` encounters a message larger than
|
||||
/// `maxContentLength`, it discards the aggregated contents until the next
|
||||
/// `HTTPClientResponsePart.end` and signals that via
|
||||
/// `fireUserInboundEventTriggered`.
|
||||
public final class NIOHTTPClientResponseAggregator: ChannelInboundHandler, RemovableChannelHandler {
|
||||
public typealias InboundIn = HTTPClientResponsePart
|
||||
public typealias InboundOut = NIOHTTPClientResponseFull
|
||||
|
||||
private var fullMessageHead: HTTPResponseHead? = nil
|
||||
private var buffer: ByteBuffer! = nil
|
||||
private var maxContentLength: Int
|
||||
private var state: AggregatorState
|
||||
|
||||
public init(maxContentLength: Int) {
|
||||
precondition(maxContentLength >= 0, "maxContentLength must not be negative")
|
||||
self.maxContentLength = maxContentLength
|
||||
self.state = .idle
|
||||
}
|
||||
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let msg = self.unwrapInboundIn(data)
|
||||
|
||||
do {
|
||||
switch msg {
|
||||
case .head(let httpHead):
|
||||
try self.state.messageHeadReceived()
|
||||
try self.beginAggregation(context: context, response: httpHead)
|
||||
case .body(var content):
|
||||
try self.state.messageBodyReceived()
|
||||
try self.aggregate(context: context, content: &content)
|
||||
case .end(let trailingHeaders):
|
||||
try self.state.messageEndReceived()
|
||||
self.endAggregation(context: context, trailingHeaders: trailingHeaders)
|
||||
}
|
||||
} catch let error as NIOHTTPObjectAggregatorError {
|
||||
context.fireErrorCaught(error)
|
||||
// Won't be able to complete those
|
||||
self.fullMessageHead = nil
|
||||
self.buffer.clear()
|
||||
} catch let error {
|
||||
context.fireErrorCaught(error)
|
||||
}
|
||||
}
|
||||
|
||||
private func beginAggregation(context: ChannelHandlerContext, response: HTTPResponseHead) throws {
|
||||
self.fullMessageHead = response
|
||||
if let contentLength = response.contentLength, contentLength > self.maxContentLength {
|
||||
self.state.handlingOversizeMessage()
|
||||
context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
|
||||
context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
|
||||
}
|
||||
}
|
||||
|
||||
private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer) throws {
|
||||
if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) {
|
||||
self.state.handlingOversizeMessage()
|
||||
context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
|
||||
context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
|
||||
} else {
|
||||
self.buffer.writeBuffer(&content)
|
||||
}
|
||||
}
|
||||
|
||||
private func endAggregation(context: ChannelHandlerContext, trailingHeaders: HTTPHeaders?) {
|
||||
if var aggregated = self.fullMessageHead {
|
||||
// Remove `Trailer` from existing header fields and append trailer fields to existing header fields
|
||||
// See rfc7230 4.1.3 Decoding Chunked
|
||||
if let headers = trailingHeaders {
|
||||
aggregated.headers.remove(name: "trailer")
|
||||
aggregated.headers.add(contentsOf: headers)
|
||||
}
|
||||
|
||||
let fullMessage = NIOHTTPClientResponseFull(
|
||||
head: aggregated,
|
||||
body: self.buffer.readableBytes > 0 ? self.buffer : nil)
|
||||
self.fullMessageHead = nil
|
||||
self.buffer.clear()
|
||||
context.fireChannelRead(NIOAny(fullMessage))
|
||||
}
|
||||
}
|
||||
|
||||
public func handlerAdded(context: ChannelHandlerContext) {
|
||||
self.buffer = context.channel.allocator.buffer(capacity: 0)
|
||||
}
|
||||
}
|
|
@ -98,6 +98,8 @@ class LinuxMainRunnerImpl: LinuxMainRunner {
|
|||
testCase(NIOCloseOnErrorHandlerTest.allTests),
|
||||
testCase(NIOConcurrencyHelpersTests.allTests),
|
||||
testCase(NIOHTTP1TestServerTest.allTests),
|
||||
testCase(NIOHTTPClientResponseAggregatorTest.allTests),
|
||||
testCase(NIOHTTPServerRequestAggregatorTest.allTests),
|
||||
testCase(NIOSingleStepByteToMessageDecoderTest.allTests),
|
||||
testCase(NIOThreadPoolTest.allTests),
|
||||
testCase(NonBlockingFileIOTest.allTests),
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// NIOHTTPObjectAggregatorTest+XCTest.swift
|
||||
//
|
||||
import XCTest
|
||||
|
||||
///
|
||||
/// NOTE: This file was generated by generate_linux_tests.rb
|
||||
///
|
||||
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
|
||||
///
|
||||
|
||||
extension NIOHTTPServerRequestAggregatorTest {
|
||||
|
||||
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
|
||||
static var allTests : [(String, (NIOHTTPServerRequestAggregatorTest) -> () throws -> Void)] {
|
||||
return [
|
||||
("testAggregateNoBody", testAggregateNoBody),
|
||||
("testAggregateWithBody", testAggregateWithBody),
|
||||
("testAggregateChunkedBody", testAggregateChunkedBody),
|
||||
("testAggregateWithTrailer", testAggregateWithTrailer),
|
||||
("testOversizeRequest", testOversizeRequest),
|
||||
("testOversizedRequestWithoutKeepAlive", testOversizedRequestWithoutKeepAlive),
|
||||
("testOversizedRequestWithContentLength", testOversizedRequestWithContentLength),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
extension NIOHTTPClientResponseAggregatorTest {
|
||||
|
||||
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
|
||||
static var allTests : [(String, (NIOHTTPClientResponseAggregatorTest) -> () throws -> Void)] {
|
||||
return [
|
||||
("testOversizeResponseHead", testOversizeResponseHead),
|
||||
("testOversizeResponse", testOversizeResponse),
|
||||
("testAggregatedResponse", testAggregatedResponse),
|
||||
("testOkAfterOversized", testOkAfterOversized),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,459 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import XCTest
|
||||
import NIO
|
||||
import NIOHTTP1
|
||||
import NIOTestUtils
|
||||
|
||||
|
||||
private final class ReadRecorder<T: Equatable>: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = T
|
||||
|
||||
enum Event: Equatable {
|
||||
case channelRead(InboundIn)
|
||||
case httpFrameTooLongEvent
|
||||
case httpExpectationFailedEvent
|
||||
|
||||
static func ==(lhs: Event, rhs: Event) -> Bool {
|
||||
switch (lhs, rhs) {
|
||||
case (.channelRead(let b1), .channelRead(let b2)):
|
||||
return b1 == b2
|
||||
case (.httpFrameTooLongEvent, .httpFrameTooLongEvent):
|
||||
return true
|
||||
case (.httpExpectationFailedEvent, .httpExpectationFailedEvent):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public var reads: [Event] = []
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
self.reads.append(.channelRead(self.unwrapInboundIn(data)))
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
|
||||
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
||||
switch event {
|
||||
case let evt as NIOHTTPObjectAggregatorEvent where evt == NIOHTTPObjectAggregatorEvent.httpFrameTooLong:
|
||||
self.reads.append(.httpFrameTooLongEvent)
|
||||
case let evt as NIOHTTPObjectAggregatorEvent where evt == NIOHTTPObjectAggregatorEvent.httpExpectationFailed:
|
||||
self.reads.append(.httpExpectationFailedEvent)
|
||||
default:
|
||||
context.fireUserInboundEventTriggered(event)
|
||||
}
|
||||
}
|
||||
|
||||
func clear() {
|
||||
self.reads.removeAll(keepingCapacity: true)
|
||||
}
|
||||
}
|
||||
|
||||
private final class WriteRecorder: ChannelOutboundHandler, RemovableChannelHandler {
|
||||
typealias OutboundIn = HTTPServerResponsePart
|
||||
|
||||
public var writes: [HTTPServerResponsePart] = []
|
||||
|
||||
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||
self.writes.append(self.unwrapOutboundIn(data))
|
||||
|
||||
context.write(data, promise: promise)
|
||||
}
|
||||
|
||||
func clear() {
|
||||
self.writes.removeAll(keepingCapacity: true)
|
||||
}
|
||||
}
|
||||
|
||||
private extension ByteBuffer {
|
||||
func assertContainsOnly(_ string: String) {
|
||||
let innerData = self.getString(at: self.readerIndex, length: self.readableBytes)!
|
||||
XCTAssertEqual(innerData, string)
|
||||
}
|
||||
}
|
||||
|
||||
private func asHTTPResponseHead(_ response: HTTPServerResponsePart) -> HTTPResponseHead? {
|
||||
switch response {
|
||||
case .head(let resHead):
|
||||
return resHead
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class NIOHTTPServerRequestAggregatorTest: XCTestCase {
|
||||
var channel: EmbeddedChannel! = nil
|
||||
var requestHead: HTTPRequestHead! = nil
|
||||
var responseHead: HTTPResponseHead! = nil
|
||||
fileprivate var readRecorder: ReadRecorder<NIOHTTPServerRequestFull>! = nil
|
||||
fileprivate var writeRecorder: WriteRecorder! = nil
|
||||
fileprivate var aggregatorHandler: NIOHTTPServerRequestAggregator! = nil
|
||||
|
||||
override func setUp() {
|
||||
self.channel = EmbeddedChannel()
|
||||
self.readRecorder = ReadRecorder()
|
||||
self.writeRecorder = WriteRecorder()
|
||||
self.aggregatorHandler = NIOHTTPServerRequestAggregator(maxContentLength: 1024 * 1024)
|
||||
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPResponseEncoder()).wait())
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(self.writeRecorder).wait())
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(self.aggregatorHandler).wait())
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(self.readRecorder).wait())
|
||||
|
||||
self.requestHead = HTTPRequestHead(version: .init(major: 1, minor: 1), method: .PUT, uri: "/path")
|
||||
self.requestHead.headers.add(name: "Host", value: "example.com")
|
||||
self.requestHead.headers.add(name: "X-Test", value: "True")
|
||||
|
||||
self.responseHead = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok)
|
||||
self.responseHead.headers.add(name: "Server", value: "SwiftNIO")
|
||||
|
||||
// this activates the channel
|
||||
XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
}
|
||||
|
||||
/// Modify pipeline setup to use aggregator with a smaller `maxContentLength`
|
||||
private func resetSmallHandler(maxContentLength: Int) {
|
||||
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.readRecorder!).wait())
|
||||
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.aggregatorHandler!).wait())
|
||||
self.aggregatorHandler = NIOHTTPServerRequestAggregator(maxContentLength: maxContentLength)
|
||||
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.aggregatorHandler).wait())
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(self.readRecorder).wait())
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
if let channel = self.channel {
|
||||
XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: true))
|
||||
self.channel = nil
|
||||
}
|
||||
self.requestHead = nil
|
||||
self.responseHead = nil
|
||||
self.readRecorder = nil
|
||||
self.writeRecorder = nil
|
||||
self.aggregatorHandler = nil
|
||||
}
|
||||
|
||||
func testAggregateNoBody() {
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
|
||||
// Only one request should have made it through.
|
||||
XCTAssertEqual(self.readRecorder.reads,
|
||||
[.channelRead(NIOHTTPServerRequestFull(head: self.requestHead, body: nil))])
|
||||
}
|
||||
|
||||
func testAggregateWithBody() {
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "hello"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
|
||||
// Only one request should have made it through.
|
||||
XCTAssertEqual(self.readRecorder.reads, [
|
||||
.channelRead(NIOHTTPServerRequestFull(
|
||||
head: self.requestHead,
|
||||
body: channel.allocator.buffer(string: "hello")))])
|
||||
}
|
||||
|
||||
func testAggregateChunkedBody() {
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "hello"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "world"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
|
||||
// Only one request should have made it through.
|
||||
XCTAssertEqual(self.readRecorder.reads, [
|
||||
.channelRead(NIOHTTPServerRequestFull(
|
||||
head: self.requestHead,
|
||||
body: channel.allocator.buffer(string: "helloworld")))])
|
||||
}
|
||||
|
||||
func testAggregateWithTrailer() {
|
||||
var reqWithChunking: HTTPRequestHead = self.requestHead
|
||||
reqWithChunking.headers.add(name: "transfer-encoding", value: "chunked")
|
||||
reqWithChunking.headers.add(name: "Trailer", value: "X-Trailer")
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(reqWithChunking)))
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "hello"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "world"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(
|
||||
HTTPHeaders.init([("X-Trailer", "true")]))))
|
||||
|
||||
reqWithChunking.headers.remove(name: "Trailer")
|
||||
reqWithChunking.headers.add(name: "X-Trailer", value: "true")
|
||||
|
||||
// Trailer headers should get moved to normal ones
|
||||
XCTAssertEqual(self.readRecorder.reads, [
|
||||
.channelRead(NIOHTTPServerRequestFull(
|
||||
head: reqWithChunking,
|
||||
body: channel.allocator.buffer(string: "helloworld")))])
|
||||
}
|
||||
|
||||
func testOversizeRequest() {
|
||||
resetSmallHandler(maxContentLength: 4)
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
|
||||
XCTAssertTrue(channel.isActive)
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "he"))))
|
||||
XCTAssertEqual(self.writeRecorder.writes, [])
|
||||
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(string: "llo")))) { error in
|
||||
XCTAssertEqual(NIOHTTPObjectAggregatorError.frameTooLong, error as? NIOHTTPObjectAggregatorError)
|
||||
}
|
||||
|
||||
let resTooLarge = HTTPResponseHead(
|
||||
version: .init(major: 1, minor: 1),
|
||||
status: .payloadTooLarge,
|
||||
headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")]))
|
||||
|
||||
XCTAssertEqual(self.writeRecorder.writes, [
|
||||
HTTPServerResponsePart.head(resTooLarge),
|
||||
HTTPServerResponsePart.end(nil)])
|
||||
|
||||
XCTAssertFalse(channel.isActive)
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in
|
||||
XCTAssertEqual(NIOHTTPObjectAggregatorError.connectionClosed, error as? NIOHTTPObjectAggregatorError)
|
||||
}
|
||||
}
|
||||
|
||||
func testOversizedRequestWithoutKeepAlive() {
|
||||
resetSmallHandler(maxContentLength: 4)
|
||||
|
||||
// send an HTTP/1.0 request with no keep-alive header
|
||||
let requestHead: HTTPRequestHead = HTTPRequestHead(
|
||||
version: .init(major: 1, minor: 0),
|
||||
method: .PUT, uri: "/path",
|
||||
headers: HTTPHeaders(
|
||||
[("Host", "example.com"), ("X-Test", "True"), ("content-length", "5")]))
|
||||
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(requestHead)))
|
||||
|
||||
let resTooLarge = HTTPResponseHead(
|
||||
version: .init(major: 1, minor: 0),
|
||||
status: .payloadTooLarge,
|
||||
headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")]))
|
||||
|
||||
XCTAssertEqual(self.writeRecorder.writes, [
|
||||
HTTPServerResponsePart.head(resTooLarge),
|
||||
HTTPServerResponsePart.end(nil)])
|
||||
|
||||
// Connection should be closed right away
|
||||
XCTAssertFalse(channel.isActive)
|
||||
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in
|
||||
XCTAssertEqual(NIOHTTPObjectAggregatorError.connectionClosed, error as? NIOHTTPObjectAggregatorError)
|
||||
}
|
||||
}
|
||||
|
||||
func testOversizedRequestWithContentLength() {
|
||||
resetSmallHandler(maxContentLength: 4)
|
||||
|
||||
// HTTP/1.1 uses Keep-Alive unless told otherwise
|
||||
let requestHead: HTTPRequestHead = HTTPRequestHead(
|
||||
version: .init(major: 1, minor: 1),
|
||||
method: .PUT, uri: "/path",
|
||||
headers: HTTPHeaders(
|
||||
[("Host", "example.com"), ("X-Test", "True"), ("content-length", "8")]))
|
||||
|
||||
resetSmallHandler(maxContentLength: 4)
|
||||
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(requestHead)))
|
||||
|
||||
let response = asHTTPResponseHead(self.writeRecorder.writes.first!)!
|
||||
XCTAssertEqual(response.status, .payloadTooLarge)
|
||||
XCTAssertEqual(response.headers[canonicalForm: "content-length"], ["0"])
|
||||
XCTAssertEqual(response.version, requestHead.version)
|
||||
|
||||
// Connection should be kept open
|
||||
XCTAssertTrue(channel.isActive)
|
||||
|
||||
// An ill-behaved client may continue writing the request
|
||||
let requestParts = [
|
||||
HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [1, 2, 3, 4])),
|
||||
HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [5,6])),
|
||||
HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [7,8]))
|
||||
]
|
||||
|
||||
for requestPart in requestParts {
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(requestPart))
|
||||
}
|
||||
|
||||
// The aggregated message should not get passed up as it is too large.
|
||||
// There should only be one "frame too long" event despite multiple writes
|
||||
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
|
||||
|
||||
// Write another request that is small enough
|
||||
var secondReqWithContentLength: HTTPRequestHead = self.requestHead
|
||||
secondReqWithContentLength.headers.replaceOrAdd(name: "content-length", value: "2")
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(secondReqWithContentLength)))
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(bytes: [1]))))
|
||||
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
|
||||
channel.allocator.buffer(bytes: [2]))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
|
||||
XCTAssertEqual(self.readRecorder.reads, [
|
||||
.httpFrameTooLongEvent,
|
||||
.channelRead(NIOHTTPServerRequestFull(
|
||||
head: secondReqWithContentLength,
|
||||
body: channel.allocator.buffer(bytes: [1, 2])))])
|
||||
}
|
||||
}
|
||||
|
||||
class NIOHTTPClientResponseAggregatorTest: XCTestCase {
|
||||
var channel: EmbeddedChannel! = nil
|
||||
var requestHead: HTTPRequestHead! = nil
|
||||
var responseHead: HTTPResponseHead! = nil
|
||||
fileprivate var readRecorder: ReadRecorder<NIOHTTPClientResponseFull>! = nil
|
||||
fileprivate var aggregatorHandler: NIOHTTPClientResponseAggregator! = nil
|
||||
|
||||
override func setUp() {
|
||||
self.channel = EmbeddedChannel()
|
||||
self.readRecorder = ReadRecorder()
|
||||
self.aggregatorHandler = NIOHTTPClientResponseAggregator(maxContentLength: 1024 * 1024)
|
||||
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPRequestEncoder()).wait())
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(self.aggregatorHandler).wait())
|
||||
XCTAssertNoThrow(try channel.pipeline.addHandler(self.readRecorder).wait())
|
||||
|
||||
self.requestHead = HTTPRequestHead(version: .init(major: 1, minor: 1), method: .PUT, uri: "/path")
|
||||
self.requestHead.headers.add(name: "Host", value: "example.com")
|
||||
self.requestHead.headers.add(name: "X-Test", value: "True")
|
||||
|
||||
self.responseHead = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok)
|
||||
self.responseHead.headers.add(name: "Server", value: "SwiftNIO")
|
||||
|
||||
// this activates the channel
|
||||
XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
}
|
||||
|
||||
/// Modify pipeline setup to use aggregator with a smaller `maxContentLength`
|
||||
private func resetSmallHandler(maxContentLength: Int) {
|
||||
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.readRecorder!).wait())
|
||||
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.aggregatorHandler!).wait())
|
||||
self.aggregatorHandler = NIOHTTPClientResponseAggregator(maxContentLength: maxContentLength)
|
||||
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.aggregatorHandler).wait())
|
||||
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.readRecorder!).wait())
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
if let channel = self.channel {
|
||||
XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: true))
|
||||
self.channel = nil
|
||||
}
|
||||
self.requestHead = nil
|
||||
self.responseHead = nil
|
||||
self.readRecorder = nil
|
||||
self.aggregatorHandler = nil
|
||||
}
|
||||
|
||||
func testOversizeResponseHead() {
|
||||
resetSmallHandler(maxContentLength: 5)
|
||||
|
||||
var resHead: HTTPResponseHead = self.responseHead
|
||||
resHead.headers.replaceOrAdd(name: "content-length", value: "10")
|
||||
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.head(resHead)))
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
|
||||
|
||||
// User event triggered
|
||||
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
|
||||
}
|
||||
|
||||
func testOversizeResponse() {
|
||||
resetSmallHandler(maxContentLength: 5)
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "hello"))))
|
||||
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "world"))))
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
|
||||
|
||||
// User event triggered
|
||||
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
|
||||
}
|
||||
|
||||
|
||||
func testAggregatedResponse() {
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "hello"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "world"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.end(HTTPHeaders([("X-Trail", "true")]))))
|
||||
|
||||
var aggregatedHead: HTTPResponseHead = self.responseHead
|
||||
aggregatedHead.headers.add(name: "X-Trail", value: "true")
|
||||
|
||||
XCTAssertEqual(self.readRecorder.reads, [
|
||||
.channelRead(NIOHTTPClientResponseFull(
|
||||
head: aggregatedHead,
|
||||
body: self.channel.allocator.buffer(string: "helloworld")))])
|
||||
}
|
||||
|
||||
func testOkAfterOversized() {
|
||||
resetSmallHandler(maxContentLength: 4)
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "hell"))))
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "owor"))))
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "ld"))))
|
||||
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
|
||||
|
||||
// User event triggered
|
||||
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(
|
||||
HTTPClientResponsePart.body(
|
||||
self.channel.allocator.buffer(string: "test"))))
|
||||
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
|
||||
|
||||
XCTAssertEqual(self.readRecorder.reads, [
|
||||
.httpFrameTooLongEvent,
|
||||
.channelRead(NIOHTTPClientResponseFull(
|
||||
head: self.responseHead,
|
||||
body: self.channel.allocator.buffer(string: "test")))])
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue