Add `AsyncChannel` based `ServerBootstrap.bind()` methods (#2403)
* Add `AsyncChannel` based `ServerBootstrap.bind()` methods # Motivation In my previous PR, we added a new async bridge from a NIO `Channel` to Swift Concurrency primitives in the from of the `NIOAsyncChannel`. This type alone is already helpful in bridging `Channel`s to Concurrency; however, it is hard to use since it requires to wrap the `Channel` at the right time otherwise we will drop reads. Furthermore, in the case of protocol negotiation this becomes even trickier since we need to wait until it finishes and then wrap the `Channel`. # Modification This PR introduces a few things: 1. New methods on the `ServerBootstrap` which allow the creation of `NIOAsyncChannel` based channels. This can be used in all cases where no protocol negotiation is involved. 2. A new protocol and type called `NIOProtocolNegotiationHandler` and `NIOProtocolNegotiationResult` which is used to identify channel handlers that are doing protocol negotiation. 3. New methods on the `ServerBootstrap` that are aware of protocol negotiation. # Result We can now easily and safely create new `AsyncChannel`s from the `ServerBootstrap` * Code review * Fix typo * Fix up tests * Stop finishing the writer when an error is caught * Code review * Fix up writer tests * Introduce shared protocol negotiation handler state machine * Correctly handle multi threaded event loops * Adapt test to assert the channel was closed correctly. * Code review
This commit is contained in:
parent
f7c4655298
commit
d836d6bef5
|
@ -104,7 +104,7 @@ var targets: [PackageDescription.Target] = [
|
||||||
.testTarget(name: "NIOEmbeddedTests",
|
.testTarget(name: "NIOEmbeddedTests",
|
||||||
dependencies: ["NIOConcurrencyHelpers", "NIOCore", "NIOEmbedded"]),
|
dependencies: ["NIOConcurrencyHelpers", "NIOCore", "NIOEmbedded"]),
|
||||||
.testTarget(name: "NIOPosixTests",
|
.testTarget(name: "NIOPosixTests",
|
||||||
dependencies: ["NIOPosix", "NIOCore", "NIOFoundationCompat", "NIOTestUtils", "NIOConcurrencyHelpers", "NIOEmbedded", "CNIOLinux"]),
|
dependencies: ["NIOPosix", "NIOCore", "NIOFoundationCompat", "NIOTestUtils", "NIOConcurrencyHelpers", "NIOEmbedded", "CNIOLinux", "NIOTLS"]),
|
||||||
.testTarget(name: "NIOConcurrencyHelpersTests",
|
.testTarget(name: "NIOConcurrencyHelpersTests",
|
||||||
dependencies: ["NIOConcurrencyHelpers", "NIOCore"]),
|
dependencies: ["NIOConcurrencyHelpers", "NIOCore"]),
|
||||||
.testTarget(name: "NIODataStructuresTests",
|
.testTarget(name: "NIODataStructuresTests",
|
||||||
|
|
|
@ -62,7 +62,7 @@ public final class NIOAsyncChannel<Inbound: Sendable, Outbound: Sendable>: Senda
|
||||||
public init(
|
public init(
|
||||||
synchronouslyWrapping channel: Channel,
|
synchronouslyWrapping channel: Channel,
|
||||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
isOutboundHalfClosureEnabled: Bool = true,
|
isOutboundHalfClosureEnabled: Bool = false,
|
||||||
inboundType: Inbound.Type = Inbound.self,
|
inboundType: Inbound.Type = Inbound.self,
|
||||||
outboundType: Outbound.Type = Outbound.self
|
outboundType: Outbound.Type = Outbound.self
|
||||||
) throws {
|
) throws {
|
||||||
|
@ -92,7 +92,7 @@ public final class NIOAsyncChannel<Inbound: Sendable, Outbound: Sendable>: Senda
|
||||||
public init(
|
public init(
|
||||||
synchronouslyWrapping channel: Channel,
|
synchronouslyWrapping channel: Channel,
|
||||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
isOutboundHalfClosureEnabled: Bool = true,
|
isOutboundHalfClosureEnabled: Bool = false,
|
||||||
inboundType: Inbound.Type = Inbound.self
|
inboundType: Inbound.Type = Inbound.self
|
||||||
) throws where Outbound == Never {
|
) throws where Outbound == Never {
|
||||||
channel.eventLoop.preconditionInEventLoop()
|
channel.eventLoop.preconditionInEventLoop()
|
||||||
|
@ -104,6 +104,67 @@ public final class NIOAsyncChannel<Inbound: Sendable, Outbound: Sendable>: Senda
|
||||||
|
|
||||||
self.outboundWriter.finish()
|
self.outboundWriter.finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public init(
|
||||||
|
channel: Channel,
|
||||||
|
inboundStream: NIOAsyncChannelInboundStream<Inbound>,
|
||||||
|
outboundWriter: NIOAsyncChannelOutboundWriter<Outbound>
|
||||||
|
) {
|
||||||
|
channel.eventLoop.preconditionInEventLoop()
|
||||||
|
self.channel = channel
|
||||||
|
self.inboundStream = inboundStream
|
||||||
|
self.outboundWriter = outboundWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public static func wrapAsyncChannelForBootstrapBind(
|
||||||
|
synchronouslyWrapping channel: Channel,
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
isOutboundHalfClosureEnabled: Bool = false,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
|
||||||
|
) throws -> NIOAsyncChannel<Inbound, Outbound> where Outbound == Never {
|
||||||
|
channel.eventLoop.preconditionInEventLoop()
|
||||||
|
let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) = try channel._syncAddAsyncHandlersForBootstrapBind(
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
|
||||||
|
transformationClosure: transformationClosure
|
||||||
|
)
|
||||||
|
|
||||||
|
outboundWriter.finish()
|
||||||
|
|
||||||
|
return .init(
|
||||||
|
channel: channel,
|
||||||
|
inboundStream: inboundStream,
|
||||||
|
outboundWriter: outboundWriter
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public static func wrapAsyncChannelForBootstrapBindWithProtocolNegotiation(
|
||||||
|
synchronouslyWrapping channel: Channel,
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
isOutboundHalfClosureEnabled: Bool = false,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
|
||||||
|
) throws -> NIOAsyncChannel<Inbound, Outbound> where Outbound == Never {
|
||||||
|
channel.eventLoop.preconditionInEventLoop()
|
||||||
|
let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) = try channel._syncAddAsyncHandlersForBootstrapProtocolNegotiation(
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
|
||||||
|
transformationClosure: transformationClosure
|
||||||
|
)
|
||||||
|
|
||||||
|
outboundWriter.finish()
|
||||||
|
|
||||||
|
return .init(
|
||||||
|
channel: channel,
|
||||||
|
inboundStream: inboundStream,
|
||||||
|
outboundWriter: outboundWriter
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension Channel {
|
extension Channel {
|
||||||
|
@ -118,7 +179,7 @@ extension Channel {
|
||||||
self.eventLoop.assertInEventLoop()
|
self.eventLoop.assertInEventLoop()
|
||||||
|
|
||||||
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
|
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
|
||||||
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>(
|
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>.makeWrappingHandler(
|
||||||
channel: self,
|
channel: self,
|
||||||
backpressureStrategy: backpressureStrategy,
|
backpressureStrategy: backpressureStrategy,
|
||||||
closeRatchet: closeRatchet
|
closeRatchet: closeRatchet
|
||||||
|
@ -129,4 +190,52 @@ extension Channel {
|
||||||
)
|
)
|
||||||
return (inboundStream, writer)
|
return (inboundStream, writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@inlinable
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func _syncAddAsyncHandlersForBootstrapBind<Inbound: Sendable, Outbound: Sendable>(
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
isOutboundHalfClosureEnabled: Bool,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
|
||||||
|
) throws -> (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) {
|
||||||
|
self.eventLoop.assertInEventLoop()
|
||||||
|
|
||||||
|
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
|
||||||
|
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>.makeBindingHandler(
|
||||||
|
channel: self,
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformationClosure: transformationClosure
|
||||||
|
)
|
||||||
|
let writer = try NIOAsyncChannelOutboundWriter<Outbound>(
|
||||||
|
channel: self,
|
||||||
|
closeRatchet: closeRatchet
|
||||||
|
)
|
||||||
|
return (inboundStream, writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@inlinable
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func _syncAddAsyncHandlersForBootstrapProtocolNegotiation<Inbound: Sendable, Outbound: Sendable>(
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
isOutboundHalfClosureEnabled: Bool,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
|
||||||
|
) throws -> (NIOAsyncChannelInboundStream<Inbound>, NIOAsyncChannelOutboundWriter<Outbound>) {
|
||||||
|
self.eventLoop.assertInEventLoop()
|
||||||
|
|
||||||
|
let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled)
|
||||||
|
let inboundStream = try NIOAsyncChannelInboundStream<Inbound>.makeProtocolNegotiationHandler(
|
||||||
|
channel: self,
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformationClosure: transformationClosure
|
||||||
|
)
|
||||||
|
let writer = try NIOAsyncChannelOutboundWriter<Outbound>(
|
||||||
|
channel: self,
|
||||||
|
closeRatchet: closeRatchet
|
||||||
|
)
|
||||||
|
return (inboundStream, writer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,22 +19,19 @@
|
||||||
@_spi(AsyncChannel)
|
@_spi(AsyncChannel)
|
||||||
public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
|
public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
typealias Producer = NIOThrowingAsyncSequenceProducer<Inbound, Error, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncChannelInboundStreamChannelHandler<Inbound>.Delegate>
|
typealias Producer = NIOThrowingAsyncSequenceProducer<Inbound, Error, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate>
|
||||||
|
|
||||||
/// The underlying async sequence.
|
/// The underlying async sequence.
|
||||||
@usableFromInline let _producer: Producer
|
@usableFromInline let _producer: Producer
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
init(
|
init<HandlerInbound: Sendable>(
|
||||||
channel: Channel,
|
channel: Channel,
|
||||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
closeRatchet: CloseRatchet
|
closeRatchet: CloseRatchet,
|
||||||
|
handler: NIOAsyncChannelInboundStreamChannelHandler<HandlerInbound, Inbound>
|
||||||
) throws {
|
) throws {
|
||||||
channel.eventLoop.preconditionInEventLoop()
|
channel.eventLoop.preconditionInEventLoop()
|
||||||
let handler = NIOAsyncChannelInboundStreamChannelHandler<Inbound>(
|
|
||||||
eventLoop: channel.eventLoop,
|
|
||||||
closeRatchet: closeRatchet
|
|
||||||
)
|
|
||||||
let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark
|
let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark
|
||||||
|
|
||||||
if let userProvided = backpressureStrategy {
|
if let userProvided = backpressureStrategy {
|
||||||
|
@ -47,12 +44,77 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
|
||||||
|
|
||||||
let sequence = Producer.makeSequence(
|
let sequence = Producer.makeSequence(
|
||||||
backPressureStrategy: strategy,
|
backPressureStrategy: strategy,
|
||||||
delegate: NIOAsyncChannelInboundStreamChannelHandler<Inbound>.Delegate(handler: handler)
|
delegate: NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate(handler: handler)
|
||||||
)
|
)
|
||||||
handler.source = sequence.source
|
handler.source = sequence.source
|
||||||
try channel.pipeline.syncOperations.addHandler(handler)
|
try channel.pipeline.syncOperations.addHandler(handler)
|
||||||
self._producer = sequence.sequence
|
self._producer = sequence.sequence
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a new ``NIOAsyncChannelInboundStream`` which is used when the pipeline got synchronously wrapped.
|
||||||
|
@inlinable
|
||||||
|
static func makeWrappingHandler(
|
||||||
|
channel: Channel,
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
closeRatchet: CloseRatchet
|
||||||
|
) throws -> NIOAsyncChannelInboundStream {
|
||||||
|
let handler = NIOAsyncChannelInboundStreamChannelHandler<Inbound, Inbound>.makeWrappingHandler(
|
||||||
|
eventLoop: channel.eventLoop,
|
||||||
|
closeRatchet: closeRatchet
|
||||||
|
)
|
||||||
|
|
||||||
|
return try .init(
|
||||||
|
channel: channel,
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
handler: handler
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel.
|
||||||
|
@inlinable
|
||||||
|
static func makeBindingHandler(
|
||||||
|
channel: Channel,
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
closeRatchet: CloseRatchet,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
|
||||||
|
) throws -> NIOAsyncChannelInboundStream {
|
||||||
|
let handler = NIOAsyncChannelInboundStreamChannelHandler<Channel, Inbound>.makeBindingHandler(
|
||||||
|
eventLoop: channel.eventLoop,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformationClosure: transformationClosure
|
||||||
|
)
|
||||||
|
|
||||||
|
return try .init(
|
||||||
|
channel: channel,
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
handler: handler
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel when the child
|
||||||
|
/// channel does protocol negotiation.
|
||||||
|
@inlinable
|
||||||
|
static func makeProtocolNegotiationHandler(
|
||||||
|
channel: Channel,
|
||||||
|
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
closeRatchet: CloseRatchet,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<Inbound>
|
||||||
|
) throws -> NIOAsyncChannelInboundStream {
|
||||||
|
let handler = NIOAsyncChannelInboundStreamChannelHandler<Channel, Inbound>.makeProtocolNegotiationHandler(
|
||||||
|
eventLoop: channel.eventLoop,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformationClosure: transformationClosure
|
||||||
|
)
|
||||||
|
|
||||||
|
return try .init(
|
||||||
|
channel: channel,
|
||||||
|
backpressureStrategy: backpressureStrategy,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
handler: handler
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@ -87,4 +149,3 @@ extension NIOAsyncChannelInboundStream: AsyncSequence {
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
@available(*, unavailable)
|
@available(*, unavailable)
|
||||||
extension NIOAsyncChannelInboundStream.AsyncIterator: Sendable {}
|
extension NIOAsyncChannelInboundStream.AsyncIterator: Sendable {}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
/// ``Channel`` into an asynchronous sequence that supports back-pressure.
|
/// ``Channel`` into an asynchronous sequence that supports back-pressure.
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Sendable>: ChannelDuplexHandler {
|
internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Sendable, ProducerElement: Sendable>: ChannelDuplexHandler {
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
enum _ProducingState {
|
enum _ProducingState {
|
||||||
// Not .stopProducing
|
// Not .stopProducing
|
||||||
|
@ -37,10 +37,10 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
|
||||||
|
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
typealias Source = NIOThrowingAsyncSequenceProducer<
|
typealias Source = NIOThrowingAsyncSequenceProducer<
|
||||||
InboundIn,
|
ProducerElement,
|
||||||
Error,
|
Error,
|
||||||
NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark,
|
NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark,
|
||||||
NIOAsyncChannelInboundStreamChannelHandler<InboundIn>.Delegate
|
NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate
|
||||||
>.Source
|
>.Source
|
||||||
|
|
||||||
/// The source of the asynchronous sequence.
|
/// The source of the asynchronous sequence.
|
||||||
|
@ -53,7 +53,7 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
|
||||||
|
|
||||||
/// An array of reads which will be yielded to the source with the next channel read complete.
|
/// An array of reads which will be yielded to the source with the next channel read complete.
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
var buffer: [InboundIn] = []
|
var buffer: [ProducerElement] = []
|
||||||
|
|
||||||
/// The current producing state.
|
/// The current producing state.
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
|
@ -67,10 +67,73 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
let closeRatchet: CloseRatchet
|
let closeRatchet: CloseRatchet
|
||||||
|
|
||||||
|
/// A type indicating what kind of transformation to apply to reads.
|
||||||
|
@usableFromInline
|
||||||
|
enum Transformation {
|
||||||
|
/// A synchronous transformation is applied to incoming reads. This is used when sync wrapping a channel.
|
||||||
|
case syncWrapping((InboundIn) -> ProducerElement)
|
||||||
|
/// This is used in the ServerBootstrap since we require to wrap the child channel on it's event loop but yield it on the parent's loop.
|
||||||
|
case bind((InboundIn) -> EventLoopFuture<ProducerElement>)
|
||||||
|
/// In the case of protocol negotiation we are applying a future based transformation where we wait for the transformation
|
||||||
|
/// to finish before we yield it to the source.
|
||||||
|
case protocolNegotiation((InboundIn) -> EventLoopFuture<ProducerElement>)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The transformation applied to incoming reads.
|
||||||
|
@usableFromInline
|
||||||
|
let transformation: Transformation
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
init(eventLoop: EventLoop, closeRatchet: CloseRatchet) {
|
init(
|
||||||
|
eventLoop: EventLoop,
|
||||||
|
closeRatchet: CloseRatchet,
|
||||||
|
transformation: Transformation
|
||||||
|
) {
|
||||||
self.eventLoop = eventLoop
|
self.eventLoop = eventLoop
|
||||||
self.closeRatchet = closeRatchet
|
self.closeRatchet = closeRatchet
|
||||||
|
self.transformation = transformation
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used when the pipeline got synchronously wrapped.
|
||||||
|
@inlinable
|
||||||
|
static func makeWrappingHandler(
|
||||||
|
eventLoop: EventLoop,
|
||||||
|
closeRatchet: CloseRatchet
|
||||||
|
) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == ProducerElement {
|
||||||
|
return .init(
|
||||||
|
eventLoop: eventLoop,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformation: .syncWrapping { $0 }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel.
|
||||||
|
@inlinable
|
||||||
|
static func makeBindingHandler(
|
||||||
|
eventLoop: EventLoop,
|
||||||
|
closeRatchet: CloseRatchet,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<ProducerElement>
|
||||||
|
) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == Channel {
|
||||||
|
return .init(
|
||||||
|
eventLoop: eventLoop,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformation: .bind(transformationClosure)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ``NIOAsyncChannelInboundStreamChannelHandler`` which is used in the bootstrap for the ServerChannel when the child
|
||||||
|
/// channel does protocol negotiation.
|
||||||
|
@inlinable
|
||||||
|
static func makeProtocolNegotiationHandler(
|
||||||
|
eventLoop: EventLoop,
|
||||||
|
closeRatchet: CloseRatchet,
|
||||||
|
transformationClosure: @escaping (Channel) -> EventLoopFuture<ProducerElement>
|
||||||
|
) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == Channel {
|
||||||
|
return .init(
|
||||||
|
eventLoop: eventLoop,
|
||||||
|
closeRatchet: closeRatchet,
|
||||||
|
transformation: .protocolNegotiation(transformationClosure)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
|
@ -86,10 +149,67 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler<InboundIn: Senda
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
self.buffer.append(self.unwrapInboundIn(data))
|
let unwrapped = self.unwrapInboundIn(data)
|
||||||
|
|
||||||
// We forward on reads here to enable better channel composition.
|
switch self.transformation {
|
||||||
context.fireChannelRead(data)
|
case .syncWrapping(let transformation):
|
||||||
|
self.buffer.append(transformation(unwrapped))
|
||||||
|
// We forward on reads here to enable better channel composition.
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
|
||||||
|
case .bind(let transformation):
|
||||||
|
// The unsafe transfers here are required because we need to use self in whenComplete
|
||||||
|
// We are making sure to be on our event loop so we can safely use self in whenComplete
|
||||||
|
let unsafeSelf = NIOLoopBound(self, eventLoop: context.eventLoop)
|
||||||
|
let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop)
|
||||||
|
transformation(unwrapped)
|
||||||
|
.hop(to: context.eventLoop)
|
||||||
|
.whenComplete { result in
|
||||||
|
unsafeSelf.value._transformationCompleted(context: unsafeContext.value, result: result)
|
||||||
|
|
||||||
|
// We forward the read only after the transformation has been completed. This is super important
|
||||||
|
// since we are setting up the NIOAsyncChannel handlers in the transformation and
|
||||||
|
// we must make sure to only generate reads once they are setup. Reads can only
|
||||||
|
// happen after the child channels hit `channelRead0` that's why we are holding the read here.
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
case .protocolNegotiation(let protocolNegotiation):
|
||||||
|
// The unsafe transfers here are required because we need to use self in whenComplete
|
||||||
|
// We are making sure to be on our event loop so we can safely use self in whenComplete
|
||||||
|
let unsafeSelf = NIOLoopBound(self, eventLoop: context.eventLoop)
|
||||||
|
let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop)
|
||||||
|
protocolNegotiation(unwrapped)
|
||||||
|
.hop(to: context.eventLoop)
|
||||||
|
.whenComplete { result in
|
||||||
|
unsafeSelf.value._transformationCompleted(context: unsafeContext.value, result: result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We forwarding the read here right away since protocol negotiation often needs reads to progress.
|
||||||
|
// In this case, we expect the user to synchronously wrap the child channel into a NIOAsyncChannel
|
||||||
|
// hence we don't have the timing issue as in the `.bind` case.
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
func _transformationCompleted(
|
||||||
|
context: ChannelHandlerContext,
|
||||||
|
result: Result<ProducerElement, Error>
|
||||||
|
) {
|
||||||
|
context.eventLoop.preconditionInEventLoop()
|
||||||
|
|
||||||
|
switch result {
|
||||||
|
case .success(let transformed):
|
||||||
|
self.buffer.append(transformed)
|
||||||
|
// We are delivering out of band here since the future can complete at any point
|
||||||
|
self._deliverReads(context: context)
|
||||||
|
|
||||||
|
case .failure:
|
||||||
|
// Transformation failed. Nothing to really do here this must be handled in the transformation
|
||||||
|
// futures themselves.
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
|
@ -215,33 +335,35 @@ extension NIOAsyncChannelInboundStreamChannelHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
extension NIOAsyncChannelInboundStreamChannelHandler {
|
@usableFromInline
|
||||||
|
struct NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate {
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
struct Delegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate {
|
let eventLoop: EventLoop
|
||||||
@usableFromInline
|
|
||||||
let eventLoop: EventLoop
|
|
||||||
|
|
||||||
@usableFromInline
|
@usableFromInline
|
||||||
let handler: NIOAsyncChannelInboundStreamChannelHandler<InboundIn>
|
let _didTerminate: () -> Void
|
||||||
|
|
||||||
@inlinable
|
@usableFromInline
|
||||||
init(handler: NIOAsyncChannelInboundStreamChannelHandler<InboundIn>) {
|
let _produceMore: () -> Void
|
||||||
self.eventLoop = handler.eventLoop
|
|
||||||
self.handler = handler
|
@inlinable
|
||||||
|
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelInboundStreamChannelHandler<InboundIn, ProducerElement>) {
|
||||||
|
self.eventLoop = handler.eventLoop
|
||||||
|
self._didTerminate = handler._didTerminate
|
||||||
|
self._produceMore = handler._produceMore
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
func didTerminate() {
|
||||||
|
self.eventLoop.execute {
|
||||||
|
self._didTerminate()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
func didTerminate() {
|
func produceMore() {
|
||||||
self.eventLoop.execute {
|
self.eventLoop.execute {
|
||||||
self.handler._didTerminate()
|
self._produceMore()
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@inlinable
|
|
||||||
func produceMore() {
|
|
||||||
self.eventLoop.execute {
|
|
||||||
self.handler._produceMore()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -249,4 +371,3 @@ extension NIOAsyncChannelInboundStreamChannelHandler {
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
@available(*, unavailable)
|
@available(*, unavailable)
|
||||||
extension NIOAsyncChannelInboundStreamChannelHandler: Sendable {}
|
extension NIOAsyncChannelInboundStreamChannelHandler: Sendable {}
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
/// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel.
|
/// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel.
|
||||||
///
|
///
|
||||||
/// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the
|
/// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the
|
||||||
|
@ -90,4 +89,3 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
|
||||||
self._outboundWriter.finish()
|
self._outboundWriter.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -113,12 +113,6 @@ internal final class NIOAsyncChannelOutboundWriterHandler<OutboundOut: Sendable>
|
||||||
self.sink = nil
|
self.sink = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@inlinable
|
|
||||||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
|
||||||
self.sink?.finish(error: error)
|
|
||||||
context.fireErrorCaught(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
@inlinable
|
@inlinable
|
||||||
func channelInactive(context: ChannelHandlerContext) {
|
func channelInactive(context: ChannelHandlerContext) {
|
||||||
self.sink?.finish()
|
self.sink?.finish()
|
||||||
|
@ -172,4 +166,3 @@ extension NIOAsyncChannelOutboundWriterHandler {
|
||||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
@available(*, unavailable)
|
@available(*, unavailable)
|
||||||
extension NIOAsyncChannelOutboundWriterHandler: Sendable {}
|
extension NIOAsyncChannelOutboundWriterHandler: Sendable {}
|
||||||
|
|
||||||
|
|
|
@ -91,4 +91,3 @@ final class CloseRatchet {
|
||||||
return self._state.closeWrite()
|
return self._state.closeWrite()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -343,3 +343,30 @@ extension RemovableChannelHandler {
|
||||||
context.leavePipeline(removalToken: removalToken)
|
context.leavePipeline(removalToken: removalToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The result of protocol negotiation.
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public enum NIOProtocolNegotiationResult<NegotiationResult> {
|
||||||
|
/// Indicates that the protocol negotiation finished.
|
||||||
|
case finished(NegotiationResult)
|
||||||
|
/// Indicates that protocol negotiation has been deferred to the next handler.
|
||||||
|
case deferredResult(EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>)
|
||||||
|
}
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
extension NIOProtocolNegotiationResult: Equatable where NegotiationResult: Equatable {}
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
extension NIOProtocolNegotiationResult: Sendable where NegotiationResult: Sendable {}
|
||||||
|
|
||||||
|
/// A ``ProtocolNegotiationHandler`` is a ``ChannelHandler`` that is responsible for negotiating networking protocols.
|
||||||
|
///
|
||||||
|
/// Typically these handlers are at the tail of the pipeline and wait until the peer indicated what protocol should be used. Once, the protocol
|
||||||
|
/// has been negotiated the handlers allow user code to configure the pipeline.
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public protocol NIOProtocolNegotiationHandler: ChannelHandler {
|
||||||
|
associatedtype NegotiationResult
|
||||||
|
|
||||||
|
/// The future which gets succeeded with the protocol negotiation result.
|
||||||
|
var protocolNegotiationResult: EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> { get }
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
import NIOCore
|
@_spi(AsyncChannel) import NIOCore
|
||||||
|
|
||||||
#if os(Windows)
|
#if os(Windows)
|
||||||
import ucrt
|
import ucrt
|
||||||
|
@ -311,8 +311,9 @@ public final class ServerBootstrap {
|
||||||
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
|
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
|
||||||
///
|
///
|
||||||
/// - parameters:
|
/// - parameters:
|
||||||
/// - unixDomainSocketPath: The _Unix domain socket_ path to bind to. `unixDomainSocketPath` must not exist, it will be created by the system.
|
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
|
||||||
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`.
|
/// unless `cleanupExistingSocketFile`is set to `true`.
|
||||||
|
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
|
||||||
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
|
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
|
||||||
if cleanupExistingSocketFile {
|
if cleanupExistingSocketFile {
|
||||||
do {
|
do {
|
||||||
|
@ -482,6 +483,454 @@ public final class ServerBootstrap {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: AsyncChannel based bind
|
||||||
|
extension ServerBootstrap {
|
||||||
|
/// Bind the `ServerSocketChannel` to the `host` and `port` parameters.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - host: The host to bind on.
|
||||||
|
/// - port: The port to bind on.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
|
||||||
|
/// - childChannelInboundType: The child channel's inbound type.
|
||||||
|
/// - childChannelOutboundType: The child channel's outbound type.
|
||||||
|
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
|
||||||
|
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func bind<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
|
||||||
|
host: String,
|
||||||
|
port: Int,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
|
||||||
|
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
|
||||||
|
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: Bool = false
|
||||||
|
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
|
||||||
|
return try await self.bindAsyncChannel0(
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy,
|
||||||
|
childBackpressureStrategy: childBackpressureStrategy,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
|
||||||
|
) {
|
||||||
|
return try SocketAddress.makeAddressResolvingHost(host, port: port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bind the `ServerSocketChannel` to the `address` parameter.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - address: The `SocketAddress` to bind on.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
|
||||||
|
/// - childChannelInboundType: The child channel's inbound type.
|
||||||
|
/// - childChannelOutboundType: The child channel's outbound type.
|
||||||
|
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
|
||||||
|
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func bind<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
|
||||||
|
to address: SocketAddress,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
|
||||||
|
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: Bool = false
|
||||||
|
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
|
||||||
|
return try await self.bindAsyncChannel0(
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy,
|
||||||
|
childBackpressureStrategy: childBackpressureStrategy,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
|
||||||
|
) { address }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bind the `ServerSocketChannel` to the `unixDomainSocketPath` parameter.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
|
||||||
|
/// unless `cleanupExistingSocketFile`is set to `true`.
|
||||||
|
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
|
||||||
|
/// - childChannelInboundType: The child channel's inbound type.
|
||||||
|
/// - childChannelOutboundType: The child channel's outbound type.
|
||||||
|
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
|
||||||
|
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func bind<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
|
||||||
|
unixDomainSocketPath: String,
|
||||||
|
cleanupExistingSocketFile: Bool = false,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
|
||||||
|
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: Bool = false
|
||||||
|
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
|
||||||
|
if cleanupExistingSocketFile {
|
||||||
|
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return try await self.bind(
|
||||||
|
unixDomainSocketPath: unixDomainSocketPath,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy,
|
||||||
|
childBackpressureStrategy: childBackpressureStrategy,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use the existing bound socket file descriptor.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - socket: The _Unix file descriptor_ representing the bound stream socket.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - childBackpressureStrategy: The back pressure strategy used by the child channels.
|
||||||
|
/// - childChannelInboundType: The child channel's inbound type.
|
||||||
|
/// - childChannelOutboundType: The child channel's outbound type.
|
||||||
|
/// - isChildChannelOutboundHalfClosureEnabled: Indicates if half closure is enabled on the child channels. If half closure is enabled
|
||||||
|
/// then finishing the ``NIOAsyncChannelWriter`` will lead to half closure.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of connection ``NIOAsyncChannel``s.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func withBoundSocket<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
|
||||||
|
_ socket: NIOBSDSocket.Handle,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||||
|
childChannelInboundType: ChildChannelInbound.Type = ChildChannelInbound.self,
|
||||||
|
childChannelOutboundType: ChildChannelOutbound.Type = ChildChannelOutbound.self,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: Bool = false
|
||||||
|
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
|
||||||
|
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
|
||||||
|
if enableMPTCP {
|
||||||
|
throw ChannelError.operationUnsupported
|
||||||
|
}
|
||||||
|
return try ServerSocketChannel(socket: socket, eventLoop: eventLoop, group: childEventLoopGroup)
|
||||||
|
}
|
||||||
|
return try await self.bindAsyncChannel0(
|
||||||
|
makeServerChannel: makeChannel,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy,
|
||||||
|
childBackpressureStrategy: childBackpressureStrategy,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
|
||||||
|
) { (eventLoop, serverChannel) in
|
||||||
|
let promise = eventLoop.makePromise(of: Void.self)
|
||||||
|
serverChannel.registerAlreadyConfigured0(promise: promise)
|
||||||
|
return promise.futureResult
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
private func bindAsyncChannel0<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: Bool,
|
||||||
|
_ makeSocketAddress: () throws -> SocketAddress
|
||||||
|
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
|
||||||
|
let address = try makeSocketAddress()
|
||||||
|
|
||||||
|
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
|
||||||
|
return try ServerSocketChannel(eventLoop: eventLoop,
|
||||||
|
group: childEventLoopGroup,
|
||||||
|
protocolFamily: address.protocol,
|
||||||
|
enableMPTCP: enableMPTCP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return try await self.bindAsyncChannel0(
|
||||||
|
makeServerChannel: makeChannel,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy,
|
||||||
|
childBackpressureStrategy: childBackpressureStrategy,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled
|
||||||
|
) { (eventLoop, serverChannel) in
|
||||||
|
serverChannel.registerAndDoSynchronously { serverChannel in
|
||||||
|
serverChannel.bind(to: address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private typealias MakeServerChannel = (_ eventLoop: SelectableEventLoop, _ childGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel
|
||||||
|
private typealias Register = @Sendable (EventLoop, ServerSocketChannel) -> EventLoopFuture<Void>
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
private func bindAsyncChannel0<ChildChannelInbound: Sendable, ChildChannelOutbound: Sendable>(
|
||||||
|
makeServerChannel: MakeServerChannel,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
childBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
isChildChannelOutboundHalfClosureEnabled: Bool,
|
||||||
|
_ register: @escaping Register
|
||||||
|
) async throws -> NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never> {
|
||||||
|
let eventLoop = self.group.next()
|
||||||
|
let childEventLoopGroup = self.childGroup
|
||||||
|
let serverChannelOptions = self._serverChannelOptions
|
||||||
|
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
|
||||||
|
let childChannelInit = self.childChannelInit
|
||||||
|
let childChannelOptions = self._childChannelOptions
|
||||||
|
|
||||||
|
let serverChannel = try makeServerChannel(eventLoop as! SelectableEventLoop, childEventLoopGroup, self.enableMPTCP)
|
||||||
|
|
||||||
|
return try await eventLoop.submit {
|
||||||
|
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
|
||||||
|
serverChannelInit(serverChannel)
|
||||||
|
}.flatMap {
|
||||||
|
do {
|
||||||
|
try serverChannel.pipeline.syncOperations.addHandler(
|
||||||
|
AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions),
|
||||||
|
name: "AcceptHandler"
|
||||||
|
)
|
||||||
|
|
||||||
|
// We are wrapping the inbound channels into `NIOAsyncChannel`s with the transformation
|
||||||
|
// closure of the `NIOAsyncChannel` that allows us to wrap them without adding
|
||||||
|
// wrapping/unwrapping handlers to the pipeline.
|
||||||
|
let asyncChannel = try NIOAsyncChannel<NIOAsyncChannel<ChildChannelInbound, ChildChannelOutbound>, Never>.wrapAsyncChannelForBootstrapBind(
|
||||||
|
synchronouslyWrapping: serverChannel,
|
||||||
|
backpressureStrategy: serverBackpressureStrategy,
|
||||||
|
transformationClosure: { channel in
|
||||||
|
// We must hop to the child channel event loop here to add the async channel handlers
|
||||||
|
channel.eventLoop.submit {
|
||||||
|
try NIOAsyncChannel(
|
||||||
|
synchronouslyWrapping: channel,
|
||||||
|
backpressureStrategy: childBackpressureStrategy,
|
||||||
|
isOutboundHalfClosureEnabled: isChildChannelOutboundHalfClosureEnabled,
|
||||||
|
inboundType: ChildChannelInbound.self,
|
||||||
|
outboundType: ChildChannelOutbound.self
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return register(eventLoop, serverChannel).map { asyncChannel }
|
||||||
|
} catch {
|
||||||
|
return eventLoop.makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
}.flatMapError { error in
|
||||||
|
serverChannel.close0(error: error, mode: .all, promise: nil)
|
||||||
|
return eventLoop.makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
}.flatMap {
|
||||||
|
$0
|
||||||
|
}.get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: AsyncChannel based bind with protocol negotiation
|
||||||
|
extension ServerBootstrap {
|
||||||
|
/// Bind the `ServerSocketChannel` to the `host` and `port` parameters.
|
||||||
|
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - host: The host to bind on.
|
||||||
|
/// - port: The port to bind on.
|
||||||
|
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
|
||||||
|
/// must be added to the pipeline in the child channel initializer.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
|
||||||
|
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func bind<Handler: NIOProtocolNegotiationHandler>(
|
||||||
|
host: String,
|
||||||
|
port: Int,
|
||||||
|
protocolNegotiationHandlerType: Handler.Type,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
|
||||||
|
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
|
||||||
|
return try await self.bindAsyncChannelWithProtocolNegotiation0(
|
||||||
|
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy
|
||||||
|
) {
|
||||||
|
return try SocketAddress.makeAddressResolvingHost(host, port: port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bind the `ServerSocketChannel` to the `address` parameter.
|
||||||
|
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - address: The `SocketAddress` to bind on.
|
||||||
|
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
|
||||||
|
/// must be added to the pipeline in the child channel initializer.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
|
||||||
|
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func bind<Handler: NIOProtocolNegotiationHandler>(
|
||||||
|
to address: SocketAddress,
|
||||||
|
protocolNegotiationHandlerType: Handler.Type,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
|
||||||
|
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
|
||||||
|
return try await self.bindAsyncChannelWithProtocolNegotiation0(
|
||||||
|
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy
|
||||||
|
) { address }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
|
||||||
|
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
|
||||||
|
/// unless `cleanupExistingSocketFile`is set to `true`.
|
||||||
|
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
|
||||||
|
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
|
||||||
|
/// must be added to the pipeline in the child channel initializer.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
|
||||||
|
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func bind<Handler: NIOProtocolNegotiationHandler>(
|
||||||
|
unixDomainSocketPath: String,
|
||||||
|
cleanupExistingSocketFile: Bool = false,
|
||||||
|
protocolNegotiationHandlerType: Handler.Type,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
|
||||||
|
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
|
||||||
|
if cleanupExistingSocketFile {
|
||||||
|
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return try await self.bind(
|
||||||
|
unixDomainSocketPath: unixDomainSocketPath,
|
||||||
|
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use the existing bound socket file descriptor.
|
||||||
|
/// Waiting for protocol negotiation to finish before yielding the channel to the returned ``NIOAsyncChannel``.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - socket: The _Unix file descriptor_ representing the bound stream socket.
|
||||||
|
/// - protocolNegotiationHandlerType: The protocol negotiation handler type that is awaited on. A handler of this type
|
||||||
|
/// must be added to the pipeline in the child channel initializer.
|
||||||
|
/// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel.
|
||||||
|
/// - Returns: A ``NIOAsyncChannel`` of the protocol negotiation results. It is expected that the protocol negotiation handler
|
||||||
|
/// is going to wrap the child channels into ``NIOAsyncChannel`` which are returned as part of the negotiation result.
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func withBoundSocket<Handler: NIOProtocolNegotiationHandler>(
|
||||||
|
_ socket: NIOBSDSocket.Handle,
|
||||||
|
protocolNegotiationHandlerType: Handler.Type,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil
|
||||||
|
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
|
||||||
|
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
|
||||||
|
if enableMPTCP {
|
||||||
|
throw ChannelError.operationUnsupported
|
||||||
|
}
|
||||||
|
return try ServerSocketChannel(socket: socket, eventLoop: eventLoop, group: childEventLoopGroup)
|
||||||
|
}
|
||||||
|
return try await self.bindAsyncChannelWithProtocolNegotiation0(
|
||||||
|
makeServerChannel: makeChannel,
|
||||||
|
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy
|
||||||
|
) { (eventLoop, serverChannel) in
|
||||||
|
let promise = eventLoop.makePromise(of: Void.self)
|
||||||
|
serverChannel.registerAlreadyConfigured0(promise: promise)
|
||||||
|
return promise.futureResult
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
private func bindAsyncChannelWithProtocolNegotiation0<Handler: NIOProtocolNegotiationHandler & ChannelHandler>(
|
||||||
|
protocolNegotiationHandlerType: Handler.Type,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
_ makeSocketAddress: () throws -> SocketAddress
|
||||||
|
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
|
||||||
|
let address = try makeSocketAddress()
|
||||||
|
|
||||||
|
func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel {
|
||||||
|
return try ServerSocketChannel(eventLoop: eventLoop,
|
||||||
|
group: childEventLoopGroup,
|
||||||
|
protocolFamily: address.protocol,
|
||||||
|
enableMPTCP: enableMPTCP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return try await self.bindAsyncChannelWithProtocolNegotiation0(
|
||||||
|
makeServerChannel: makeChannel,
|
||||||
|
protocolNegotiationHandlerType: protocolNegotiationHandlerType,
|
||||||
|
serverBackpressureStrategy: serverBackpressureStrategy
|
||||||
|
) { (eventLoop, serverChannel) in
|
||||||
|
serverChannel.registerAndDoSynchronously { serverChannel in
|
||||||
|
serverChannel.bind(to: address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||||
|
private func bindAsyncChannelWithProtocolNegotiation0<Handler: NIOProtocolNegotiationHandler & ChannelHandler>(
|
||||||
|
makeServerChannel: MakeServerChannel,
|
||||||
|
protocolNegotiationHandlerType: Handler.Type,
|
||||||
|
serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
|
||||||
|
_ register: @escaping Register
|
||||||
|
) async throws -> NIOAsyncChannel<Handler.NegotiationResult, Never> {
|
||||||
|
let eventLoop = self.group.next()
|
||||||
|
let childEventLoopGroup = self.childGroup
|
||||||
|
let serverChannelOptions = self._serverChannelOptions
|
||||||
|
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
|
||||||
|
let childChannelInit = self.childChannelInit
|
||||||
|
let childChannelOptions = self._childChannelOptions
|
||||||
|
|
||||||
|
let serverChannel = try makeServerChannel(eventLoop as! SelectableEventLoop, childEventLoopGroup, self.enableMPTCP)
|
||||||
|
|
||||||
|
return try await eventLoop.submit {
|
||||||
|
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
|
||||||
|
serverChannelInit(serverChannel)
|
||||||
|
}.flatMap {
|
||||||
|
do {
|
||||||
|
try serverChannel.pipeline.syncOperations.addHandler(
|
||||||
|
AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions),
|
||||||
|
name: "AcceptHandler"
|
||||||
|
)
|
||||||
|
|
||||||
|
// In the case of protocol negotiation we cannot wrap the child channels into
|
||||||
|
// `NIOAsyncChannel`s for the user since we don't know the type. We rather expect
|
||||||
|
// the user to wrap the child channels themselves when the negotiation is done
|
||||||
|
// and return the wrapped async channel as part of the negotiation result.
|
||||||
|
let asyncChannel = try NIOAsyncChannel<Handler.NegotiationResult, Never>.wrapAsyncChannelForBootstrapBindWithProtocolNegotiation(
|
||||||
|
synchronouslyWrapping: serverChannel,
|
||||||
|
backpressureStrategy: serverBackpressureStrategy,
|
||||||
|
transformationClosure: { (channel: Channel) in
|
||||||
|
channel.pipeline.handler(type: protocolNegotiationHandlerType)
|
||||||
|
.flatMap { handler -> EventLoopFuture<NIOProtocolNegotiationResult<Handler.NegotiationResult>> in
|
||||||
|
handler.protocolNegotiationResult
|
||||||
|
}.flatMap { result in
|
||||||
|
ServerBootstrap.waitForFinalResult(result, eventLoop: eventLoop)
|
||||||
|
}.flatMapErrorThrowing { error in
|
||||||
|
channel.pipeline.fireErrorCaught(error)
|
||||||
|
channel.close(promise: nil)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
)
|
||||||
|
return register(eventLoop, serverChannel).map { asyncChannel }
|
||||||
|
} catch {
|
||||||
|
return eventLoop.makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
}.flatMapError { error in
|
||||||
|
serverChannel.close0(error: error, mode: .all, promise: nil)
|
||||||
|
return eventLoop.makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
}.flatMap {
|
||||||
|
$0
|
||||||
|
}.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This method recursively waits for the final result of protocol negotiation
|
||||||
|
static func waitForFinalResult<NegotiationResult>(
|
||||||
|
_ result: NIOProtocolNegotiationResult<NegotiationResult>,
|
||||||
|
eventLoop: EventLoop
|
||||||
|
) -> EventLoopFuture<NegotiationResult> {
|
||||||
|
switch result {
|
||||||
|
case .finished(let negotiationResult):
|
||||||
|
return eventLoop.makeSucceededFuture(negotiationResult)
|
||||||
|
|
||||||
|
case .deferredResult(let future):
|
||||||
|
return future.flatMap { result in
|
||||||
|
return waitForFinalResult(result, eventLoop: eventLoop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@available(*, unavailable)
|
@available(*, unavailable)
|
||||||
extension ServerBootstrap: Sendable {}
|
extension ServerBootstrap: Sendable {}
|
||||||
|
|
||||||
|
@ -1058,8 +1507,9 @@ public final class DatagramBootstrap {
|
||||||
/// Bind the `DatagramChannel` to a UNIX Domain Socket.
|
/// Bind the `DatagramChannel` to a UNIX Domain Socket.
|
||||||
///
|
///
|
||||||
/// - parameters:
|
/// - parameters:
|
||||||
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. `path` must not exist, it will be created by the system.
|
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist,
|
||||||
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`.
|
/// unless `cleanupExistingSocketFile`is set to `true`.
|
||||||
|
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`.
|
||||||
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
|
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
|
||||||
if cleanupExistingSocketFile {
|
if cleanupExistingSocketFile {
|
||||||
do {
|
do {
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
import NIOCore
|
import NIOCore
|
||||||
import DequeModule
|
|
||||||
|
|
||||||
/// The result of an ALPN negotiation.
|
/// The result of an ALPN negotiation.
|
||||||
///
|
///
|
||||||
|
@ -35,6 +34,14 @@ public enum ALPNResult: Equatable, Sendable {
|
||||||
/// ALPN negotiation either failed, or never took place. The application
|
/// ALPN negotiation either failed, or never took place. The application
|
||||||
/// should fall back to a default protocol choice or close the connection.
|
/// should fall back to a default protocol choice or close the connection.
|
||||||
case fallback
|
case fallback
|
||||||
|
|
||||||
|
init(negotiated: String?) {
|
||||||
|
if let negotiated = negotiated {
|
||||||
|
self = .negotiated(negotiated)
|
||||||
|
} else {
|
||||||
|
self = .fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A helper `ChannelInboundHandler` that makes it easy to swap channel pipelines
|
/// A helper `ChannelInboundHandler` that makes it easy to swap channel pipelines
|
||||||
|
@ -62,8 +69,7 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
|
||||||
public typealias InboundOut = Any
|
public typealias InboundOut = Any
|
||||||
|
|
||||||
private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<Void>
|
private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<Void>
|
||||||
private var waitingForUser: Bool
|
private var stateMachine = ProtocolNegotiationHandlerStateMachine<Void>()
|
||||||
private var eventBuffer: Deque<NIOAny>
|
|
||||||
|
|
||||||
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
|
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
|
||||||
/// callback.
|
/// callback.
|
||||||
|
@ -72,8 +78,6 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
|
||||||
/// negotiation has completed.
|
/// negotiation has completed.
|
||||||
public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<Void>) {
|
public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<Void>) {
|
||||||
self.completionHandler = alpnCompleteHandler
|
self.completionHandler = alpnCompleteHandler
|
||||||
self.waitingForUser = false
|
|
||||||
self.eventBuffer = []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
|
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
|
||||||
|
@ -88,56 +92,71 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
||||||
guard let tlsEvent = event as? TLSUserEvent else {
|
switch self.stateMachine.userInboundEventTriggered(event: event) {
|
||||||
|
case .fireUserInboundEventTriggered:
|
||||||
context.fireUserInboundEventTriggered(event)
|
context.fireUserInboundEventTriggered(event)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if case .handshakeCompleted(let p) = tlsEvent {
|
case .invokeUserClosure(let result):
|
||||||
handshakeCompleted(context: context, negotiatedProtocol: p)
|
self.invokeUserClosure(context: context, result: result)
|
||||||
} else {
|
|
||||||
context.fireUserInboundEventTriggered(event)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
if waitingForUser {
|
switch self.stateMachine.channelRead(data: data) {
|
||||||
eventBuffer.append(data)
|
case .fireChannelRead:
|
||||||
} else {
|
|
||||||
context.fireChannelRead(data)
|
context.fireChannelRead(data)
|
||||||
|
|
||||||
|
case .none:
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handshakeCompleted(context: ChannelHandlerContext, negotiatedProtocol: String?) {
|
public func channelInactive(context: ChannelHandlerContext) {
|
||||||
waitingForUser = true
|
self.stateMachine.channelInactive()
|
||||||
|
|
||||||
let result: ALPNResult
|
context.fireChannelInactive()
|
||||||
if let negotiatedProtocol = negotiatedProtocol {
|
}
|
||||||
result = .negotiated(negotiatedProtocol)
|
|
||||||
} else {
|
|
||||||
result = .fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
|
private func invokeUserClosure(context: ChannelHandlerContext, result: ALPNResult) {
|
||||||
let switchFuture = self.completionHandler(result, context.channel)
|
let switchFuture = self.completionHandler(result, context.channel)
|
||||||
switchFuture.whenComplete { (_: Result<Void, Error>) in
|
|
||||||
|
switchFuture
|
||||||
|
.hop(to: context.eventLoop)
|
||||||
|
.whenComplete { result in
|
||||||
|
self.userFutureCompleted(context: context, result: result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func userFutureCompleted(context: ChannelHandlerContext, result: Result<Void, Error>) {
|
||||||
|
switch self.stateMachine.userFutureCompleted(with: result) {
|
||||||
|
case .fireErrorCaughtAndRemoveHandler(let error):
|
||||||
|
context.fireErrorCaught(error)
|
||||||
|
context.pipeline.removeHandler(self, promise: nil)
|
||||||
|
|
||||||
|
case .fireErrorCaughtAndStartUnbuffering(let error):
|
||||||
|
context.fireErrorCaught(error)
|
||||||
self.unbuffer(context: context)
|
self.unbuffer(context: context)
|
||||||
|
|
||||||
|
case .startUnbuffering:
|
||||||
|
self.unbuffer(context: context)
|
||||||
|
|
||||||
|
case .removeHandler:
|
||||||
context.pipeline.removeHandler(self, promise: nil)
|
context.pipeline.removeHandler(self, promise: nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func unbuffer(context: ChannelHandlerContext) {
|
private func unbuffer(context: ChannelHandlerContext) {
|
||||||
// First we check if we have anything to unbuffer
|
while true {
|
||||||
guard !self.eventBuffer.isEmpty else {
|
switch self.stateMachine.unbuffer() {
|
||||||
return
|
case .fireChannelRead(let data):
|
||||||
}
|
context.fireChannelRead(data)
|
||||||
|
|
||||||
// Now we unbuffer until there is nothing left.
|
case .fireChannelReadCompleteAndRemoveHandler:
|
||||||
// Importantly firing a channel read can lead to new reads being buffered due to reentrancy!
|
context.fireChannelReadComplete()
|
||||||
while let datum = self.eventBuffer.popFirst() {
|
context.pipeline.removeHandler(self, promise: nil)
|
||||||
context.fireChannelRead(datum)
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
context.fireChannelReadComplete()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@_spi(AsyncChannel) import NIOCore
|
||||||
|
|
||||||
|
/// A helper ``ChannelInboundHandler`` that makes it easy to swap channel pipelines
|
||||||
|
/// based on the result of an ALPN negotiation.
|
||||||
|
///
|
||||||
|
/// The standard pattern used by applications that want to use ALPN is to select
|
||||||
|
/// an application protocol based on the result, optionally falling back to some
|
||||||
|
/// default protocol. To do this in SwiftNIO requires that the channel pipeline be
|
||||||
|
/// reconfigured based on the result of the ALPN negotiation. This channel handler
|
||||||
|
/// encapsulates that logic in a generic form that doesn't depend on the specific
|
||||||
|
/// TLS implementation in use by using ``TLSUserEvent``
|
||||||
|
///
|
||||||
|
/// The user of this channel handler provides a single closure that is called with
|
||||||
|
/// an ``ALPNResult`` when the ALPN negotiation is complete. Based on that result
|
||||||
|
/// the user is free to reconfigure the ``ChannelPipeline`` as required, and should
|
||||||
|
/// return an ``EventLoopFuture`` that will complete when the pipeline is reconfigured.
|
||||||
|
///
|
||||||
|
/// Until the ``EventLoopFuture`` completes, this channel handler will buffer inbound
|
||||||
|
/// data. When the ``EventLoopFuture`` completes, the buffered data will be replayed
|
||||||
|
/// down the channel. Then, finally, this channel handler will automatically remove
|
||||||
|
/// itself from the channel pipeline, leaving the pipeline in its final
|
||||||
|
/// configuration.
|
||||||
|
///
|
||||||
|
/// Importantly, this is a typed variant of the ``ApplicationProtocolNegotiationHandler`` and allows the user to
|
||||||
|
/// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult``
|
||||||
|
/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into ``NIOAsyncChannel``
|
||||||
|
/// based bootstraps.
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public final class NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>: ChannelInboundHandler, RemovableChannelHandler, NIOProtocolNegotiationHandler {
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public typealias InboundIn = Any
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public typealias InboundOut = Any
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public var protocolNegotiationResult: EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
|
||||||
|
self.negotiatedPromise.futureResult
|
||||||
|
}
|
||||||
|
|
||||||
|
private let negotiatedPromise: EventLoopPromise<NIOProtocolNegotiationResult<NegotiationResult>>
|
||||||
|
|
||||||
|
private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>
|
||||||
|
private var stateMachine = ProtocolNegotiationHandlerStateMachine<NIOProtocolNegotiationResult<NegotiationResult>>()
|
||||||
|
|
||||||
|
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
|
||||||
|
/// callback.
|
||||||
|
///
|
||||||
|
/// - Parameter alpnCompleteHandler: The closure that will fire when ALPN
|
||||||
|
/// negotiation has completed.
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public init(eventLoop: EventLoop, alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
|
||||||
|
self.completionHandler = alpnCompleteHandler
|
||||||
|
self.negotiatedPromise = eventLoop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an `ApplicationProtocolNegotiationHandler` with the given completion
|
||||||
|
/// callback.
|
||||||
|
///
|
||||||
|
/// - Parameter alpnCompleteHandler: The closure that will fire when ALPN
|
||||||
|
/// negotiation has completed.
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public convenience init(eventLoop: EventLoop, alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>>) {
|
||||||
|
self.init(eventLoop: eventLoop) { result, _ in
|
||||||
|
alpnCompleteHandler(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
||||||
|
switch self.stateMachine.userInboundEventTriggered(event: event) {
|
||||||
|
case .fireUserInboundEventTriggered:
|
||||||
|
context.fireUserInboundEventTriggered(event)
|
||||||
|
|
||||||
|
case .invokeUserClosure(let result):
|
||||||
|
self.invokeUserClosure(context: context, result: result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
|
switch self.stateMachine.channelRead(data: data) {
|
||||||
|
case .fireChannelRead:
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
|
||||||
|
case .none:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@_spi(AsyncChannel)
|
||||||
|
public func channelInactive(context: ChannelHandlerContext) {
|
||||||
|
self.stateMachine.channelInactive()
|
||||||
|
|
||||||
|
self.negotiatedPromise.fail(ChannelError.outputClosed)
|
||||||
|
context.fireChannelInactive()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func invokeUserClosure(context: ChannelHandlerContext, result: ALPNResult) {
|
||||||
|
let switchFuture = self.completionHandler(result, context.channel)
|
||||||
|
|
||||||
|
switchFuture
|
||||||
|
.hop(to: context.eventLoop)
|
||||||
|
.whenComplete { result in
|
||||||
|
self.userFutureCompleted(context: context, result: result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func userFutureCompleted(context: ChannelHandlerContext, result: Result<NIOProtocolNegotiationResult<NegotiationResult>, Error>) {
|
||||||
|
switch self.stateMachine.userFutureCompleted(with: result) {
|
||||||
|
case .fireErrorCaughtAndRemoveHandler(let error):
|
||||||
|
self.negotiatedPromise.fail(error)
|
||||||
|
context.fireErrorCaught(error)
|
||||||
|
context.pipeline.removeHandler(self, promise: nil)
|
||||||
|
|
||||||
|
case .fireErrorCaughtAndStartUnbuffering(let error):
|
||||||
|
self.negotiatedPromise.fail(error)
|
||||||
|
context.fireErrorCaught(error)
|
||||||
|
self.unbuffer(context: context)
|
||||||
|
|
||||||
|
case .startUnbuffering(let value):
|
||||||
|
self.negotiatedPromise.succeed(value)
|
||||||
|
self.unbuffer(context: context)
|
||||||
|
|
||||||
|
case .removeHandler(let value):
|
||||||
|
self.negotiatedPromise.succeed(value)
|
||||||
|
context.pipeline.removeHandler(self, promise: nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func unbuffer(context: ChannelHandlerContext) {
|
||||||
|
while true {
|
||||||
|
switch self.stateMachine.unbuffer() {
|
||||||
|
case .fireChannelRead(let data):
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
|
||||||
|
case .fireChannelReadCompleteAndRemoveHandler:
|
||||||
|
context.fireChannelReadComplete()
|
||||||
|
context.pipeline.removeHandler(self, promise: nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(*, unavailable)
|
||||||
|
extension NIOTypedApplicationProtocolNegotiationHandler: Sendable {}
|
|
@ -0,0 +1,159 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
import DequeModule
|
||||||
|
import NIOCore
|
||||||
|
|
||||||
|
struct ProtocolNegotiationHandlerStateMachine<NegotiationResult> {
|
||||||
|
enum State {
|
||||||
|
/// The state before we received a TLSUserEvent. We are just forwarding any read at this point.
|
||||||
|
case initial
|
||||||
|
/// The state after we received a ``TLSUserEvent`` and are waiting for the future of the user to complete.
|
||||||
|
case waitingForUser(buffer: Deque<NIOAny>)
|
||||||
|
/// The state after the users future finished and we are unbuffering all the reads.
|
||||||
|
case unbuffering(buffer: Deque<NIOAny>)
|
||||||
|
/// The state once the negotiation is done and we are finished with unbuffering.
|
||||||
|
case finished
|
||||||
|
}
|
||||||
|
|
||||||
|
private var state = State.initial
|
||||||
|
|
||||||
|
@usableFromInline
|
||||||
|
enum UserInboundEventTriggeredAction {
|
||||||
|
case fireUserInboundEventTriggered
|
||||||
|
case invokeUserClosure(ALPNResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
mutating func userInboundEventTriggered(event: Any) -> UserInboundEventTriggeredAction {
|
||||||
|
if case .handshakeCompleted(let negotiated) = event as? TLSUserEvent {
|
||||||
|
switch self.state {
|
||||||
|
case .initial:
|
||||||
|
self.state = .waitingForUser(buffer: .init())
|
||||||
|
|
||||||
|
return .invokeUserClosure(.init(negotiated: negotiated))
|
||||||
|
case .waitingForUser, .unbuffering:
|
||||||
|
preconditionFailure("Unexpectedly received two TLSUserEvents")
|
||||||
|
|
||||||
|
case .finished:
|
||||||
|
// This is weird but we can tolerate it and just forward the event
|
||||||
|
return .fireUserInboundEventTriggered
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return .fireUserInboundEventTriggered
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@usableFromInline
|
||||||
|
enum ChannelReadAction {
|
||||||
|
case fireChannelRead
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
mutating func channelRead(data: NIOAny) -> ChannelReadAction? {
|
||||||
|
switch self.state {
|
||||||
|
case .initial, .finished:
|
||||||
|
return .fireChannelRead
|
||||||
|
|
||||||
|
case .waitingForUser(var buffer):
|
||||||
|
buffer.append(data)
|
||||||
|
self.state = .waitingForUser(buffer: buffer)
|
||||||
|
|
||||||
|
return .none
|
||||||
|
|
||||||
|
case .unbuffering(var buffer):
|
||||||
|
buffer.append(data)
|
||||||
|
self.state = .unbuffering(buffer: buffer)
|
||||||
|
|
||||||
|
return .none
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@usableFromInline
|
||||||
|
enum UserFutureCompletedAction {
|
||||||
|
case fireErrorCaughtAndRemoveHandler(Error)
|
||||||
|
case fireErrorCaughtAndStartUnbuffering(Error)
|
||||||
|
case startUnbuffering(NegotiationResult)
|
||||||
|
case removeHandler(NegotiationResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
mutating func userFutureCompleted(with result: Result<NegotiationResult, Error>) -> UserFutureCompletedAction {
|
||||||
|
switch self.state {
|
||||||
|
case .initial, .finished:
|
||||||
|
preconditionFailure("Invalid state \(self.state)")
|
||||||
|
|
||||||
|
case .waitingForUser(let buffer):
|
||||||
|
|
||||||
|
switch result {
|
||||||
|
case .success(let value):
|
||||||
|
if !buffer.isEmpty {
|
||||||
|
self.state = .unbuffering(buffer: buffer)
|
||||||
|
return .startUnbuffering(value)
|
||||||
|
} else {
|
||||||
|
self.state = .finished
|
||||||
|
return .removeHandler(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
case .failure(let error):
|
||||||
|
if !buffer.isEmpty {
|
||||||
|
self.state = .unbuffering(buffer: buffer)
|
||||||
|
return .fireErrorCaughtAndStartUnbuffering(error)
|
||||||
|
} else {
|
||||||
|
self.state = .finished
|
||||||
|
return .fireErrorCaughtAndRemoveHandler(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case .unbuffering:
|
||||||
|
preconditionFailure("Invalid state \(self.state)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@usableFromInline
|
||||||
|
enum UnbufferAction {
|
||||||
|
case fireChannelRead(NIOAny)
|
||||||
|
case fireChannelReadCompleteAndRemoveHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
mutating func unbuffer() -> UnbufferAction {
|
||||||
|
switch self.state {
|
||||||
|
case .initial, .waitingForUser, .finished:
|
||||||
|
preconditionFailure("Invalid state \(self.state)")
|
||||||
|
|
||||||
|
case .unbuffering(var buffer):
|
||||||
|
if let element = buffer.popFirst() {
|
||||||
|
self.state = .unbuffering(buffer: buffer)
|
||||||
|
|
||||||
|
return .fireChannelRead(element)
|
||||||
|
} else {
|
||||||
|
self.state = .finished
|
||||||
|
|
||||||
|
return .fireChannelReadCompleteAndRemoveHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@inlinable
|
||||||
|
mutating func channelInactive() {
|
||||||
|
switch self.state {
|
||||||
|
case .initial, .unbuffering, .waitingForUser:
|
||||||
|
self.state = .finished
|
||||||
|
|
||||||
|
case .finished:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -78,7 +78,12 @@ final class AsyncChannelTests: XCTestCase {
|
||||||
|
|
||||||
do {
|
do {
|
||||||
let wrapped = try await channel.testingEventLoop.executeInContext {
|
let wrapped = try await channel.testingEventLoop.executeInContext {
|
||||||
try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: Never.self)
|
try NIOAsyncChannel(
|
||||||
|
synchronouslyWrapping: channel,
|
||||||
|
isOutboundHalfClosureEnabled: true,
|
||||||
|
inboundType: Never.self,
|
||||||
|
outboundType: Never.self
|
||||||
|
)
|
||||||
}
|
}
|
||||||
inboundReader = wrapped.inboundStream
|
inboundReader = wrapped.inboundStream
|
||||||
|
|
||||||
|
@ -140,7 +145,12 @@ final class AsyncChannelTests: XCTestCase {
|
||||||
|
|
||||||
do {
|
do {
|
||||||
let wrapped = try await channel.testingEventLoop.executeInContext {
|
let wrapped = try await channel.testingEventLoop.executeInContext {
|
||||||
try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: Never.self)
|
try NIOAsyncChannel(
|
||||||
|
synchronouslyWrapping: channel,
|
||||||
|
isOutboundHalfClosureEnabled: true,
|
||||||
|
inboundType: Never.self,
|
||||||
|
outboundType: Never.self
|
||||||
|
)
|
||||||
}
|
}
|
||||||
inboundReader = wrapped.inboundStream
|
inboundReader = wrapped.inboundStream
|
||||||
|
|
||||||
|
@ -239,24 +249,6 @@ final class AsyncChannelTests: XCTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testErrorsArePropagatedToWriters() {
|
|
||||||
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
|
|
||||||
XCTAsyncTest(timeout: 5) {
|
|
||||||
let channel = NIOAsyncTestingChannel()
|
|
||||||
let wrapped = try await channel.testingEventLoop.executeInContext {
|
|
||||||
try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: String.self)
|
|
||||||
}
|
|
||||||
|
|
||||||
try await channel.testingEventLoop.executeInContext {
|
|
||||||
channel.pipeline.fireErrorCaught(TestError.bang)
|
|
||||||
}
|
|
||||||
|
|
||||||
try await XCTAssertThrowsError(await wrapped.outboundWriter.write("hello")) { error in
|
|
||||||
XCTAssertEqual(error as? TestError, .bang)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testChannelBecomingNonWritableDelaysWriters() {
|
func testChannelBecomingNonWritableDelaysWriters() {
|
||||||
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
|
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
|
||||||
XCTAsyncTest(timeout: 5) {
|
XCTAsyncTest(timeout: 5) {
|
||||||
|
@ -312,7 +304,7 @@ final class AsyncChannelTests: XCTestCase {
|
||||||
do {
|
do {
|
||||||
let strongSentinel: Sentinel? = Sentinel()
|
let strongSentinel: Sentinel? = Sentinel()
|
||||||
sentinel = strongSentinel!
|
sentinel = strongSentinel!
|
||||||
try await XCTAsyncAssertNotNil(await channel.pipeline.handler(type: NIOAsyncChannelInboundStreamChannelHandler<Sentinel>.self).get())
|
try await XCTAsyncAssertNotNil(await channel.pipeline.handler(type: NIOAsyncChannelInboundStreamChannelHandler<Sentinel, Sentinel>.self).get())
|
||||||
try await channel.writeInbound(strongSentinel!)
|
try await channel.writeInbound(strongSentinel!)
|
||||||
_ = try await channel.readInbound(as: Sentinel.self)
|
_ = try await channel.readInbound(as: Sentinel.self)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,504 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@_spi(AsyncChannel) import NIOCore
|
||||||
|
@_spi(AsyncChannel) @testable import NIOPosix
|
||||||
|
import XCTest
|
||||||
|
@_spi(AsyncChannel) import NIOTLS
|
||||||
|
import NIOConcurrencyHelpers
|
||||||
|
|
||||||
|
fileprivate final class LineDelimiterDecoder: ByteToMessageDecoder {
|
||||||
|
private let newLine = "\n".utf8.first!
|
||||||
|
|
||||||
|
typealias InboundIn = ByteBuffer
|
||||||
|
typealias InboundOut = ByteBuffer
|
||||||
|
|
||||||
|
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
|
||||||
|
let readable = buffer.withUnsafeReadableBytes { $0.firstIndex(of: newLine) }
|
||||||
|
if let readable = readable {
|
||||||
|
context.fireChannelRead(self.wrapInboundOut(buffer.readSlice(length: readable)!))
|
||||||
|
buffer.moveReaderIndex(forwardBy: 1)
|
||||||
|
return .continue
|
||||||
|
}
|
||||||
|
return .needMoreData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fileprivate final class TLSUserEventHandler: ChannelInboundHandler {
|
||||||
|
typealias InboundIn = ByteBuffer
|
||||||
|
typealias InboundOut = ByteBuffer
|
||||||
|
|
||||||
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
|
let buffer = self.unwrapInboundIn(data)
|
||||||
|
let alpn = String(buffer: buffer)
|
||||||
|
|
||||||
|
if alpn.hasPrefix("alpn:") {
|
||||||
|
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(alpn.dropFirst(5))))
|
||||||
|
} else {
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fileprivate final class ByteBufferToStringHandler: ChannelInboundHandler {
|
||||||
|
typealias InboundIn = ByteBuffer
|
||||||
|
typealias InboundOut = String
|
||||||
|
|
||||||
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
|
let buffer = self.unwrapInboundIn(data)
|
||||||
|
context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fileprivate final class ByteBufferToByteHandler: ChannelInboundHandler {
|
||||||
|
typealias InboundIn = ByteBuffer
|
||||||
|
typealias InboundOut = UInt8
|
||||||
|
|
||||||
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
|
var buffer = self.unwrapInboundIn(data)
|
||||||
|
let byte = buffer.readInteger(as: UInt8.self)!
|
||||||
|
context.fireChannelRead(self.wrapInboundOut(byte))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final class AsyncChannelBootstrapTests: XCTestCase {
|
||||||
|
enum NegotiationResult {
|
||||||
|
case string(NIOAsyncChannel<String, String>)
|
||||||
|
case byte(NIOAsyncChannel<UInt8, UInt8>)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProtocolNegotiationError: Error {}
|
||||||
|
|
||||||
|
enum StringOrByte: Hashable {
|
||||||
|
case string(String)
|
||||||
|
case byte(UInt8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAsyncChannel() throws {
|
||||||
|
XCTAsyncTest {
|
||||||
|
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
|
||||||
|
|
||||||
|
let channel = try await ServerBootstrap(group: eventLoopGroup)
|
||||||
|
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||||
|
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||||
|
.childChannelInitializer { channel in
|
||||||
|
channel.eventLoop.makeCompletedFuture {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.bind(
|
||||||
|
host: "127.0.0.1",
|
||||||
|
port: 1995,
|
||||||
|
childChannelInboundType: String.self,
|
||||||
|
childChannelOutboundType: String.self
|
||||||
|
)
|
||||||
|
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||||
|
var iterator = stream.makeAsyncIterator()
|
||||||
|
|
||||||
|
group.addTask {
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
for try await childChannel in channel.inboundStream {
|
||||||
|
for try await value in childChannel.inboundStream {
|
||||||
|
continuation.yield(.string(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||||
|
|
||||||
|
group.cancelAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAsyncChannelProtocolNegotiation() throws {
|
||||||
|
XCTAsyncTest {
|
||||||
|
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
|
||||||
|
|
||||||
|
let channel: NIOAsyncChannel<NegotiationResult, Never> = try await ServerBootstrap(group: eventLoopGroup)
|
||||||
|
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||||
|
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||||
|
.childChannelInitializer { channel in
|
||||||
|
channel.eventLoop.makeCompletedFuture {
|
||||||
|
try self.makeProtocolNegotiationChildChannel(channel: channel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.bind(
|
||||||
|
host: "127.0.0.1",
|
||||||
|
port: 1995,
|
||||||
|
protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>.self
|
||||||
|
)
|
||||||
|
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||||
|
var iterator = stream.makeAsyncIterator()
|
||||||
|
|
||||||
|
group.addTask {
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
for try await childChannel in channel.inboundStream {
|
||||||
|
group.addTask {
|
||||||
|
switch childChannel {
|
||||||
|
case .string(let channel):
|
||||||
|
for try await value in channel.inboundStream {
|
||||||
|
continuation.yield(.string(value))
|
||||||
|
}
|
||||||
|
case .byte(let channel):
|
||||||
|
for try await value in channel.inboundStream {
|
||||||
|
continuation.yield(.byte(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||||
|
|
||||||
|
let byteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
byteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
byteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
|
||||||
|
byteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
|
||||||
|
|
||||||
|
group.cancelAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAsyncChannelNestedProtocolNegotiation() throws {
|
||||||
|
XCTAsyncTest {
|
||||||
|
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
|
||||||
|
|
||||||
|
let channel: NIOAsyncChannel<NegotiationResult, Never> = try await ServerBootstrap(group: eventLoopGroup)
|
||||||
|
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||||
|
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||||
|
.childChannelInitializer { channel in
|
||||||
|
channel.eventLoop.makeCompletedFuture {
|
||||||
|
try self.makeNestedProtocolNegotiationChildChannel(channel: channel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.bind(
|
||||||
|
host: "127.0.0.1",
|
||||||
|
port: 1995,
|
||||||
|
protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>.self
|
||||||
|
)
|
||||||
|
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||||
|
var iterator = stream.makeAsyncIterator()
|
||||||
|
|
||||||
|
group.addTask {
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
for try await childChannel in channel.inboundStream {
|
||||||
|
group.addTask {
|
||||||
|
switch childChannel {
|
||||||
|
case .string(let channel):
|
||||||
|
for try await value in channel.inboundStream {
|
||||||
|
continuation.yield(.string(value))
|
||||||
|
}
|
||||||
|
case .byte(let channel):
|
||||||
|
for try await value in channel.inboundStream {
|
||||||
|
continuation.yield(.byte(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let stringStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is for negotiating the nested protocol
|
||||||
|
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||||
|
|
||||||
|
let byteByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is for negotiating the nested protocol
|
||||||
|
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
byteByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
|
||||||
|
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
|
||||||
|
|
||||||
|
let stringByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is for negotiating the nested protocol
|
||||||
|
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
stringByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
|
||||||
|
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
|
||||||
|
|
||||||
|
let byteStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is for negotiating the nested protocol
|
||||||
|
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||||
|
|
||||||
|
group.cancelAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAsyncChannelProtocolNegotiation_whenFails() throws {
|
||||||
|
final class CollectingHandler: ChannelInboundHandler {
|
||||||
|
typealias InboundIn = Channel
|
||||||
|
|
||||||
|
private let channels: NIOLockedValueBox<[Channel]>
|
||||||
|
|
||||||
|
init(channels: NIOLockedValueBox<[Channel]>) {
|
||||||
|
self.channels = channels
|
||||||
|
}
|
||||||
|
|
||||||
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||||
|
let channel = self.unwrapInboundIn(data)
|
||||||
|
|
||||||
|
self.channels.withLockedValue { $0.append(channel) }
|
||||||
|
|
||||||
|
context.fireChannelRead(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
XCTAsyncTest {
|
||||||
|
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
|
||||||
|
let channels = NIOLockedValueBox<[Channel]>([Channel]())
|
||||||
|
|
||||||
|
let channel: NIOAsyncChannel<NegotiationResult, Never> = try await ServerBootstrap(group: eventLoopGroup)
|
||||||
|
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||||
|
.serverChannelInitializer { channel in
|
||||||
|
channel.eventLoop.makeCompletedFuture {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||||
|
.childChannelInitializer { channel in
|
||||||
|
channel.eventLoop.makeCompletedFuture {
|
||||||
|
try self.makeProtocolNegotiationChildChannel(channel: channel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.bind(
|
||||||
|
host: "127.0.0.1",
|
||||||
|
port: 1995,
|
||||||
|
protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>.self
|
||||||
|
)
|
||||||
|
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||||
|
var iterator = stream.makeAsyncIterator()
|
||||||
|
|
||||||
|
group.addTask {
|
||||||
|
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||||
|
for try await childChannel in channel.inboundStream {
|
||||||
|
group.addTask {
|
||||||
|
switch childChannel {
|
||||||
|
case .string(let channel):
|
||||||
|
for try await value in channel.inboundStream {
|
||||||
|
continuation.yield(.string(value))
|
||||||
|
}
|
||||||
|
case .byte(let channel):
|
||||||
|
for try await value in channel.inboundStream {
|
||||||
|
continuation.yield(.byte(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let channel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
channel.writeAndFlush(.init(ByteBuffer(string: "alpn:unknown\n")), promise: nil)
|
||||||
|
|
||||||
|
// Checking that we can still create new connections afterwards
|
||||||
|
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup)
|
||||||
|
|
||||||
|
// This is for negotiating the protocol
|
||||||
|
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||||
|
|
||||||
|
// This is the actual content
|
||||||
|
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||||
|
|
||||||
|
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||||
|
|
||||||
|
let failedInboundChannel = channels.withLockedValue { channels -> Channel in
|
||||||
|
XCTAssertEqual(channels.count, 2)
|
||||||
|
return channels[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are waiting here to make sure the channel got closed
|
||||||
|
try await failedInboundChannel.closeFuture.get()
|
||||||
|
|
||||||
|
group.cancelAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Test Helpers
|
||||||
|
|
||||||
|
private func makeClientChannel(eventLoopGroup: EventLoopGroup) async throws -> Channel {
|
||||||
|
return try await ClientBootstrap(group: eventLoopGroup)
|
||||||
|
.channelInitializer { channel in
|
||||||
|
channel.eventLoop.makeCompletedFuture {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.connect(to: .init(ipAddress: "127.0.0.1", port: 1995))
|
||||||
|
.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeProtocolNegotiationChildChannel(channel: Channel) throws {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||||
|
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
|
||||||
|
try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeNestedProtocolNegotiationChildChannel(channel: Channel) throws {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||||
|
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
|
||||||
|
try channel.pipeline.syncOperations.addHandler(
|
||||||
|
NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
|
||||||
|
switch alpnResult {
|
||||||
|
case .negotiated(let alpn):
|
||||||
|
switch alpn {
|
||||||
|
case "string":
|
||||||
|
return channel.eventLoop.makeCompletedFuture {
|
||||||
|
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||||
|
|
||||||
|
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
|
||||||
|
}
|
||||||
|
case "byte":
|
||||||
|
return channel.eventLoop.makeCompletedFuture {
|
||||||
|
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||||
|
|
||||||
|
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||||
|
}
|
||||||
|
case .fallback:
|
||||||
|
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
|
||||||
|
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
|
||||||
|
switch alpnResult {
|
||||||
|
case .negotiated(let alpn):
|
||||||
|
switch alpn {
|
||||||
|
case "string":
|
||||||
|
return channel.eventLoop.makeCompletedFuture {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
|
||||||
|
let asyncChannel = try NIOAsyncChannel(
|
||||||
|
synchronouslyWrapping: channel,
|
||||||
|
isOutboundHalfClosureEnabled: true,
|
||||||
|
inboundType: String.self,
|
||||||
|
outboundType: String.self
|
||||||
|
)
|
||||||
|
|
||||||
|
return NIOProtocolNegotiationResult.finished(NegotiationResult.string(asyncChannel))
|
||||||
|
}
|
||||||
|
case "byte":
|
||||||
|
return channel.eventLoop.makeCompletedFuture {
|
||||||
|
try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler())
|
||||||
|
|
||||||
|
let asyncChannel = try NIOAsyncChannel(
|
||||||
|
synchronouslyWrapping: channel,
|
||||||
|
isOutboundHalfClosureEnabled: true,
|
||||||
|
inboundType: UInt8.self,
|
||||||
|
outboundType: UInt8.self
|
||||||
|
)
|
||||||
|
|
||||||
|
return NIOProtocolNegotiationResult.finished(NegotiationResult.byte(asyncChannel))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||||
|
}
|
||||||
|
case .fallback:
|
||||||
|
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
|
||||||
|
return negotiationHandler.protocolNegotiationResult
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension AsyncStream {
|
||||||
|
fileprivate static func makeStream(
|
||||||
|
of elementType: Element.Type = Element.self,
|
||||||
|
bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded
|
||||||
|
) -> (stream: AsyncStream<Element>, continuation: AsyncStream<Element>.Continuation) {
|
||||||
|
var continuation: AsyncStream<Element>.Continuation!
|
||||||
|
let stream = AsyncStream<Element>(bufferingPolicy: limit) { continuation = $0 }
|
||||||
|
return (stream: stream, continuation: continuation!)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||||
|
fileprivate func XCTAsyncAssertEqual<Element: Equatable>(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows {
|
||||||
|
let lhsResult = try await lhs()
|
||||||
|
let rhsResult = try await rhs()
|
||||||
|
XCTAssertEqual(lhsResult, rhsResult, file: file, line: line)
|
||||||
|
}
|
|
@ -0,0 +1,209 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@_spi(AsyncChannel) import NIOTLS
|
||||||
|
@_spi(AsyncChannel) import NIOCore
|
||||||
|
import NIOEmbedded
|
||||||
|
import XCTest
|
||||||
|
import NIOTestUtils
|
||||||
|
|
||||||
|
final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase {
|
||||||
|
enum NegotiationResult: Equatable {
|
||||||
|
case negotiated(ALPNResult)
|
||||||
|
case failed
|
||||||
|
}
|
||||||
|
|
||||||
|
private let negotiatedEvent: TLSUserEvent = .handshakeCompleted(negotiatedProtocol: "h2")
|
||||||
|
private let negotiatedResult: ALPNResult = .negotiated("h2")
|
||||||
|
|
||||||
|
func testChannelProvidedToCallback() throws {
|
||||||
|
let emChannel = EmbeddedChannel()
|
||||||
|
let loop = emChannel.eventLoop as! EmbeddedEventLoop
|
||||||
|
var called = false
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result, channel in
|
||||||
|
called = true
|
||||||
|
XCTAssertEqual(result, self.negotiatedResult)
|
||||||
|
XCTAssertTrue(emChannel === channel)
|
||||||
|
return loop.makeSucceededFuture(.finished(.negotiated(result)))
|
||||||
|
}
|
||||||
|
|
||||||
|
try emChannel.pipeline.addHandler(handler).wait()
|
||||||
|
emChannel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
|
||||||
|
XCTAssertTrue(called)
|
||||||
|
|
||||||
|
XCTAssertEqual(try handler.protocolNegotiationResult.wait(), .finished(.negotiated(negotiatedResult)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testIgnoresUnknownUserEvents() throws {
|
||||||
|
let channel = EmbeddedChannel()
|
||||||
|
let loop = channel.eventLoop as! EmbeddedEventLoop
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
|
||||||
|
XCTFail("Negotiation fired")
|
||||||
|
return loop.makeSucceededFuture(.finished(.failed))
|
||||||
|
}
|
||||||
|
|
||||||
|
try channel.pipeline.addHandler(handler).wait()
|
||||||
|
|
||||||
|
// Fire a pair of events that should be ignored.
|
||||||
|
channel.pipeline.fireUserInboundEventTriggered("FakeEvent")
|
||||||
|
channel.pipeline.fireUserInboundEventTriggered(TLSUserEvent.shutdownCompleted)
|
||||||
|
|
||||||
|
// The channel handler should still be in the pipeline.
|
||||||
|
try channel.pipeline.assertContains(handler: handler)
|
||||||
|
|
||||||
|
XCTAssertTrue(try channel.finish().isClean)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testNoBufferingBeforeEventFires() throws {
|
||||||
|
let channel = EmbeddedChannel()
|
||||||
|
let loop = channel.eventLoop as! EmbeddedEventLoop
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
|
||||||
|
XCTFail("Should not be called")
|
||||||
|
return loop.makeSucceededFuture(.finished(.failed))
|
||||||
|
}
|
||||||
|
|
||||||
|
try channel.pipeline.addHandler(handler).wait()
|
||||||
|
|
||||||
|
// The data we write should not be buffered.
|
||||||
|
try channel.writeInbound("hello")
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "hello"))
|
||||||
|
|
||||||
|
XCTAssertTrue(try channel.finish().isClean)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBufferingWhileWaitingForFuture() throws {
|
||||||
|
let channel = EmbeddedChannel()
|
||||||
|
let loop = channel.eventLoop as! EmbeddedEventLoop
|
||||||
|
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
|
||||||
|
return continuePromise.futureResult
|
||||||
|
}
|
||||||
|
|
||||||
|
try channel.pipeline.addHandler(handler).wait()
|
||||||
|
|
||||||
|
// Fire in the event.
|
||||||
|
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
|
||||||
|
|
||||||
|
// At this point all writes should be buffered.
|
||||||
|
try channel.writeInbound("writes")
|
||||||
|
try channel.writeInbound("are")
|
||||||
|
try channel.writeInbound("buffered")
|
||||||
|
XCTAssertNoThrow(XCTAssertNil(try channel.readInbound()))
|
||||||
|
|
||||||
|
// Complete the pipeline swap.
|
||||||
|
continuePromise.succeed(.finished(.failed))
|
||||||
|
|
||||||
|
// Now everything should have been unbuffered.
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "writes"))
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "are"))
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "buffered"))
|
||||||
|
|
||||||
|
XCTAssertTrue(try channel.finish().isClean)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testNothingBufferedDoesNotFireReadCompleted() throws {
|
||||||
|
let channel = EmbeddedChannel()
|
||||||
|
let loop = channel.eventLoop as! EmbeddedEventLoop
|
||||||
|
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
|
||||||
|
continuePromise.futureResult
|
||||||
|
}
|
||||||
|
let eventCounterHandler = EventCounterHandler()
|
||||||
|
|
||||||
|
try channel.pipeline.addHandler(handler).wait()
|
||||||
|
try channel.pipeline.addHandler(eventCounterHandler).wait()
|
||||||
|
|
||||||
|
// Fire in the event.
|
||||||
|
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
|
||||||
|
|
||||||
|
// At this time, readComplete hasn't fired.
|
||||||
|
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 0)
|
||||||
|
|
||||||
|
// Now satisfy the future, which forces data unbuffering. As we haven't buffered any data,
|
||||||
|
// readComplete should not be fired.
|
||||||
|
continuePromise.succeed(.finished(.failed))
|
||||||
|
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 0)
|
||||||
|
|
||||||
|
XCTAssertTrue(try channel.finish().isClean)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnbufferingFiresReadCompleted() throws {
|
||||||
|
let channel = EmbeddedChannel()
|
||||||
|
let loop = channel.eventLoop as! EmbeddedEventLoop
|
||||||
|
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
|
||||||
|
continuePromise.futureResult
|
||||||
|
}
|
||||||
|
let eventCounterHandler = EventCounterHandler()
|
||||||
|
|
||||||
|
try channel.pipeline.addHandler(handler).wait()
|
||||||
|
try channel.pipeline.addHandler(eventCounterHandler).wait()
|
||||||
|
|
||||||
|
// Fire in the event.
|
||||||
|
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
|
||||||
|
|
||||||
|
// Send a write, which is buffered.
|
||||||
|
try channel.writeInbound("a write")
|
||||||
|
|
||||||
|
// At this time, readComplete hasn't fired.
|
||||||
|
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1)
|
||||||
|
|
||||||
|
// Now satisfy the future, which forces data unbuffering. This should fire readComplete.
|
||||||
|
continuePromise.succeed(.finished(.failed))
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write"))
|
||||||
|
|
||||||
|
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 2)
|
||||||
|
|
||||||
|
XCTAssertTrue(try channel.finish().isClean)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnbufferingHandlesReentrantReads() throws {
|
||||||
|
let channel = EmbeddedChannel()
|
||||||
|
let loop = channel.eventLoop as! EmbeddedEventLoop
|
||||||
|
let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult<NegotiationResult>.self)
|
||||||
|
|
||||||
|
let handler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: loop) { result in
|
||||||
|
continuePromise.futureResult
|
||||||
|
}
|
||||||
|
let eventCounterHandler = EventCounterHandler()
|
||||||
|
|
||||||
|
try channel.pipeline.addHandler(handler).wait()
|
||||||
|
try channel.pipeline.addHandler(DuplicatingReadHandler(embeddedChannel: channel)).wait()
|
||||||
|
try channel.pipeline.addHandler(eventCounterHandler).wait()
|
||||||
|
|
||||||
|
// Fire in the event.
|
||||||
|
channel.pipeline.fireUserInboundEventTriggered(negotiatedEvent)
|
||||||
|
|
||||||
|
// Send a write, which is buffered.
|
||||||
|
try channel.writeInbound("a write")
|
||||||
|
|
||||||
|
// At this time, readComplete hasn't fired.
|
||||||
|
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1)
|
||||||
|
|
||||||
|
// Now satisfy the future, which forces data unbuffering. This should fire readComplete.
|
||||||
|
continuePromise.succeed(.finished(.failed))
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write"))
|
||||||
|
XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write"))
|
||||||
|
|
||||||
|
XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 3)
|
||||||
|
|
||||||
|
XCTAssertTrue(try channel.finish().isClean)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue