HTTPObjectAggregator implementation (#1664)

* HTTPServerObjectAggregator for requests

* Apply suggestions from code review

Co-authored-by: Cory Benfield <lukasa@apple.com>

* most of code review comments addressed

* tidy up state machine close

* bad line breaks

* more verbose error reporting

* removes expect:continue functionality due to #1422

* public type naming and error handling tweaks

* wrong expectation in test

* do not quietly swallow unexpected errors in channelRead

Co-authored-by: Cory Benfield <lukasa@apple.com>
This commit is contained in:
Andrius 2020-10-08 11:37:32 +01:00 committed by GitHub
parent d2372de507
commit e6b7d718a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 932 additions and 0 deletions

View File

@ -1435,3 +1435,21 @@ extension HTTPMethod: RawRepresentable {
}
}
}
extension HTTPResponseHead {
internal var contentLength: Int? {
return headers.contentLength
}
}
extension HTTPRequestHead {
internal var contentLength: Int? {
return headers.contentLength
}
}
extension HTTPHeaders {
internal var contentLength: Int? {
return self.first(name: "content-length").flatMap { Int($0) }
}
}

View File

@ -0,0 +1,400 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
/// The parts of a complete HTTP response from the view of the client.
///
/// A full HTTP request is made up of a response header encoded by `.head`
/// and an optional `.body`.
public struct NIOHTTPServerRequestFull {
public var head: HTTPRequestHead
public var body: ByteBuffer?
public init(head: HTTPRequestHead, body: ByteBuffer?) {
self.head = head
self.body = body
}
}
extension NIOHTTPServerRequestFull: Equatable {}
/// The parts of a complete HTTP response from the view of the client.
///
/// Afull HTTP response is made up of a response header encoded by `.head`
/// and an optional `.body`.
public struct NIOHTTPClientResponseFull {
public var head: HTTPResponseHead
public var body: ByteBuffer?
public init(head: HTTPResponseHead, body: ByteBuffer?) {
self.head = head
self.body = body
}
}
extension NIOHTTPClientResponseFull: Equatable {}
public struct NIOHTTPObjectAggregatorError: Error, Equatable {
private enum Base {
case frameTooLong
case connectionClosed
case endingIgnoredMessage
case unexpectedMessageHead
case unexpectedMessageBody
case unexpectedMessageEnd
}
private var base: Base
private init(base: Base) {
self.base = base
}
public static let frameTooLong = NIOHTTPObjectAggregatorError(base: .frameTooLong)
public static let connectionClosed = NIOHTTPObjectAggregatorError(base: .connectionClosed)
public static let endingIgnoredMessage = NIOHTTPObjectAggregatorError(base: .endingIgnoredMessage)
public static let unexpectedMessageHead = NIOHTTPObjectAggregatorError(base: .unexpectedMessageHead)
public static let unexpectedMessageBody = NIOHTTPObjectAggregatorError(base: .unexpectedMessageBody)
public static let unexpectedMessageEnd = NIOHTTPObjectAggregatorError(base: .unexpectedMessageEnd)
}
public struct NIOHTTPObjectAggregatorEvent: Hashable {
private enum Base {
case httpExpectationFailed
case httpFrameTooLong
}
private var base: Base
private init(base: Base) {
self.base = base
}
public static let httpExpectationFailed = NIOHTTPObjectAggregatorEvent(base: .httpExpectationFailed)
public static let httpFrameTooLong = NIOHTTPObjectAggregatorEvent(base: .httpFrameTooLong)
}
/// The state of the aggregator connection.
internal enum AggregatorState {
/// Nothing is active on this connection, the next message we expect would be a request `.head`.
case idle
/// Ill-behaving client may be sending content that is too large
case ignoringContent
/// We are receiving and aggregating a request
case receiving
/// Connection should be closed
case closed
mutating func messageHeadReceived() throws {
switch self {
case .idle:
self = .receiving
case .ignoringContent, .receiving:
throw NIOHTTPObjectAggregatorError.unexpectedMessageHead
case .closed:
throw NIOHTTPObjectAggregatorError.connectionClosed
}
}
mutating func messageBodyReceived() throws {
switch self {
case .receiving:
()
case .ignoringContent:
throw NIOHTTPObjectAggregatorError.frameTooLong
case .idle:
throw NIOHTTPObjectAggregatorError.unexpectedMessageBody
case .closed:
throw NIOHTTPObjectAggregatorError.connectionClosed
}
}
mutating func messageEndReceived() throws {
switch self {
case .receiving:
// Got the request end we were waiting for.
self = .idle
case .ignoringContent:
// Expected transition from a state where message contents are getting
// ignored because the message is too large. Throwing an error prevents
// the normal control flow from continuing into dispatching the completed
// invalid message to the next handler.
self = .idle
throw NIOHTTPObjectAggregatorError.endingIgnoredMessage
case .idle:
throw NIOHTTPObjectAggregatorError.unexpectedMessageEnd
case .closed:
throw NIOHTTPObjectAggregatorError.connectionClosed
}
}
mutating func handlingOversizeMessage() {
switch self {
case .receiving, .idle:
self = .ignoringContent
case .ignoringContent, .closed:
// If we are already ignoring content or connection is closed, should not get here
preconditionFailure("Unreachable state: should never handle overized message in \(self)")
}
}
mutating func closed() {
self = .closed
}
}
/// A `ChannelInboundHandler` that handles HTTP chunked `HTTPServerRequestPart`
/// messages by aggregating individual message chunks into a single
/// `NIOHTTPServerRequestFull`.
///
/// This is achieved by buffering the contents of all received `HTTPServerRequestPart`
/// messages until `HTTPServerRequestPart.end` is received, then assembling the
/// full message and firing a channel read upstream with it. It is useful for when you do not
/// want to deal with chunked messages and just want to receive everything at once, and
/// are happy with the additional memory used and delay handling of the message until
/// everything has been received.
///
/// `NIOHTTPServerRequestAggregator` may end up sending a `HTTPResponseHead`:
/// - Response status `413 Request Entity Too Large` when either the
/// `content-length` or the bytes received so far exceed `maxContentLength`.
///
/// `NIOHTTPServerRequestAggregator` may close the connection if it is impossible
/// to recover:
/// - If `content-length` is too large and `keep-alive` is off.
/// - If the bytes received exceed `maxContentLength` and the client didn't signal
/// `content-length`
public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, RemovableChannelHandler {
public typealias InboundIn = HTTPServerRequestPart
public typealias InboundOut = NIOHTTPServerRequestFull
// Aggregator may generate responses of its own
public typealias OutboundOut = HTTPServerResponsePart
private var fullMessageHead: HTTPRequestHead? = nil
private var buffer: ByteBuffer! = nil
private var maxContentLength: Int
private var closeOnExpectationFailed: Bool
private var state: AggregatorState
public init(maxContentLength: Int, closeOnExpectationFailed: Bool = false) {
precondition(maxContentLength >= 0, "maxContentLength must not be negative")
self.maxContentLength = maxContentLength
self.closeOnExpectationFailed = closeOnExpectationFailed
self.state = .idle
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let msg = self.unwrapInboundIn(data)
var serverResponse: HTTPResponseHead? = nil
do {
switch msg {
case .head(let httpHead):
try self.state.messageHeadReceived()
serverResponse = self.beginAggregation(context: context, request: httpHead, message: msg)
case .body(var content):
try self.state.messageBodyReceived()
serverResponse = self.aggregate(context: context, content: &content, message: msg)
case .end(let trailingHeaders):
try self.state.messageEndReceived()
self.endAggregation(context: context, trailingHeaders: trailingHeaders)
}
} catch let error as NIOHTTPObjectAggregatorError {
context.fireErrorCaught(error)
// Won't be able to complete those
self.fullMessageHead = nil
self.buffer.clear()
} catch let error {
context.fireErrorCaught(error)
}
// Generated a server esponse to send back
if let response = serverResponse {
context.write(self.wrapOutboundOut(.head(response)), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
if response.status == .payloadTooLarge {
// If indicated content length is too large
self.state.handlingOversizeMessage()
context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
}
if !response.headers.isKeepAlive(version: response.version) {
context.close(promise: nil)
self.state.closed()
}
}
}
private func beginAggregation(context: ChannelHandlerContext, request: HTTPRequestHead, message: InboundIn) -> HTTPResponseHead? {
self.fullMessageHead = request
if let contentLength = request.contentLength, contentLength > self.maxContentLength {
return self.handleOversizeMessage(message: message)
}
return nil
}
private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer, message: InboundIn) -> HTTPResponseHead? {
if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) {
return self.handleOversizeMessage(message: message)
} else {
self.buffer.writeBuffer(&content)
return nil
}
}
private func endAggregation(context: ChannelHandlerContext, trailingHeaders: HTTPHeaders?) {
if var aggregated = self.fullMessageHead {
// Remove `Trailer` from existing header fields and append trailer fields to existing header fields
// See rfc7230 4.1.3 Decoding Chunked
if let headers = trailingHeaders {
aggregated.headers.remove(name: "trailer")
aggregated.headers.add(contentsOf: headers)
}
let fullMessage = NIOHTTPServerRequestFull(head: aggregated,
body: self.buffer.readableBytes > 0 ? self.buffer : nil)
self.fullMessageHead = nil
self.buffer.clear()
context.fireChannelRead(NIOAny(fullMessage))
}
}
private func handleOversizeMessage(message: InboundIn) -> HTTPResponseHead {
var payloadTooLargeHead = HTTPResponseHead(
version: self.fullMessageHead?.version ?? .init(major: 1, minor: 1),
status: .payloadTooLarge,
headers: HTTPHeaders([("content-length", "0")]))
switch message {
case .head(let request):
if !request.isKeepAlive {
// If keep-alive is off and, no need to leave the connection open.
// Send back a 413 and close the connection.
payloadTooLargeHead.headers.add(name: "connection", value: "close")
}
default:
// The client started to send data already, close because it's impossible to recover.
// Send back a 413 and close the connection.
payloadTooLargeHead.headers.add(name: "connection", value: "close")
}
return payloadTooLargeHead
}
public func handlerAdded(context: ChannelHandlerContext) {
self.buffer = context.channel.allocator.buffer(capacity: 0)
}
}
/// A `ChannelInboundHandler` that handles HTTP chunked `HTTPClientResponsePart`
/// messages by aggregating individual message chunks into a single
/// `NIOHTTPClientResponseFull`.
///
/// This is achieved by buffering the contents of all received `HTTPClientResponsePart`
/// messages until `HTTPClientResponsePart.end` is received, then assembling the
/// full message and firing a channel read upstream with it. Useful when you do not
/// want to deal with chunked messages and just want to receive everything at once, and
/// are happy with the additional memory used and delay handling of the message until
/// everything has been received.
///
/// If `NIOHTTPClientResponseAggregator` encounters a message larger than
/// `maxContentLength`, it discards the aggregated contents until the next
/// `HTTPClientResponsePart.end` and signals that via
/// `fireUserInboundEventTriggered`.
public final class NIOHTTPClientResponseAggregator: ChannelInboundHandler, RemovableChannelHandler {
public typealias InboundIn = HTTPClientResponsePart
public typealias InboundOut = NIOHTTPClientResponseFull
private var fullMessageHead: HTTPResponseHead? = nil
private var buffer: ByteBuffer! = nil
private var maxContentLength: Int
private var state: AggregatorState
public init(maxContentLength: Int) {
precondition(maxContentLength >= 0, "maxContentLength must not be negative")
self.maxContentLength = maxContentLength
self.state = .idle
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let msg = self.unwrapInboundIn(data)
do {
switch msg {
case .head(let httpHead):
try self.state.messageHeadReceived()
try self.beginAggregation(context: context, response: httpHead)
case .body(var content):
try self.state.messageBodyReceived()
try self.aggregate(context: context, content: &content)
case .end(let trailingHeaders):
try self.state.messageEndReceived()
self.endAggregation(context: context, trailingHeaders: trailingHeaders)
}
} catch let error as NIOHTTPObjectAggregatorError {
context.fireErrorCaught(error)
// Won't be able to complete those
self.fullMessageHead = nil
self.buffer.clear()
} catch let error {
context.fireErrorCaught(error)
}
}
private func beginAggregation(context: ChannelHandlerContext, response: HTTPResponseHead) throws {
self.fullMessageHead = response
if let contentLength = response.contentLength, contentLength > self.maxContentLength {
self.state.handlingOversizeMessage()
context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
}
}
private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer) throws {
if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) {
self.state.handlingOversizeMessage()
context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
} else {
self.buffer.writeBuffer(&content)
}
}
private func endAggregation(context: ChannelHandlerContext, trailingHeaders: HTTPHeaders?) {
if var aggregated = self.fullMessageHead {
// Remove `Trailer` from existing header fields and append trailer fields to existing header fields
// See rfc7230 4.1.3 Decoding Chunked
if let headers = trailingHeaders {
aggregated.headers.remove(name: "trailer")
aggregated.headers.add(contentsOf: headers)
}
let fullMessage = NIOHTTPClientResponseFull(
head: aggregated,
body: self.buffer.readableBytes > 0 ? self.buffer : nil)
self.fullMessageHead = nil
self.buffer.clear()
context.fireChannelRead(NIOAny(fullMessage))
}
}
public func handlerAdded(context: ChannelHandlerContext) {
self.buffer = context.channel.allocator.buffer(capacity: 0)
}
}

