Add `AsyncChannel` based `ServerBootstrap.bind()` methods (#2403)

* Add `AsyncChannel` based `ServerBootstrap.bind()` methods

# Motivation
In my previous PR, we added a new async bridge from a NIO `Channel` to Swift Concurrency primitives in the from of the `NIOAsyncChannel`. This type alone is already helpful in bridging `Channel`s to Concurrency; however, it is hard to use since it requires to wrap the `Channel` at the right time otherwise we will drop reads. Furthermore, in the case of protocol negotiation this becomes even trickier since we need to wait until it finishes and then wrap the `Channel`.

# Modification
This PR introduces a few things:
1. New methods on the `ServerBootstrap` which allow the creation of `NIOAsyncChannel` based channels. This can be used in all cases where no protocol negotiation is involved.
2. A new protocol and type called `NIOProtocolNegotiationHandler` and `NIOProtocolNegotiationResult` which is used to identify channel handlers that are doing protocol negotiation.
3. New methods on the `ServerBootstrap` that are aware of protocol negotiation.

# Result
We can now easily and safely create new `AsyncChannel`s from the `ServerBootstrap`

* Code review

* Fix typo

* Fix up tests

* Stop finishing the writer when an error is caught

* Code review

* Fix up writer tests

* Introduce shared protocol negotiation handler state machine

* Correctly handle multi threaded event loops

* Adapt test to assert the channel was closed correctly.

* Code review
This commit is contained in:
Franz Busch 2023-04-26 15:17:07 +01:00 committed by GitHub
parent f7c4655298
commit d836d6bef5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1916 additions and 114 deletions

View File

@ -104,7 +104,7 @@ var targets: [PackageDescription.Target] = [
.testTarget(name: "NIOEmbeddedTests", .testTarget(name: "NIOEmbeddedTests",
dependencies: ["NIOConcurrencyHelpers", "NIOCore", "NIOEmbedded"]), dependencies: ["NIOConcurrencyHelpers", "NIOCore", "NIOEmbedded"]),
.testTarget(name: "NIOPosixTests", .testTarget(name: "NIOPosixTests",
dependencies: ["NIOPosix", "NIOCore", "NIOFoundationCompat", "NIOTestUtils", "NIOConcurrencyHelpers", "NIOEmbedded", "CNIOLinux"]), dependencies: ["NIOPosix", "NIOCore", "NIOFoundationCompat", "NIOTestUtils", "NIOConcurrencyHelpers", "NIOEmbedded", "CNIOLinux", "NIOTLS"]),
.testTarget(name: "NIOConcurrencyHelpersTests", .testTarget(name: "NIOConcurrencyHelpersTests",
dependencies: ["NIOConcurrencyHelpers", "NIOCore"]), dependencies: ["NIOConcurrencyHelpers", "NIOCore"]),
.testTarget(name: "NIODataStructuresTests", .testTarget(name: "NIODataStructuresTests",

View File

@ -62,7 +62,7 @@ public final class NIOAsyncChannel<Inbound: Sendable, Outbound: Sendable>: Senda
public init( public init(
synchronouslyWrapping channel: Channel, synchronouslyWrapping channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = true, isOutboundHalfClosureEnabled: Bool = false,
inboundType: Inbound.Type = Inbound.self, inboundType: Inbound.Type = Inbound.self,
outboundType: Outbound.Type = Outbound.self outboundType: Outbound.Type = Outbound.self
) throws { ) throws {
@ -92,7 +92,7 @@ public final class NIOAsyncChannel<Inbound: Sendable, Outbound: Sendable>: Senda
public init( public init(
synchronouslyWrapping channel: Channel, synchronouslyWrapping channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = true, isOutboundHalfClosureEnabled: Bool = false,
inboundType: Inbound.Type = Inbound.self inboundType: Inbound.Type = Inbound.self
) throws where Outbound == Never { ) throws where Outbound == Never {
channel.eventLoop.preconditionInEventLoop() channel.eventLoop.preconditionInEventLoop()
@ -104,6 +104,67 @@ public final class NIOAsyncChannel<Inbound: Sendable, Outbound: Sendable>: Senda
self.outboundWriter.finish() self.outboundWriter.finish()
} }
@inlinable
@_spi(AsyncChannel)
public init(
channel: Channel,
inboundStream: NIOAsyncChannelInboundStream<Inbound>,
outboundWriter: NIOAsyncChannelOutboundWriter<Outbound>
) {
channel.eventLoop.preconditionInEventLoop()
self.channel = channel
self.inboundStream = inboundStream
self.outboundWriter = outboundWriter
}
@inlinable
@_spi(AsyncChannel)
public static func wrapAsyncChannelForBootstrapBind(
synchronouslyWrapping channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = false,
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
) throws -> NIOAsyncChannel<Inbound, Outbound> where Outbound == Never {
channel.eventLoop.preconditionInEventLoop()
let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) = try channel._syncAddAsyncHandlersForBootstrapBind(
backpressureStrategy: backpressureStrategy,
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
transformationClosure: transformationClosure
)
outboundWriter.finish()
return .init(
channel: channel,
inboundStream: inboundStream,
outboundWriter: outboundWriter
)
}
@inlinable
@_spi(AsyncChannel)
public static func wrapAsyncChannelForBootstrapBindWithProtocolNegotiation(
synchronouslyWrapping channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = false,
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
) throws -> NIOAsyncChannel<Inbound, Outbound> where Outbound == Never {
channel.eventLoop.preconditionInEventLoop()
let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) = try channel._syncAddAsyncHandlersForBootstrapProtocolNegotiation(
backpressureStrategy: backpressureStrategy,
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
transformationClosure: transformationClosure
)
outboundWriter.finish()
return .init(
channel: channel,
inboundStream: inboundStream,
outboundWriter: outboundWriter
)
}
} }
extension Channel { extension Channel {
@ -118,7 +179,7 @@ extension Channel {
self.eventLoop.assertInEventLoop() self.eventLoop.assertInEventLoop()
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>( let inboundStream = try NIOAsyncChannelInboundStream<Inbound>.makeWrappingHandler(
channel: self, channel: self,
backpressureStrategy: backpressureStrategy, backpressureStrategy: backpressureStrategy,
closeRatchet: closeRatchet closeRatchet: closeRatchet
@ -129,4 +190,52 @@ extension Channel {
) )
return (inboundStream, writer) return (inboundStream, writer)
} }
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@inlinable
@_spi(AsyncChannel)
public func _syncAddAsyncHandlersForBootstrapBind<Inbound: Sendable, Outbound: Sendable>(
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
isOutboundHalfClosureEnabled: Bool,
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
) throws -> (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) {
self.eventLoop.assertInEventLoop()
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>.makeBindingHandler(
channel: self,
backpressureStrategy: backpressureStrategy,
closeRatchet: closeRatchet,
transformationClosure: transformationClosure
)
let writer = try NIOAsyncChannelOutboundWriter<Outbound>(
channel: self,
closeRatchet: closeRatchet
)
return (inboundStream, writer)
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@inlinable
@_spi(AsyncChannel)
public func _syncAddAsyncHandlersForBootstrapProtocolNegotiation<Inbound: Sendable, Outbound: Sendable>(
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
isOutboundHalfClosureEnabled: Bool,
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
) throws -> (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) {
self.eventLoop.assertInEventLoop()
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>.makeProtocolNegotiationHandler(
channel: self,
backpressureStrategy: backpressureStrategy,
closeRatchet: closeRatchet,
transformationClosure: transformationClosure
)
let writer = try NIOAsyncChannelOutboundWriter<Outbound>(
channel: self,
closeRatchet: closeRatchet
)
return (inboundStream, writer)
}
} }

View File

@ -19,22 +19,19 @@
@_spi(AsyncChannel) @_spi(AsyncChannel)
public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable { public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
@usableFromInline @usableFromInline
typealias Producer = NIOThrowingAsyncSequenceProducer<Inbound, Error, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncChannelInboundStreamChannelHandler<Inbound>.Delegate> typealias Producer = NIOThrowingAsyncSequenceProducer<Inbound, Error, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate>
/// The underlying async sequence. /// The underlying async sequence.
@usableFromInline let _producer: Producer @usableFromInline let _producer: Producer
@inlinable @inlinable
init( init<HandlerInbound: Sendable>(
channel: Channel, channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
closeRatchet: CloseRatchet closeRatchet: CloseRatchet,
handler: NIOAsyncChannelInboundStreamChannelHandler<HandlerInbound, Inbound>
) throws { ) throws {
channel.eventLoop.preconditionInEventLoop() channel.eventLoop.preconditionInEventLoop()
let handler = NIOAsyncChannelInboundStreamChannelHandler<Inbound>(
eventLoop: channel.eventLoop,
closeRatchet: closeRatchet
)
let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark
if let userProvided = backpressureStrategy { if let userProvided = backpressureStrategy {
@ -47,12 +44,77 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
let sequence = Producer.makeSequence( let sequence = Producer.makeSequence(
backPressureStrategy: strategy, backPressureStrategy: strategy,
delegate: NIOAsyncChannelInboundStreamChannelHandler<Inbound>.Delegate(handler: handler) delegate: NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate(handler: handler)
) )
handler.source = sequence.source handler.source = sequence.source
try channel.pipeline.syncOperations.addHandler(handler) try channel.pipeline.syncOperations.addHandler(handler)
self._producer = sequence.sequence self._producer = sequence.sequence
} }
/// Creates a new ``NIOAsyncChannelInboundStream`` which is used when the pipeline got synchronously wrapped.
@inlinable
static func makeWrappingHandler(
channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
closeRatchet: CloseRatchet
) throws -> NIOAsyncChannelInboundStream {
let handler = NIOAsyncChannelInboundStreamChannelHandler<Inbound, Inbound>.makeWrappingHandler(
eventLoop: channel.eventLoop,
closeRatchet: closeRatchet
)
return try .init(
channel: channel,
backpressureStrategy: backpressureStrategy,
closeRatchet: closeRatchet,
handler: handler
)
}
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel.
@inlinable
static func makeBindingHandler(
channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
closeRatchet: CloseRatchet,
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
) throws -> NIOAsyncChannelInboundStream {
let handler = NIOAsyncChannelInboundStreamChannelHandler<Channel, Inbound>.makeBindingHandler(
eventLoop: channel.eventLoop,
closeRatchet: closeRatchet,
transformationClosure: transformationClosure
)
return try .init(
channel: channel,
backpressureStrategy: backpressureStrategy,
closeRatchet: closeRatchet,
handler: handler
)
}
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel when the child
/// channel does protocol negotiation.
@inlinable
static func makeProtocolNegotiationHandler(
channel: Channel,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
closeRatchet: CloseRatchet,
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
) throws -> NIOAsyncChannelInboundStream {
let handler = NIOAsyncChannelInboundStreamChannelHandler<Channel, Inbound>.makeProtocolNegotiationHandler(
eventLoop: channel.eventLoop,
closeRatchet: closeRatchet,
transformationClosure: transformationClosure
)
return try .init(
channel: channel,
backpressureStrategy: backpressureStrategy,
closeRatchet: closeRatchet,
handler: handler
)
}
} }
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@ -87,4 +149,3 @@ extension NIOAsyncChannelInboundStream: AsyncSequence {
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@available(*, unavailable) @available(*, unavailable)
extension NIOAsyncChannelInboundStream.AsyncIterator: Sendable {} extension NIOAsyncChannelInboundStream.AsyncIterator: Sendable {}

View File

@ -16,7 +16,7 @@
/// ``Channel`` into an asynchronous sequence that supports back-pressure. /// ``Channel`` into an asynchronous sequence that supports back-pressure.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@usableFromInline @usableFromInline
internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Sendable>: ChannelDuplexHandler { internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Sendable, ProducerElement: Sendable>: ChannelDuplexHandler {
@usableFromInline @usableFromInline
enum _ProducingState { enum _ProducingState {
// Not .stopProducing // Not .stopProducing
@ -37,10 +37,10 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
@usableFromInline @usableFromInline
typealias Source = NIOThrowingAsyncSequenceProducer< typealias Source = NIOThrowingAsyncSequenceProducer<
InboundIn, ProducerElement,
Error, Error,
NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark,
NIOAsyncChannelInboundStreamChannelHandler<InboundIn>.Delegate NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate
>.Source >.Source
/// The source of the asynchronous sequence. /// The source of the asynchronous sequence.
@ -53,7 +53,7 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
/// An array of reads which will be yielded to the source with the next channel read complete. /// An array of reads which will be yielded to the source with the next channel read complete.
@usableFromInline @usableFromInline
var buffer: [InboundIn] = [] var buffer: [ProducerElement] = []
/// The current producing state. /// The current producing state.
@usableFromInline @usableFromInline
@ -67,10 +67,73 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
@usableFromInline @usableFromInline
let closeRatchet: CloseRatchet let closeRatchet: CloseRatchet
/// A type indicating what kind of transformation to apply to reads.
@usableFromInline
enum Transformation {
/// A synchronous transformation is applied to incoming reads. This is used when sync wrapping a channel.
case syncWrapping((InboundIn) -> ProducerElement)
/// This is used in the ServerBootstrap since we require to wrap the child channel on it's event loop but yield it on the parent's loop.
case bind((InboundIn) -> EventLoopFuture<ProducerElement>)
/// In the case of protocol negotiation we are applying a future based transformation where we wait for the transformation
/// to finish before we yield it to the source.
case protocolNegotiation((InboundIn) -> EventLoopFuture<ProducerElement>)
}
/// The transformation applied to incoming reads.
@usableFromInline
let transformation: Transformation
@inlinable @inlinable
init(eventLoop: EventLoop, closeRatchet: CloseRatchet) { init(
eventLoop: EventLoop,
closeRatchet: CloseRatchet,
transformation: Transformation
) {
self.eventLoop = eventLoop self.eventLoop = eventLoop
self.closeRatchet = closeRatchet self.closeRatchet = closeRatchet
self.transformation = transformation
}
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used when the pipeline got synchronously wrapped.
@inlinable
static func makeWrappingHandler(
eventLoop: EventLoop,
closeRatchet: CloseRatchet
) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == ProducerElement {
return .init(
eventLoop: eventLoop,
closeRatchet: closeRatchet,
transformation: .syncWrapping { $0 }
)
}
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel.
@inlinable
static func makeBindingHandler(
eventLoop: EventLoop,
closeRatchet: CloseRatchet,
transformationClosure: @escaping (Channel) -> EventLoopFuture<ProducerElement>
) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == Channel {
return .init(
eventLoop: eventLoop,
closeRatchet: closeRatchet,
transformation: .bind(transformationClosure)
)
}
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel when the child
/// channel does protocol negotiation.
@inlinable
static func makeProtocolNegotiationHandler(
eventLoop: EventLoop,
closeRatchet: CloseRatchet,
transformationClosure: @escaping (Channel) -> EventLoopFuture<ProducerElement>
) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == Channel {
return .init(
eventLoop: eventLoop,
closeRatchet: closeRatchet,
transformation: .protocolNegotiation(transformationClosure)
)
} }
@inlinable @inlinable
@ -86,10 +149,67 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
@inlinable @inlinable
func channelRead(context: ChannelHandlerContext, data: NIOAny) { func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.buffer.append(self.unwrapInboundIn(data)) let unwrapped = self.unwrapInboundIn(data)
// We forward on reads here to enable better channel composition. switch self.transformation {
context.fireChannelRead(data) case .syncWrapping(let transformation):
self.buffer.append(transformation(unwrapped))
// We forward on reads here to enable better channel composition.
context.fireChannelRead(data)
case .bind(let transformation):
// The unsafe transfers here are required because we need to use self in whenComplete
// We are making sure to be on our event loop so we can safely use self in whenComplete
let unsafeSelf = NIOLoopBound(self, eventLoop: context.eventLoop)
let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop)
transformation(unwrapped)
.hop(to: context.eventLoop)
.whenComplete { result in
unsafeSelf.value._transformationCompleted(context: unsafeContext.value, result: result)
// We forward the read only after the transformation has been completed. This is super important
// since we are setting up the NIOAsyncChannel handlers in the transformation and
// we must make sure to only generate reads once they are setup. Reads can only
// happen after the child channels hit `channelRead0` that's why we are holding the read here.
context.fireChannelRead(data)
}
case .protocolNegotiation(let protocolNegotiation):
// The unsafe transfers here are required because we need to use self in whenComplete
// We are making sure to be on our event loop so we can safely use self in whenComplete
let unsafeSelf = NIOLoopBound(self, eventLoop: context.eventLoop)
let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop)
protocolNegotiation(unwrapped)
.hop(to: context.eventLoop)
.whenComplete { result in
unsafeSelf.value._transformationCompleted(context: unsafeContext.value, result: result)
}
// We forwarding the read here right away since protocol negotiation often needs reads to progress.
// In this case, we expect the user to synchronously wrap the child channel into a NIOAsyncChannel
// hence we don't have the timing issue as in the `.bind` case.
context.fireChannelRead(data)
}
}
@inlinable
func _transformationCompleted(
context: ChannelHandlerContext,
result: Result<ProducerElement, Error>
) {
context.eventLoop.preconditionInEventLoop()
switch result {
case .success(let transformed):
self.buffer.append(transformed)
// We are delivering out of band here since the future can complete at any point
self._deliverReads(context: context)
case .failure:
// Transformation failed. Nothing to really do here this must be handled in the transformation
// futures themselves.
break
}
} }
@inlinable @inlinable
@ -215,33 +335,35 @@ extension NIOAsyncChannelInboundStreamChannelHandler {
} }
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncChannelInboundStreamChannelHandler { @usableFromInline
struct NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate {
@usableFromInline @usableFromInline
struct Delegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate { let eventLoop: EventLoop
@usableFromInline
let eventLoop: EventLoop
@usableFromInline @usableFromInline
let handler: NIOAsyncChannelInboundStreamChannelHandler<InboundIn> let _didTerminate: () -> Void
@inlinable @usableFromInline
init(handler: NIOAsyncChannelInboundStreamChannelHandler<InboundIn>) { let _produceMore: () -> Void
self.eventLoop = handler.eventLoop
self.handler = handler @inlinable
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelInboundStreamChannelHandler<InboundIn, ProducerElement>) {
self.eventLoop = handler.eventLoop
self._didTerminate = handler._didTerminate
self._produceMore = handler._produceMore
}
@inlinable
func didTerminate() {
self.eventLoop.execute {
self._didTerminate()
} }
}
@inlinable @inlinable
func didTerminate() { func produceMore() {
self.eventLoop.execute { self.eventLoop.execute {
self.handler._didTerminate() self._produceMore()
}
}
@inlinable
func produceMore() {
self.eventLoop.execute {
self.handler._produceMore()
}
} }
} }
} }
@ -249,4 +371,3 @@ extension NIOAsyncChannelInboundStreamChannelHandler {
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@available(*, unavailable) @available(*, unavailable)
extension NIOAsyncChannelInboundStreamChannelHandler: Sendable {} extension NIOAsyncChannelInboundStreamChannelHandler: Sendable {}

View File

@ -12,7 +12,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel. /// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel.
/// ///
/// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the /// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the
@ -90,4 +89,3 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
self._outboundWriter.finish() self._outboundWriter.finish()
} }
} }

View File

@ -113,12 +113,6 @@ internal final class NIOAsyncChannelOutboundWriterHandler<OutboundOut: Sendable>
self.sink = nil self.sink = nil
} }
@inlinable
func errorCaught(context: ChannelHandlerContext, error: Error) {
self.sink?.finish(error: error)
context.fireErrorCaught(error)
}
@inlinable @inlinable
func channelInactive(context: ChannelHandlerContext) { func channelInactive(context: ChannelHandlerContext) {
self.sink?.finish() self.sink?.finish()
@ -172,4 +166,3 @@ extension NIOAsyncChannelOutboundWriterHandler {
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@available(*, unavailable) @available(*, unavailable)
extension NIOAsyncChannelOutboundWriterHandler: Sendable {} extension NIOAsyncChannelOutboundWriterHandler: Sendable {}

View File

@ -91,4 +91,3 @@ final class CloseRatchet {
return self._state.closeWrite() return self._state.closeWrite()
} }
} }

View File

@ -343,3 +343,30 @@ extension RemovableChannelHandler {
context.leavePipeline(removalToken: removalToken) context.leavePipeline(removalToken: removalToken)
} }
} }
/// The result of protocol negotiation.
@_spi(AsyncChannel)
public enum NIOProtocolNegotiationResult<NegotiationResult> {
/// Indicates that the protocol negotiation finished.
case finished(NegotiationResult)
/// Indicates that protocol negotiation has been deferred to the next handler.
case deferredResult(EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>)
}
@_spi(AsyncChannel)
extension NIOProtocolNegotiationResult: Equatable where NegotiationResult: Equatable {}
@_spi(AsyncChannel)
extension NIOProtocolNegotiationResult: Sendable where NegotiationResult: Sendable {}
/// A ``ProtocolNegotiationHandler`` is a ``ChannelHandler`` that is responsible for negotiating networking protocols.
///
/// Typically these handlers are at the tail of the pipeline and wait until the peer indicated what protocol should be used. Once, the protocol
/// has been negotiated the handlers allow user code to configure the pipeline.
@_spi(AsyncChannel)
public protocol NIOProtocolNegotiationHandler: ChannelHandler {
associatedtype NegotiationResult
/// The future which gets succeeded with the protocol negotiation result.
var protocolNegotiationResult: EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> { get }
}