View File

@ -98,6 +98,8 @@ class LinuxMainRunnerImpl: LinuxMainRunner {
testCase(NIOCloseOnErrorHandlerTest.allTests),
testCase(NIOConcurrencyHelpersTests.allTests),
testCase(NIOHTTP1TestServerTest.allTests),
testCase(NIOHTTPClientResponseAggregatorTest.allTests),
testCase(NIOHTTPServerRequestAggregatorTest.allTests),
testCase(NIOSingleStepByteToMessageDecoderTest.allTests),
testCase(NIOThreadPoolTest.allTests),
testCase(NonBlockingFileIOTest.allTests),

View File

@ -0,0 +1,53 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//
// NIOHTTPObjectAggregatorTest+XCTest.swift
//
import XCTest
///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///
extension NIOHTTPServerRequestAggregatorTest {
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
static var allTests : [(String, (NIOHTTPServerRequestAggregatorTest) -> () throws -> Void)] {
return [
("testAggregateNoBody", testAggregateNoBody),
("testAggregateWithBody", testAggregateWithBody),
("testAggregateChunkedBody", testAggregateChunkedBody),
("testAggregateWithTrailer", testAggregateWithTrailer),
("testOversizeRequest", testOversizeRequest),
("testOversizedRequestWithoutKeepAlive", testOversizedRequestWithoutKeepAlive),
("testOversizedRequestWithContentLength", testOversizedRequestWithContentLength),
]
}
}
extension NIOHTTPClientResponseAggregatorTest {
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
static var allTests : [(String, (NIOHTTPClientResponseAggregatorTest) -> () throws -> Void)] {
return [
("testOversizeResponseHead", testOversizeResponseHead),
("testOversizeResponse", testOversizeResponse),
("testAggregatedResponse", testAggregatedResponse),
("testOkAfterOversized", testOkAfterOversized),
]
}
}

View File

@ -0,0 +1,459 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
import NIO
import NIOHTTP1
import NIOTestUtils
private final class ReadRecorder<T: Equatable>: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = T
enum Event: Equatable {
case channelRead(InboundIn)
case httpFrameTooLongEvent
case httpExpectationFailedEvent
static func ==(lhs: Event, rhs: Event) -> Bool {
switch (lhs, rhs) {
case (.channelRead(let b1), .channelRead(let b2)):
return b1 == b2
case (.httpFrameTooLongEvent, .httpFrameTooLongEvent):
return true
case (.httpExpectationFailedEvent, .httpExpectationFailedEvent):
return true
default:
return false
}
}
}
public var reads: [Event] = []
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.reads.append(.channelRead(self.unwrapInboundIn(data)))
context.fireChannelRead(data)
}
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case let evt as NIOHTTPObjectAggregatorEvent where evt == NIOHTTPObjectAggregatorEvent.httpFrameTooLong:
self.reads.append(.httpFrameTooLongEvent)
case let evt as NIOHTTPObjectAggregatorEvent where evt == NIOHTTPObjectAggregatorEvent.httpExpectationFailed:
self.reads.append(.httpExpectationFailedEvent)
default:
context.fireUserInboundEventTriggered(event)
}
}
func clear() {
self.reads.removeAll(keepingCapacity: true)
}
}
private final class WriteRecorder: ChannelOutboundHandler, RemovableChannelHandler {
typealias OutboundIn = HTTPServerResponsePart
public var writes: [HTTPServerResponsePart] = []
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
self.writes.append(self.unwrapOutboundIn(data))
context.write(data, promise: promise)
}
func clear() {
self.writes.removeAll(keepingCapacity: true)
}
}
private extension ByteBuffer {
func assertContainsOnly(_ string: String) {
let innerData = self.getString(at: self.readerIndex, length: self.readableBytes)!
XCTAssertEqual(innerData, string)
}
}
private func asHTTPResponseHead(_ response: HTTPServerResponsePart) -> HTTPResponseHead? {
switch response {
case .head(let resHead):
return resHead
default:
return nil
}
}
class NIOHTTPServerRequestAggregatorTest: XCTestCase {
var channel: EmbeddedChannel! = nil
var requestHead: HTTPRequestHead! = nil
var responseHead: HTTPResponseHead! = nil
fileprivate var readRecorder: ReadRecorder<NIOHTTPServerRequestFull>! = nil
fileprivate var writeRecorder: WriteRecorder! = nil
fileprivate var aggregatorHandler: NIOHTTPServerRequestAggregator! = nil
override func setUp() {
self.channel = EmbeddedChannel()
self.readRecorder = ReadRecorder()
self.writeRecorder = WriteRecorder()
self.aggregatorHandler = NIOHTTPServerRequestAggregator(maxContentLength: 1024 * 1024)
XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPResponseEncoder()).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(self.writeRecorder).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(self.aggregatorHandler).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(self.readRecorder).wait())
self.requestHead = HTTPRequestHead(version: .init(major: 1, minor: 1), method: .PUT, uri: "/path")
self.requestHead.headers.add(name: "Host", value: "example.com")
self.requestHead.headers.add(name: "X-Test", value: "True")
self.responseHead = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok)
self.responseHead.headers.add(name: "Server", value: "SwiftNIO")
// this activates the channel
XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
}
/// Modify pipeline setup to use aggregator with a smaller `maxContentLength`
private func resetSmallHandler(maxContentLength: Int) {
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.readRecorder!).wait())
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.aggregatorHandler!).wait())
self.aggregatorHandler = NIOHTTPServerRequestAggregator(maxContentLength: maxContentLength)
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.aggregatorHandler).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(self.readRecorder).wait())
}
override func tearDown() {
if let channel = self.channel {
XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: true))
self.channel = nil
}
self.requestHead = nil
self.responseHead = nil
self.readRecorder = nil
self.writeRecorder = nil
self.aggregatorHandler = nil
}
func testAggregateNoBody() {
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
// Only one request should have made it through.
XCTAssertEqual(self.readRecorder.reads,
[.channelRead(NIOHTTPServerRequestFull(head: self.requestHead, body: nil))])
}
func testAggregateWithBody() {
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "hello"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
// Only one request should have made it through.
XCTAssertEqual(self.readRecorder.reads, [
.channelRead(NIOHTTPServerRequestFull(
head: self.requestHead,
body: channel.allocator.buffer(string: "hello")))])
}
func testAggregateChunkedBody() {
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "hello"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "world"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
// Only one request should have made it through.
XCTAssertEqual(self.readRecorder.reads, [
.channelRead(NIOHTTPServerRequestFull(
head: self.requestHead,
body: channel.allocator.buffer(string: "helloworld")))])
}
func testAggregateWithTrailer() {
var reqWithChunking: HTTPRequestHead = self.requestHead
reqWithChunking.headers.add(name: "transfer-encoding", value: "chunked")
reqWithChunking.headers.add(name: "Trailer", value: "X-Trailer")
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(reqWithChunking)))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "hello"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "world"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(
HTTPHeaders.init([("X-Trailer", "true")]))))
reqWithChunking.headers.remove(name: "Trailer")
reqWithChunking.headers.add(name: "X-Trailer", value: "true")
// Trailer headers should get moved to normal ones
XCTAssertEqual(self.readRecorder.reads, [
.channelRead(NIOHTTPServerRequestFull(
head: reqWithChunking,
body: channel.allocator.buffer(string: "helloworld")))])
}
func testOversizeRequest() {
resetSmallHandler(maxContentLength: 4)
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead)))
XCTAssertTrue(channel.isActive)
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "he"))))
XCTAssertEqual(self.writeRecorder.writes, [])
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(string: "llo")))) { error in
XCTAssertEqual(NIOHTTPObjectAggregatorError.frameTooLong, error as? NIOHTTPObjectAggregatorError)
}
let resTooLarge = HTTPResponseHead(
version: .init(major: 1, minor: 1),
status: .payloadTooLarge,
headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")]))
XCTAssertEqual(self.writeRecorder.writes, [
HTTPServerResponsePart.head(resTooLarge),
HTTPServerResponsePart.end(nil)])
XCTAssertFalse(channel.isActive)
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in
XCTAssertEqual(NIOHTTPObjectAggregatorError.connectionClosed, error as? NIOHTTPObjectAggregatorError)
}
}
func testOversizedRequestWithoutKeepAlive() {
resetSmallHandler(maxContentLength: 4)
// send an HTTP/1.0 request with no keep-alive header
let requestHead: HTTPRequestHead = HTTPRequestHead(
version: .init(major: 1, minor: 0),
method: .PUT, uri: "/path",
headers: HTTPHeaders(
[("Host", "example.com"), ("X-Test", "True"), ("content-length", "5")]))
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(requestHead)))
let resTooLarge = HTTPResponseHead(
version: .init(major: 1, minor: 0),
status: .payloadTooLarge,
headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")]))
XCTAssertEqual(self.writeRecorder.writes, [
HTTPServerResponsePart.head(resTooLarge),
HTTPServerResponsePart.end(nil)])
// Connection should be closed right away
XCTAssertFalse(channel.isActive)
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in
XCTAssertEqual(NIOHTTPObjectAggregatorError.connectionClosed, error as? NIOHTTPObjectAggregatorError)
}
}
func testOversizedRequestWithContentLength() {
resetSmallHandler(maxContentLength: 4)
// HTTP/1.1 uses Keep-Alive unless told otherwise
let requestHead: HTTPRequestHead = HTTPRequestHead(
version: .init(major: 1, minor: 1),
method: .PUT, uri: "/path",
headers: HTTPHeaders(
[("Host", "example.com"), ("X-Test", "True"), ("content-length", "8")]))
resetSmallHandler(maxContentLength: 4)
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(requestHead)))
let response = asHTTPResponseHead(self.writeRecorder.writes.first!)!
XCTAssertEqual(response.status, .payloadTooLarge)
XCTAssertEqual(response.headers[canonicalForm: "content-length"], ["0"])
XCTAssertEqual(response.version, requestHead.version)
// Connection should be kept open
XCTAssertTrue(channel.isActive)
// An ill-behaved client may continue writing the request
let requestParts = [
HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [1, 2, 3, 4])),
HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [5,6])),
HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [7,8]))
]
for requestPart in requestParts {
XCTAssertThrowsError(try self.channel.writeInbound(requestPart))
}
// The aggregated message should not get passed up as it is too large.
// There should only be one "frame too long" event despite multiple writes
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
// Write another request that is small enough
var secondReqWithContentLength: HTTPRequestHead = self.requestHead
secondReqWithContentLength.headers.replaceOrAdd(name: "content-length", value: "2")
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(secondReqWithContentLength)))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(bytes: [1]))))
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(
channel.allocator.buffer(bytes: [2]))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil)))
XCTAssertEqual(self.readRecorder.reads, [
.httpFrameTooLongEvent,
.channelRead(NIOHTTPServerRequestFull(
head: secondReqWithContentLength,
body: channel.allocator.buffer(bytes: [1, 2])))])
}
}
class NIOHTTPClientResponseAggregatorTest: XCTestCase {
var channel: EmbeddedChannel! = nil
var requestHead: HTTPRequestHead! = nil
var responseHead: HTTPResponseHead! = nil
fileprivate var readRecorder: ReadRecorder<NIOHTTPClientResponseFull>! = nil
fileprivate var aggregatorHandler: NIOHTTPClientResponseAggregator! = nil
override func setUp() {
self.channel = EmbeddedChannel()
self.readRecorder = ReadRecorder()
self.aggregatorHandler = NIOHTTPClientResponseAggregator(maxContentLength: 1024 * 1024)
XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPRequestEncoder()).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(self.aggregatorHandler).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(self.readRecorder).wait())
self.requestHead = HTTPRequestHead(version: .init(major: 1, minor: 1), method: .PUT, uri: "/path")
self.requestHead.headers.add(name: "Host", value: "example.com")
self.requestHead.headers.add(name: "X-Test", value: "True")
self.responseHead = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok)
self.responseHead.headers.add(name: "Server", value: "SwiftNIO")
// this activates the channel
XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
}
/// Modify pipeline setup to use aggregator with a smaller `maxContentLength`
private func resetSmallHandler(maxContentLength: Int) {
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.readRecorder!).wait())
XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.aggregatorHandler!).wait())
self.aggregatorHandler = NIOHTTPClientResponseAggregator(maxContentLength: maxContentLength)
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.aggregatorHandler).wait())
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.readRecorder!).wait())
}
override func tearDown() {
if let channel = self.channel {
XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: true))
self.channel = nil
}
self.requestHead = nil
self.responseHead = nil
self.readRecorder = nil
self.aggregatorHandler = nil
}
func testOversizeResponseHead() {
resetSmallHandler(maxContentLength: 5)
var resHead: HTTPResponseHead = self.responseHead
resHead.headers.replaceOrAdd(name: "content-length", value: "10")
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.head(resHead)))
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
// User event triggered
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
}
func testOversizeResponse() {
resetSmallHandler(maxContentLength: 5)
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "hello"))))
XCTAssertThrowsError(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "world"))))
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
// User event triggered
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
}
func testAggregatedResponse() {
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
XCTAssertNoThrow(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "hello"))))
XCTAssertNoThrow(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "world"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.end(HTTPHeaders([("X-Trail", "true")]))))
var aggregatedHead: HTTPResponseHead = self.responseHead
aggregatedHead.headers.add(name: "X-Trail", value: "true")
XCTAssertEqual(self.readRecorder.reads, [
.channelRead(NIOHTTPClientResponseFull(
head: aggregatedHead,
body: self.channel.allocator.buffer(string: "helloworld")))])
}
func testOkAfterOversized() {
resetSmallHandler(maxContentLength: 4)
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
XCTAssertNoThrow(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "hell"))))
XCTAssertThrowsError(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "owor"))))
XCTAssertThrowsError(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "ld"))))
XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
// User event triggered
XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent])
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead)))
XCTAssertNoThrow(try self.channel.writeInbound(
HTTPClientResponsePart.body(
self.channel.allocator.buffer(string: "test"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.end(nil)))
XCTAssertEqual(self.readRecorder.reads, [
.httpFrameTooLongEvent,
.channelRead(NIOHTTPClientResponseFull(
head: self.responseHead,
body: self.channel.allocator.buffer(string: "test")))])
}
}