View File

@ -11,7 +11,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
import NIOCore @_spi(AsyncChannel) import NIOCore
#if os(Windows) #if os(Windows)
import ucrt import ucrt
@ -311,8 +311,9 @@ public final class ServerBootstrap {
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket. /// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
/// ///
/// - parameters: /// - parameters:
/// - unixDomainSocketPath: The _Unix domain socket_ path to bind to. `unixDomainSocketPath` must not exist, it will be created by the system. /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`. /// unless `cleanupExistingSocketFile`is set to `true`.
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> { public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
if cleanupExistingSocketFile { if cleanupExistingSocketFile {
do { do {
@ -482,6 +483,454 @@ public final class ServerBootstrap {
} }
} }
// MARK: AsyncChannel based bind
extension ServerBootstrap {
/// Bind the `ServerSocketChannel` to the `host` and `port` parameters.
///
/// - Parameters:
/// - host: The host to bind on.
/// - port: The port to bind on.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
/// - childChannelInboundType: The child channel's inbound type.
/// - childChannelOutboundType: The child channel's outbound type.
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func bind<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
host: String,
port: Int,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isChildChannelOutboundHalfClosureEnabled: Bool = false
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
return try await self.bindAsyncChannel0(
serverBackpressureStrategy: serverBackpressureStrategy,
childBackpressureStrategy: childBackpressureStrategy,
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
) {
return try SocketAddress.makeAddressResolvingHost(host, port: port)
}
}
/// Bind the `ServerSocketChannel` to the `address` parameter.
///
/// - Parameters:
/// - address: The `SocketAddress` to bind on.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
/// - childChannelInboundType: The child channel's inbound type.
/// - childChannelOutboundType: The child channel's outbound type.
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func bind<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
to address: SocketAddress,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
isChildChannelOutboundHalfClosureEnabled: Bool = false
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
return try await self.bindAsyncChannel0(
serverBackpressureStrategy: serverBackpressureStrategy,
childBackpressureStrategy: childBackpressureStrategy,
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
) { address }
}
/// Bind the `ServerSocketChannel` to the `unixDomainSocketPath` parameter.
///
/// - Parameters:
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
/// unless `cleanupExistingSocketFile`is set to `true`.
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
/// - childChannelInboundType: The child channel's inbound type.
/// - childChannelOutboundType: The child channel's outbound type.
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func bind<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
unixDomainSocketPath: String,
cleanupExistingSocketFile: Bool = false,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
isChildChannelOutboundHalfClosureEnabled: Bool = false
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
if cleanupExistingSocketFile {
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
}
return try await self.bind(
unixDomainSocketPath: unixDomainSocketPath,
serverBackpressureStrategy: serverBackpressureStrategy,
childBackpressureStrategy: childBackpressureStrategy,
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
)
}
/// Use the existing bound socket file descriptor.
///
/// - Parameters:
/// - socket: The _Unix file descriptor_ representing the bound stream socket.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
/// - childChannelInboundType: The child channel's inbound type.
/// - childChannelOutboundType: The child channel's outbound type.
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func withBoundSocket<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
_ socket: NIOBSDSocket.Handle,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
isChildChannelOutboundHalfClosureEnabled: Bool = false
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
if enableMPTCP {
throw ChannelError.operationUnsupported
}
return try ServerSocketChannel(socket: socket, eventLoop: eventLoop, group: childEventLoopGroup)
}
return try await self.bindAsyncChannel0(
makeServerChannel: makeChannel,
serverBackpressureStrategy: serverBackpressureStrategy,
childBackpressureStrategy: childBackpressureStrategy,
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
) { (eventLoop, serverChannel) in
let promise = eventLoop.makePromise(of: Void.self)
serverChannel.registerAlreadyConfigured0(promise: promise)
return promise.futureResult
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func bindAsyncChannel0<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
isChildChannelOutboundHalfClosureEnabled: Bool,
_ makeSocketAddress: () throws -> SocketAddress
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
let address = try makeSocketAddress()
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
return try ServerSocketChannel(eventLoop: eventLoop,
group: childEventLoopGroup,
protocolFamily: address.protocol,
enableMPTCP: enableMPTCP)
}
return try await self.bindAsyncChannel0(
makeServerChannel: makeChannel,
serverBackpressureStrategy: serverBackpressureStrategy,
childBackpressureStrategy: childBackpressureStrategy,
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
) { (eventLoop, serverChannel) in
serverChannel.registerAndDoSynchronously { serverChannel in
serverChannel.bind(to: address)
}
}
}
private typealias MakeServerChannel = (_ eventLoop: SelectableEventLoop, _ childGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel
private typealias Register = @Sendable (EventLoop, ServerSocketChannel) -> EventLoopFuture<Void>
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func bindAsyncChannel0<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
makeServerChannel: MakeServerChannel,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
isChildChannelOutboundHalfClosureEnabled: Bool,
_ register: @escaping Register
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
let eventLoop = self.group.next()
let childEventLoopGroup = self.childGroup
let serverChannelOptions = self._serverChannelOptions
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
let childChannelInit = self.childChannelInit
let childChannelOptions = self._childChannelOptions
let serverChannel = try makeServerChannel(eventLoop as! SelectableEventLoop, childEventLoopGroup, self.enableMPTCP)
return try await eventLoop.submit {
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
serverChannelInit(serverChannel)
}.flatMap {
do {
try serverChannel.pipeline.syncOperations.addHandler(
AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions),
name: "AcceptHandler"
)
// We are wrapping the inbound channels into `NIOAsyncChannel`s with the transformation
// closure of the `NIOAsyncChannel` that allows us to wrap them without adding
// wrapping/unwrapping handlers to the pipeline.
let asyncChannel = try NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never>.wrapAsyncChannelForBootstrapBind(
synchronouslyWrapping: serverChannel,
backpressureStrategy: serverBackpressureStrategy,
transformationClosure: { channel in
// We must hop to the child channel event loop here to add the async channel handlers
channel.eventLoop.submit {
try NIOAsyncChannel(
synchronouslyWrapping: channel,
backpressureStrategy: childBackpressureStrategy,
isOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled,
inboundType: ChildChannelInbound.self,
outboundType: ChildChannelOutbound.self
)
}
}
)
return register(eventLoop, serverChannel).map { asyncChannel }
} catch {
return eventLoop.makeFailedFuture(error)
}
}.flatMapError { error in
serverChannel.close0(error: error, mode: .all, promise: nil)
return eventLoop.makeFailedFuture(error)
}
}.flatMap {
$0
}.get()
}
}
// MARK: AsyncChannel based bind with protocol negotiation
extension ServerBootstrap {
/// Bind the `ServerSocketChannel` to the `host` and `port` parameters.
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
///
/// - Parameters:
/// - host: The host to bind on.
/// - port: The port to bind on.
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
/// must be added to the pipeline in the child channel initializer.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func bind<Handler: NIOProtocolNegotiationHandler>(
host: String,
port: Int,
protocolNegotiationHandlerType: Handler.Type,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
return try await self.bindAsyncChannelWithProtocolNegotiation0(
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
serverBackpressureStrategy: serverBackpressureStrategy
) {
return try SocketAddress.makeAddressResolvingHost(host, port: port)
}
}
/// Bind the `ServerSocketChannel` to the `address` parameter.
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
///
/// - Parameters:
/// - address: The `SocketAddress` to bind on.
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
/// must be added to the pipeline in the child channel initializer.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func bind<Handler: NIOProtocolNegotiationHandler>(
to address: SocketAddress,
protocolNegotiationHandlerType: Handler.Type,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
return try await self.bindAsyncChannelWithProtocolNegotiation0(
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
serverBackpressureStrategy: serverBackpressureStrategy
) { address }
}
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
///
/// - Parameters:
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
/// unless `cleanupExistingSocketFile`is set to `true`.
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
/// must be added to the pipeline in the child channel initializer.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func bind<Handler: NIOProtocolNegotiationHandler>(
unixDomainSocketPath: String,
cleanupExistingSocketFile: Bool = false,
protocolNegotiationHandlerType: Handler.Type,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
if cleanupExistingSocketFile {
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
}
return try await self.bind(
unixDomainSocketPath: unixDomainSocketPath,
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
serverBackpressureStrategy: serverBackpressureStrategy
)
}
/// Use the existing bound socket file descriptor.
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
///
/// - Parameters:
/// - socket: The _Unix file descriptor_ representing the bound stream socket.
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
/// must be added to the pipeline in the child channel initializer.
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func withBoundSocket<Handler: NIOProtocolNegotiationHandler>(
_ socket: NIOBSDSocket.Handle,
protocolNegotiationHandlerType: Handler.Type,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
if enableMPTCP {
throw ChannelError.operationUnsupported
}
return try ServerSocketChannel(socket: socket, eventLoop: eventLoop, group: childEventLoopGroup)
}
return try await self.bindAsyncChannelWithProtocolNegotiation0(
makeServerChannel: makeChannel,
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
serverBackpressureStrategy: serverBackpressureStrategy
) { (eventLoop, serverChannel) in
let promise = eventLoop.makePromise(of: Void.self)
serverChannel.registerAlreadyConfigured0(promise: promise)
return promise.futureResult
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func bindAsyncChannelWithProtocolNegotiation0<Handler: NIOProtocolNegotiationHandler & ChannelHandler>(
protocolNegotiationHandlerType: Handler.Type,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
_ makeSocketAddress: () throws -> SocketAddress
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
let address = try makeSocketAddress()
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
return try ServerSocketChannel(eventLoop: eventLoop,
group: childEventLoopGroup,
protocolFamily: address.protocol,
enableMPTCP: enableMPTCP)
}
return try await self.bindAsyncChannelWithProtocolNegotiation0(
makeServerChannel: makeChannel,
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
serverBackpressureStrategy: serverBackpressureStrategy
) { (eventLoop, serverChannel) in
serverChannel.registerAndDoSynchronously { serverChannel in
serverChannel.bind(to: address)
}
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func bindAsyncChannelWithProtocolNegotiation0<Handler: NIOProtocolNegotiationHandler & ChannelHandler>(
makeServerChannel: MakeServerChannel,
protocolNegotiationHandlerType: Handler.Type,
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
_ register: @escaping Register
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
let eventLoop = self.group.next()
let childEventLoopGroup = self.childGroup
let serverChannelOptions = self._serverChannelOptions
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
let childChannelInit = self.childChannelInit
let childChannelOptions = self._childChannelOptions
let serverChannel = try makeServerChannel(eventLoop as! SelectableEventLoop, childEventLoopGroup, self.enableMPTCP)
return try await eventLoop.submit {
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
serverChannelInit(serverChannel)
}.flatMap {
do {
try serverChannel.pipeline.syncOperations.addHandler(
AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions),
name: "AcceptHandler"
)
// In the case of protocol negotiation we cannot wrap the child channels into
// `NIOAsyncChannel`s for the user since we don't know the type. We rather expect
// the user to wrap the child channels themselves when the negotiation is done
// and return the wrapped async channel as part of the negotiation result.
let asyncChannel = try NIOAsyncChannel<Handler.NegotiationResult, Never>.wrapAsyncChannelForBootstrapBindWithProtocolNegotiation(
synchronouslyWrapping: serverChannel,
backpressureStrategy: serverBackpressureStrategy,
transformationClosure: { (channel: Channel) in
channel.pipeline.handler(type: protocolNegotiationHandlerType)
.flatMap { handler -> EventLoopFuture<NIOProtocolNegotiationResult<Handler.NegotiationResult>> in
handler.protocolNegotiationResult
}.flatMap { result in
ServerBootstrap.waitForFinalResult(result, eventLoop: eventLoop)
}.flatMapErrorThrowing { error in
channel.pipeline.fireErrorCaught(error)
channel.close(promise: nil)
throw error
}
}
)
return register(eventLoop, serverChannel).map { asyncChannel }
} catch {
return eventLoop.makeFailedFuture(error)
}
}.flatMapError { error in
serverChannel.close0(error: error, mode: .all, promise: nil)
return eventLoop.makeFailedFuture(error)
}
}.flatMap {
$0
}.get()
}
/// This method recursively waits for the final result of protocol negotiation
static func waitForFinalResult<NegotiationResult>(
_ result: NIOProtocolNegotiationResult<NegotiationResult>,
eventLoop: EventLoop
) -> EventLoopFuture<NegotiationResult> {
switch result {
case .finished(let negotiationResult):
return eventLoop.makeSucceededFuture(negotiationResult)
case .deferredResult(let future):
return future.flatMap { result in
return waitForFinalResult(result, eventLoop: eventLoop)
}
}
}
}
@available(*, unavailable) @available(*, unavailable)
extension ServerBootstrap: Sendable {} extension ServerBootstrap: Sendable {}
@ -1058,8 +1507,9 @@ public final class DatagramBootstrap {
/// Bind the `DatagramChannel` to a UNIX Domain Socket. /// Bind the `DatagramChannel` to a UNIX Domain Socket.
/// ///
/// - parameters: /// - parameters:
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. `path` must not exist, it will be created by the system. /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`. /// unless `cleanupExistingSocketFile`is set to `true`.
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> { public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
if cleanupExistingSocketFile { if cleanupExistingSocketFile {
do { do {

View File

@ -13,7 +13,6 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
import NIOCore import NIOCore
import DequeModule
/// The result of an ALPN negotiation. /// The result of an ALPN negotiation.
/// ///
@ -35,6 +34,14 @@ public enum ALPNResult: Equatable, Sendable {
/// ALPN negotiation either failed, or never took place. The application /// ALPN negotiation either failed, or never took place. The application
/// should fall back to a default protocol choice or close the connection. /// should fall back to a default protocol choice or close the connection.
case fallback case fallback
init(negotiated: String?) {
if let negotiated = negotiated {
self = .negotiated(negotiated)
} else {
self = .fallback
}
}
} }
/// A helper `ChannelInboundHandler` that makes it easy to swap channel pipelines /// A helper `ChannelInboundHandler` that makes it easy to swap channel pipelines
@ -62,8 +69,7 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
public typealias InboundOut = Any public typealias InboundOut = Any
private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<Void> private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<Void>
private var waitingForUser: Bool private var stateMachine = ProtocolNegotiationHandlerStateMachine<Void>()
private var eventBuffer: Deque<NIOAny>
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion /// Create an `ApplicationProtocolNegotiationHandler` with the given completion
/// callback. /// callback.
@ -72,8 +78,6 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
/// negotiation has completed. /// negotiation has completed.
public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<Void>) { public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<Void>) {
self.completionHandler = alpnCompleteHandler self.completionHandler = alpnCompleteHandler
self.waitingForUser = false
self.eventBuffer = []
} }
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion /// Create an `ApplicationProtocolNegotiationHandler` with the given completion
@ -88,56 +92,71 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
} }
public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
guard let tlsEvent = event as? TLSUserEvent else { switch self.stateMachine.userInboundEventTriggered(event: event) {
case .fireUserInboundEventTriggered:
context.fireUserInboundEventTriggered(event) context.fireUserInboundEventTriggered(event)
return
}
if case .handshakeCompleted(let p) = tlsEvent { case .invokeUserClosure(let result):
handshakeCompleted(context: context, negotiatedProtocol: p) self.invokeUserClosure(context: context, result: result)
} else {
context.fireUserInboundEventTriggered(event)
} }
} }
public func channelRead(context: ChannelHandlerContext, data: NIOAny) { public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
if waitingForUser { switch self.stateMachine.channelRead(data: data) {
eventBuffer.append(data) case .fireChannelRead:
} else {
context.fireChannelRead(data) context.fireChannelRead(data)
case .none:
break
} }
} }
private func handshakeCompleted(context: ChannelHandlerContext, negotiatedProtocol: String?) { public func channelInactive(context: ChannelHandlerContext) {
waitingForUser = true self.stateMachine.channelInactive()
let result: ALPNResult context.fireChannelInactive()
if let negotiatedProtocol = negotiatedProtocol { }
result = .negotiated(negotiatedProtocol)
} else {
result = .fallback
}
private func invokeUserClosure(context: ChannelHandlerContext, result: ALPNResult) {
let switchFuture = self.completionHandler(result, context.channel) let switchFuture = self.completionHandler(result, context.channel)
switchFuture.whenComplete { (_: Result<Void, Error>) in
switchFuture
.hop(to: context.eventLoop)
.whenComplete { result in
self.userFutureCompleted(context: context, result: result)
}
}
private func userFutureCompleted(context: ChannelHandlerContext, result: Result<Void, Error>) {
switch self.stateMachine.userFutureCompleted(with: result) {
case .fireErrorCaughtAndRemoveHandler(let error):
context.fireErrorCaught(error)
context.pipeline.removeHandler(self, promise: nil)
case .fireErrorCaughtAndStartUnbuffering(let error):
context.fireErrorCaught(error)
self.unbuffer(context: context) self.unbuffer(context: context)
case .startUnbuffering:
self.unbuffer(context: context)
case .removeHandler:
context.pipeline.removeHandler(self, promise: nil) context.pipeline.removeHandler(self, promise: nil)
} }
} }
private func unbuffer(context: ChannelHandlerContext) { private func unbuffer(context: ChannelHandlerContext) {
// First we check if we have anything to unbuffer while true {
guard !self.eventBuffer.isEmpty else { switch self.stateMachine.unbuffer() {
return case .fireChannelRead(let data):
} context.fireChannelRead(data)
// Now we unbuffer until there is nothing left. case .fireChannelReadCompleteAndRemoveHandler:
// Importantly firing a channel read can lead to new reads being buffered due to reentrancy! context.fireChannelReadComplete()
while let datum = self.eventBuffer.popFirst() { context.pipeline.removeHandler(self, promise: nil)
context.fireChannelRead(datum) return
}
} }
context.fireChannelReadComplete()
} }
} }

View File

@ -0,0 +1,161 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
@_spi(AsyncChannel) import NIOCore
/// A helper ``ChannelInboundHandler`` that makes it easy to swap channel pipelines
/// based on the result of an ALPN negotiation.
///
/// The standard pattern used by applications that want to use ALPN is to select
/// an application protocol based on the result, optionally falling back to some
/// default protocol. To do this in SwiftNIO requires that the channel pipeline be
/// reconfigured based on the result of the ALPN negotiation. This channel handler
/// encapsulates that logic in a generic form that doesn't depend on the specific
/// TLS implementation in use by using ``TLSUserEvent``
///
/// The user of this channel handler provides a single closure that is called with
/// an ``ALPNResult`` when the ALPN negotiation is complete. Based on that result
/// the user is free to reconfigure the ``ChannelPipeline`` as required, and should
/// return an ``EventLoopFuture`` that will complete when the pipeline is reconfigured.
///
/// Until the ``EventLoopFuture`` completes, this channel handler will buffer inbound
/// data. When the ``EventLoopFuture`` completes, the buffered data will be replayed
/// down the channel. Then, finally, this channel handler will automatically remove
/// itself from the channel pipeline, leaving the pipeline in its final
/// configuration.
///
/// Importantly, this is a typed variant of the ``ApplicationProtocolNegotiationHandler`` and allows the user to
/// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult``
/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into ``NIOAsyncChannel``
/// based bootstraps.
@_spi(AsyncChannel)
public final class NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>: ChannelInboundHandler, RemovableChannelHandler, NIOProtocolNegotiationHandler {
@_spi(AsyncChannel)
public typealias InboundIn = Any
@_spi(AsyncChannel)
public typealias InboundOut = Any
@_spi(AsyncChannel)
public var protocolNegotiationResult: EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
self.negotiatedPromise.futureResult
}
private let negotiatedPromise: EventLoopPromise<NIOProtocolNegotiationResult<NegotiationResult>>
private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>
private var stateMachine = ProtocolNegotiationHandlerStateMachine<NIOProtocolNegotiationResult<NegotiationResult>>()
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
/// callback.
///
/// - Parameter alpnCompleteHandler: The closure that will fire when ALPN
/// negotiation has completed.
@_spi(AsyncChannel)
public init(eventLoop: EventLoop, alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
self.completionHandler = alpnCompleteHandler
self.negotiatedPromise = eventLoop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
}
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
/// callback.
///
/// - Parameter alpnCompleteHandler: The closure that will fire when ALPN
/// negotiation has completed.
@_spi(AsyncChannel)
public convenience init(eventLoop: EventLoop, alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
self.init(eventLoop: eventLoop) { result, _ in
alpnCompleteHandler(result)
}
}
@_spi(AsyncChannel)
public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch self.stateMachine.userInboundEventTriggered(event: event) {
case .fireUserInboundEventTriggered:
context.fireUserInboundEventTriggered(event)
case .invokeUserClosure(let result):
self.invokeUserClosure(context: context, result: result)
}
}
@_spi(AsyncChannel)
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.stateMachine.channelRead(data: data) {
case .fireChannelRead:
context.fireChannelRead(data)
case .none:
break
}
}
@_spi(AsyncChannel)
public func channelInactive(context: ChannelHandlerContext) {
self.stateMachine.channelInactive()
self.negotiatedPromise.fail(ChannelError.outputClosed)
context.fireChannelInactive()
}
private func invokeUserClosure(context: ChannelHandlerContext, result: ALPNResult) {
let switchFuture = self.completionHandler(result, context.channel)
switchFuture
.hop(to: context.eventLoop)
.whenComplete { result in
self.userFutureCompleted(context: context, result: result)
}
}
private func userFutureCompleted(context: ChannelHandlerContext, result: Result<NIOProtocolNegotiationResult<NegotiationResult>, Error>) {
switch self.stateMachine.userFutureCompleted(with: result) {
case .fireErrorCaughtAndRemoveHandler(let error):
self.negotiatedPromise.fail(error)
context.fireErrorCaught(error)
context.pipeline.removeHandler(self, promise: nil)
case .fireErrorCaughtAndStartUnbuffering(let error):
self.negotiatedPromise.fail(error)
context.fireErrorCaught(error)
self.unbuffer(context: context)
case .startUnbuffering(let value):
self.negotiatedPromise.succeed(value)
self.unbuffer(context: context)
case .removeHandler(let value):
self.negotiatedPromise.succeed(value)
context.pipeline.removeHandler(self, promise: nil)
}
}
private func unbuffer(context: ChannelHandlerContext) {
while true {
switch self.stateMachine.unbuffer() {
case .fireChannelRead(let data):
context.fireChannelRead(data)
case .fireChannelReadCompleteAndRemoveHandler:
context.fireChannelReadComplete()
context.pipeline.removeHandler(self, promise: nil)
return
}
}
}
}
@available(*, unavailable)
extension NIOTypedApplicationProtocolNegotiationHandler: Sendable {}

View File

@ -0,0 +1,159 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import DequeModule
import NIOCore
struct ProtocolNegotiationHandlerStateMachine<NegotiationResult> {
enum State {
/// The state before we received a TLSUserEvent. We are just forwarding any read at this point.
case initial
/// The state after we received a ``TLSUserEvent`` and are waiting for the future of the user to complete.
case waitingForUser(buffer: Deque<NIOAny>)
/// The state after the users future finished and we are unbuffering all the reads.
case unbuffering(buffer: Deque<NIOAny>)
/// The state once the negotiation is done and we are finished with unbuffering.
case finished
}
private var state = State.initial
@usableFromInline
enum UserInboundEventTriggeredAction {
case fireUserInboundEventTriggered
case invokeUserClosure(ALPNResult)
}
@inlinable
mutating func userInboundEventTriggered(event: Any) -> UserInboundEventTriggeredAction {
if case .handshakeCompleted(let negotiated) = event as? TLSUserEvent {
switch self.state {
case .initial:
self.state = .waitingForUser(buffer: .init())
return .invokeUserClosure(.init(negotiated: negotiated))
case .waitingForUser, .unbuffering:
preconditionFailure("Unexpectedly received two TLSUserEvents")
case .finished:
// This is weird but we can tolerate it and just forward the event
return .fireUserInboundEventTriggered
}
} else {
return .fireUserInboundEventTriggered
}
}
@usableFromInline
enum ChannelReadAction {
case fireChannelRead
}
@inlinable
mutating func channelRead(data: NIOAny) -> ChannelReadAction? {
switch self.state {
case .initial, .finished:
return .fireChannelRead
case .waitingForUser(var buffer):
buffer.append(data)
self.state = .waitingForUser(buffer: buffer)
return .none
case .unbuffering(var buffer):
buffer.append(data)
self.state = .unbuffering(buffer: buffer)
return .none
}
}
@usableFromInline
enum UserFutureCompletedAction {
case fireErrorCaughtAndRemoveHandler(Error)
case fireErrorCaughtAndStartUnbuffering(Error)
case startUnbuffering(NegotiationResult)
case removeHandler(NegotiationResult)
}
@inlinable
mutating func userFutureCompleted(with result: Result<NegotiationResult, Error>) -> UserFutureCompletedAction {
switch self.state {
case .initial, .finished:
preconditionFailure("Invalid state \(self.state)")
case .waitingForUser(let buffer):
switch result {
case .success(let value):
if !buffer.isEmpty {
self.state = .unbuffering(buffer: buffer)
return .startUnbuffering(value)
} else {
self.state = .finished
return .removeHandler(value)
}
case .failure(let error):
if !buffer.isEmpty {
self.state = .unbuffering(buffer: buffer)
return .fireErrorCaughtAndStartUnbuffering(error)
} else {
self.state = .finished
return .fireErrorCaughtAndRemoveHandler(error)
}
}
case .unbuffering:
preconditionFailure("Invalid state \(self.state)")
}
}
@usableFromInline
enum UnbufferAction {
case fireChannelRead(NIOAny)
case fireChannelReadCompleteAndRemoveHandler
}
@inlinable
mutating func unbuffer() -> UnbufferAction {
switch self.state {
case .initial, .waitingForUser, .finished:
preconditionFailure("Invalid state \(self.state)")
case .unbuffering(var buffer):
if let element = buffer.popFirst() {
self.state = .unbuffering(buffer: buffer)
return .fireChannelRead(element)
} else {
self.state = .finished
return .fireChannelReadCompleteAndRemoveHandler
}
}
}
@inlinable
mutating func channelInactive() {
switch self.state {
case .initial, .unbuffering, .waitingForUser:
self.state = .finished
case .finished:
break
}
}
}

View File

@ -78,7 +78,12 @@ final class AsyncChannelTests: XCTestCase {
do { do {
let wrapped = try await channel.testingEventLoop.executeInContext { let wrapped = try await channel.testingEventLoop.executeInContext {
try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: Never.self) try NIOAsyncChannel(
synchronouslyWrapping: channel,
isOutboundHalfClosureEnabled: true,
inboundType: Never.self,
outboundType: Never.self
)
} }
inboundReader = wrapped.inboundStream inboundReader = wrapped.inboundStream
@ -140,7 +145,12 @@ final class AsyncChannelTests: XCTestCase {
do { do {
let wrapped = try await channel.testingEventLoop.executeInContext { let wrapped = try await channel.testingEventLoop.executeInContext {
try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: Never.self) try NIOAsyncChannel(
synchronouslyWrapping: channel,
isOutboundHalfClosureEnabled: true,
inboundType: Never.self,
outboundType: Never.self
)
} }
inboundReader = wrapped.inboundStream inboundReader = wrapped.inboundStream
@ -239,24 +249,6 @@ final class AsyncChannelTests: XCTestCase {
} }
} }
func testErrorsArePropagatedToWriters() {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest(timeout: 5) {
let channel = NIOAsyncTestingChannel()
let wrapped = try await channel.testingEventLoop.executeInContext {
try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: String.self)
}
try await channel.testingEventLoop.executeInContext {
channel.pipeline.fireErrorCaught(TestError.bang)
}
try await XCTAssertThrowsError(await wrapped.outboundWriter.write("hello")) { error in
XCTAssertEqual(error as? TestError, .bang)
}
}
}
func testChannelBecomingNonWritableDelaysWriters() { func testChannelBecomingNonWritableDelaysWriters() {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest(timeout: 5) { XCTAsyncTest(timeout: 5) {
@ -312,7 +304,7 @@ final class AsyncChannelTests: XCTestCase {
do { do {
let strongSentinel: Sentinel? = Sentinel() let strongSentinel: Sentinel? = Sentinel()
sentinel = strongSentinel! sentinel = strongSentinel!
try await XCTAsyncAssertNotNil(await channel.pipeline.handler(type: NIOAsyncChannelInboundStreamChannelHandler<Sentinel>.self).get()) try await XCTAsyncAssertNotNil(await channel.pipeline.handler(type: NIOAsyncChannelInboundStreamChannelHandler<Sentinel, Sentinel>.self).get())
try await channel.writeInbound(strongSentinel!) try await channel.writeInbound(strongSentinel!)
_ = try await channel.readInbound(as: Sentinel.self) _ = try await channel.readInbound(as: Sentinel.self)
} }

View File

@ -0,0 +1,504 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
@_spi(AsyncChannel) import NIOCore
@_spi(AsyncChannel) @testable import NIOPosix
import XCTest
@_spi(AsyncChannel) import NIOTLS
import NIOConcurrencyHelpers
fileprivate final class LineDelimiterDecoder: ByteToMessageDecoder {
private let newLine = "\n".utf8.first!
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
let readable = buffer.withUnsafeReadableBytes { $0.firstIndex(of: newLine) }
if let readable = readable {
context.fireChannelRead(self.wrapInboundOut(buffer.readSlice(length: readable)!))
buffer.moveReaderIndex(forwardBy: 1)
return .continue
}
return .needMoreData
}
}
fileprivate final class TLSUserEventHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let buffer = self.unwrapInboundIn(data)
let alpn = String(buffer: buffer)
if alpn.hasPrefix("alpn:") {
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(alpn.dropFirst(5))))
} else {
context.fireChannelRead(data)
}
}
}
fileprivate final class ByteBufferToStringHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = String
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let buffer = self.unwrapInboundIn(data)
context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer)))
}
}
fileprivate final class ByteBufferToByteHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = UInt8
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var buffer = self.unwrapInboundIn(data)
let byte = buffer.readInteger(as: UInt8.self)!
context.fireChannelRead(self.wrapInboundOut(byte))
}
}
final class AsyncChannelBootstrapTests: XCTestCase {
enum NegotiationResult {
case string(NIOAsyncChannel<String, String>)
case byte(NIOAsyncChannel<UInt8, UInt8>)
}
struct ProtocolNegotiationError: Error {}
enum StringOrByte: Hashable {
case string(String)
case byte(UInt8)
}
func testAsyncChannel() throws {
XCTAsyncTest {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
let channel = try await ServerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
}
}
.bind(
host: "127.0.0.1",
port: 1995,
childChannelInboundType: String.self,
childChannelOutboundType: String.self
)
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await childChannel in channel.inboundStream {
for try await value in childChannel.inboundStream {
continuation.yield(.string(value))
}
}
}
}
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
group.cancelAll()
}
}
}
func testAsyncChannelProtocolNegotiation() throws {
XCTAsyncTest {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
let channel: NIOAsyncChannel<NegotiationResult, Never> = try await ServerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try self.makeProtocolNegotiationChildChannel(channel: channel)
}
}
.bind(
host: "127.0.0.1",
port: 1995,
protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>.self
)
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await childChannel in channel.inboundStream {
group.addTask {
switch childChannel {
case .string(let channel):
for try await value in channel.inboundStream {
continuation.yield(.string(value))
}
case .byte(let channel):
for try await value in channel.inboundStream {
continuation.yield(.byte(value))
}
}
}
}
}
}
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
let byteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
byteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is the actual content
byteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
byteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
group.cancelAll()
}
}
}
func testAsyncChannelNestedProtocolNegotiation() throws {
XCTAsyncTest {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
let channel: NIOAsyncChannel<NegotiationResult, Never> = try await ServerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try self.makeNestedProtocolNegotiationChildChannel(channel: channel)
}
}
.bind(
host: "127.0.0.1",
port: 1995,
protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>.self
)
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await childChannel in channel.inboundStream {
group.addTask {
switch childChannel {
case .string(let channel):
for try await value in channel.inboundStream {
continuation.yield(.string(value))
}
case .byte(let channel):
for try await value in channel.inboundStream {
continuation.yield(.byte(value))
}
}
}
}
}
}
let stringStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is for negotiating the nested protocol
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
let byteByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is for negotiating the nested protocol
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is the actual content
byteByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
let stringByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is for negotiating the nested protocol
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is the actual content
stringByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
let byteStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is for negotiating the nested protocol
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
group.cancelAll()
}
}
}
func testAsyncChannelProtocolNegotiation_whenFails() throws {
final class CollectingHandler: ChannelInboundHandler {
typealias InboundIn = Channel
private let channels: NIOLockedValueBox<[Channel]>
init(channels: NIOLockedValueBox<[Channel]>) {
self.channels = channels
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let channel = self.unwrapInboundIn(data)
self.channels.withLockedValue { $0.append(channel) }
context.fireChannelRead(data)
}
}
XCTAsyncTest {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
let channels = NIOLockedValueBox<[Channel]>([Channel]())
let channel: NIOAsyncChannel<NegotiationResult, Never> = try await ServerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.serverChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels))
}
}
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try self.makeProtocolNegotiationChildChannel(channel: channel)
}
}
.bind(
host: "127.0.0.1",
port: 1995,
protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>.self
)
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await childChannel in channel.inboundStream {
group.addTask {
switch childChannel {
case .string(let channel):
for try await value in channel.inboundStream {
continuation.yield(.string(value))
}
case .byte(let channel):
for try await value in channel.inboundStream {
continuation.yield(.byte(value))
}
}
}
}
}
}
let channel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
channel.writeAndFlush(.init(ByteBuffer(string: "alpn:unknown\n")), promise: nil)
// Checking that we can still create new connections afterwards
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
// This is for negotiating the protocol
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
let failedInboundChannel = channels.withLockedValue { channels -> Channel in
XCTAssertEqual(channels.count, 2)
return channels[0]
}
// We are waiting here to make sure the channel got closed
try await failedInboundChannel.closeFuture.get()
group.cancelAll()
}
}
}
// MARK: - Test Helpers
private func makeClientChannel(eventLoopGroup: EventLoopGroup) async throws -> Channel {
return try await ClientBootstrap(group: eventLoopGroup)
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
}
}
.connect(to: .init(ipAddress: "127.0.0.1", port: 1995))
.get()
}
private func makeProtocolNegotiationChildChannel(channel: Channel) throws {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
}
private func makeNestedProtocolNegotiationChildChannel(channel: Channel) throws {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
try channel.pipeline.syncOperations.addHandler(
NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
case "string":
return channel.eventLoop.makeCompletedFuture {
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
}
case "byte":
return channel.eventLoop.makeCompletedFuture {
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
}
default:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
case .fallback:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
}
)
}
@discardableResult
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
case "string":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
let asyncChannel = try NIOAsyncChannel(
synchronouslyWrapping: channel,
isOutboundHalfClosureEnabled: true,
inboundType: String.self,
outboundType: String.self
)
return NIOProtocolNegotiationResult.finished(NegotiationResult.string(asyncChannel))
}
case "byte":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler())
let asyncChannel = try NIOAsyncChannel(
synchronouslyWrapping: channel,
isOutboundHalfClosureEnabled: true,
inboundType: UInt8.self,
outboundType: UInt8.self
)
return NIOProtocolNegotiationResult.finished(NegotiationResult.byte(asyncChannel))
}
default:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
case .fallback:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
}
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
return negotiationHandler.protocolNegotiationResult
}
}
extension AsyncStream {
fileprivate static func makeStream(
of elementType: Element.Type = Element.self,
bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded
) -> (stream: AsyncStream<Element>, continuation: AsyncStream<Element>.Continuation) {
var continuation: AsyncStream<Element>.Continuation!
let stream = AsyncStream<Element>(bufferingPolicy: limit) { continuation = $0 }
return (stream: stream, continuation: continuation!)
}
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
fileprivate func XCTAsyncAssertEqual<Element: Equatable>(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows {
let lhsResult = try await lhs()
let rhsResult = try await rhs()
XCTAssertEqual(lhsResult, rhsResult, file: file, line: line)
}

View File

@ -0,0 +1,209 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
@_spi(AsyncChannel) import NIOTLS
@_spi(AsyncChannel) import NIOCore
import NIOEmbedded
import XCTest
import NIOTestUtils
final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
enum NegotiationResult: Equatable {
case negotiated(ALPNResult)
case failed
}
private let negotiatedEvent: TLSUserEvent = .handshakeCompleted(negotiatedProtocol: "h2")
private let negotiatedResult: ALPNResult = .negotiated("h2")
func testChannelProvidedToCallback() throws {
let emChannel = EmbeddedChannel()
let loop = emChannel.eventLoop as! EmbeddedEventLoop
var called = false
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result, channel in
called = true
XCTAssertEqual(result, self.negotiatedResult)
XCTAssertTrue(emChannel === channel)
return loop.makeSucceededFuture(.finished(.negotiated(result)))
}
try emChannel.pipeline.addHandler(handler).wait()
emChannel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
XCTAssertTrue(called)
XCTAssertEqual(try handler.protocolNegotiationResult.wait(), .finished(.negotiated(negotiatedResult)))
}
func testIgnoresUnknownUserEvents() throws {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
XCTFail("Negotiation fired")
return loop.makeSucceededFuture(.finished(.failed))
}
try channel.pipeline.addHandler(handler).wait()
// Fire a pair of events that should be ignored.
channel.pipeline.fireUserInboundEventTriggered("FakeEvent")
channel.pipeline.fireUserInboundEventTriggered(TLSUserEvent.shutdownCompleted)
// The channel handler should still be in the pipeline.
try channel.pipeline.assertContains(handler: handler)
XCTAssertTrue(try channel.finish().isClean)
}
func testNoBufferingBeforeEventFires() throws {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
XCTFail("Should not be called")
return loop.makeSucceededFuture(.finished(.failed))
}
try channel.pipeline.addHandler(handler).wait()
// The data we write should not be buffered.
try channel.writeInbound("hello")
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "hello"))
XCTAssertTrue(try channel.finish().isClean)
}
func testBufferingWhileWaitingForFuture() throws {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
return continuePromise.futureResult
}
try channel.pipeline.addHandler(handler).wait()
// Fire in the event.
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
// At this point all writes should be buffered.
try channel.writeInbound("writes")
try channel.writeInbound("are")
try channel.writeInbound("buffered")
XCTAssertNoThrow(XCTAssertNil(try channel.readInbound()))
// Complete the pipeline swap.
continuePromise.succeed(.finished(.failed))
// Now everything should have been unbuffered.
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "writes"))
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "are"))
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "buffered"))
XCTAssertTrue(try channel.finish().isClean)
}
func testNothingBufferedDoesNotFireReadCompleted() throws {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
continuePromise.futureResult
}
let eventCounterHandler = EventCounterHandler()
try channel.pipeline.addHandler(handler).wait()
try channel.pipeline.addHandler(eventCounterHandler).wait()
// Fire in the event.
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
// At this time, readComplete hasn't fired.
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 0)
// Now satisfy the future, which forces data unbuffering. As we haven't buffered any data,
// readComplete should not be fired.
continuePromise.succeed(.finished(.failed))
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 0)
XCTAssertTrue(try channel.finish().isClean)
}
func testUnbufferingFiresReadCompleted() throws {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
continuePromise.futureResult
}
let eventCounterHandler = EventCounterHandler()
try channel.pipeline.addHandler(handler).wait()
try channel.pipeline.addHandler(eventCounterHandler).wait()
// Fire in the event.
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
// Send a write, which is buffered.
try channel.writeInbound("a write")
// At this time, readComplete hasn't fired.
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1)
// Now satisfy the future, which forces data unbuffering. This should fire readComplete.
continuePromise.succeed(.finished(.failed))
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write"))
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 2)
XCTAssertTrue(try channel.finish().isClean)
}
func testUnbufferingHandlesReentrantReads() throws {
let channel = EmbeddedChannel()
let loop = channel.eventLoop as! EmbeddedEventLoop
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
continuePromise.futureResult
}
let eventCounterHandler = EventCounterHandler()
try channel.pipeline.addHandler(handler).wait()
try channel.pipeline.addHandler(DuplicatingReadHandler(embeddedChannel: channel)).wait()
try channel.pipeline.addHandler(eventCounterHandler).wait()
// Fire in the event.
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
// Send a write, which is buffered.
try channel.writeInbound("a write")
// At this time, readComplete hasn't fired.
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1)
// Now satisfy the future, which forces data unbuffering. This should fire readComplete.
continuePromise.succeed(.finished(.failed))
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write"))
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write"))
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 3)
XCTAssertTrue(try channel.finish().isClean)
}
}