fix Channel thread-safety (local/remoteAddress)
Motivation: There were a couple of places in the Channel implementations that just weren't thread-safe at all, namely: - localAddress - remoetAddress - parent those are now fixed. I also re-ordered the code so it should be easier to maintain in the future (the `// MARK` markers were totally incorrect). Modifications: - made Channel.{local,remote}Address return a future - made Channel.parent a `let` - unified more of `SocketChannel` and `DatagramSocketChannel` - fixed the `// MARK`s by reordering code - annotated methods/properties that are in the Channel API and need to be thread-safe Result: slightly more thread-safety :)
This commit is contained in:
parent
273932ae36
commit
68724be6e6
|
@ -21,6 +21,19 @@ protocol SockAddrProtocol {
|
||||||
mutating func withMutableSockAddr<R>(_ fn: (UnsafeMutablePointer<sockaddr>, Int) throws -> R) rethrows -> R
|
mutating func withMutableSockAddr<R>(_ fn: (UnsafeMutablePointer<sockaddr>, Int) throws -> R) rethrows -> R
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func descriptionForAddress(family: CInt, bytes: UnsafeRawPointer, length byteCount: Int) -> String {
|
||||||
|
var addressBytes: [Int8] = Array(repeating: 0, count: byteCount)
|
||||||
|
return addressBytes.withUnsafeMutableBufferPointer { (addressBytesPtr: inout UnsafeMutableBufferPointer<Int8>) -> String in
|
||||||
|
try! Posix.inet_ntop(addressFamily: family,
|
||||||
|
addressBytes: bytes,
|
||||||
|
addressDescription: addressBytesPtr.baseAddress!,
|
||||||
|
addressDescriptionLength: socklen_t(byteCount))
|
||||||
|
return addressBytesPtr.baseAddress!.withMemoryRebound(to: UInt8.self, capacity: byteCount) { addressBytesPtr -> String in
|
||||||
|
return String(decoding: UnsafeBufferPointer<UInt8>(start: addressBytesPtr, count: byteCount), as: Unicode.ASCII.self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extension sockaddr_in: SockAddrProtocol {
|
extension sockaddr_in: SockAddrProtocol {
|
||||||
mutating func withSockAddr<R>(_ fn: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R {
|
mutating func withSockAddr<R>(_ fn: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R {
|
||||||
var me = self
|
var me = self
|
||||||
|
@ -35,6 +48,12 @@ extension sockaddr_in: SockAddrProtocol {
|
||||||
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
|
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mutating func addressDescription() -> String {
|
||||||
|
return withUnsafePointer(to: &self.sin_addr) { addrPtr in
|
||||||
|
descriptionForAddress(family: AF_INET, bytes: addrPtr, length: Int(INET_ADDRSTRLEN))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension sockaddr_in6: SockAddrProtocol {
|
extension sockaddr_in6: SockAddrProtocol {
|
||||||
|
@ -51,6 +70,12 @@ extension sockaddr_in6: SockAddrProtocol {
|
||||||
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
|
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mutating func addressDescription() -> String {
|
||||||
|
return withUnsafePointer(to: &self.sin6_addr) { addrPtr in
|
||||||
|
descriptionForAddress(family: AF_INET6, bytes: addrPtr, length: Int(INET6_ADDRSTRLEN))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension sockaddr_un: SockAddrProtocol {
|
extension sockaddr_un: SockAddrProtocol {
|
||||||
|
@ -119,54 +144,45 @@ extension sockaddr_storage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mutating func convert() -> SocketAddress {
|
||||||
|
switch self.ss_family {
|
||||||
|
case Posix.AF_INET:
|
||||||
|
var sockAddr: sockaddr_in = self.convert()
|
||||||
|
return SocketAddress(sockAddr, host: sockAddr.addressDescription())
|
||||||
|
case Posix.AF_INET6:
|
||||||
|
var sockAddr: sockaddr_in6 = self.convert()
|
||||||
|
return SocketAddress(sockAddr, host: sockAddr.addressDescription())
|
||||||
|
case Posix.AF_UNIX:
|
||||||
|
return SocketAddress(self.convert() as sockaddr_un)
|
||||||
|
default:
|
||||||
|
fatalError("unknown sockaddr family \(self.ss_family)")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseSocket: Selectable {
|
class BaseSocket: Selectable {
|
||||||
public let descriptor: Int32
|
public let descriptor: Int32
|
||||||
public private(set) var open: Bool
|
public private(set) var open: Bool
|
||||||
|
|
||||||
final var localAddress: SocketAddress? {
|
final func localAddress() throws -> SocketAddress {
|
||||||
get {
|
return try get_addr { try Posix.getsockname(socket: $0, address: $1, addressLength: $2) }
|
||||||
return get_addr { getsockname($0, $1, $2) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final var remoteAddress: SocketAddress? {
|
final func remoteAddress() throws -> SocketAddress {
|
||||||
get {
|
return try get_addr { try Posix.getpeername(socket: $0, address: $1, addressLength: $2) }
|
||||||
return get_addr { getpeername($0, $1, $2) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private func get_addr(_ fn: (Int32, UnsafeMutablePointer<sockaddr>, UnsafeMutablePointer<socklen_t>) -> Int32) -> SocketAddress? {
|
private func get_addr(_ fn: (Int32, UnsafeMutablePointer<sockaddr>, UnsafeMutablePointer<socklen_t>) throws -> Void) throws -> SocketAddress {
|
||||||
var addr = sockaddr_storage()
|
var addr = sockaddr_storage()
|
||||||
var len: socklen_t = socklen_t(MemoryLayout<sockaddr_storage>.size)
|
|
||||||
|
try addr.withMutableSockAddr { addressPtr, size in
|
||||||
return withUnsafeMutablePointer(to: &addr) {
|
var size = socklen_t(size)
|
||||||
$0.withMemoryRebound(to: sockaddr.self, capacity: 1, { address in
|
try fn(self.descriptor, addressPtr, &size)
|
||||||
guard fn(descriptor, address, &len) == 0 else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
switch Int32(address.pointee.sa_family) {
|
|
||||||
case AF_INET:
|
|
||||||
return address.withMemoryRebound(to: sockaddr_in.self, capacity: 1, { ipv4 in
|
|
||||||
var ipAddressString = [CChar](repeating: 0, count: Int(INET_ADDRSTRLEN))
|
|
||||||
return SocketAddress(ipv4.pointee, host: String(cString: inet_ntop(AF_INET, &ipv4.pointee.sin_addr, &ipAddressString, socklen_t(INET_ADDRSTRLEN))))
|
|
||||||
})
|
|
||||||
case AF_INET6:
|
|
||||||
return address.withMemoryRebound(to: sockaddr_in6.self, capacity: 1, { ipv6 in
|
|
||||||
var ipAddressString = [CChar](repeating: 0, count: Int(INET6_ADDRSTRLEN))
|
|
||||||
return SocketAddress(ipv6.pointee, host: String(cString: inet_ntop(AF_INET6, &ipv6.pointee.sin6_addr, &ipAddressString, socklen_t(INET6_ADDRSTRLEN))))
|
|
||||||
})
|
|
||||||
case AF_UNIX:
|
|
||||||
return address.withMemoryRebound(to: sockaddr_un.self, capacity: 1) { uds in
|
|
||||||
return SocketAddress(uds.pointee)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
fatalError("address family \(address.pointee.sa_family) not supported")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
return addr.convert()
|
||||||
}
|
}
|
||||||
|
|
||||||
static func newSocket(protocolFamily: Int32, type: CInt) throws -> Int32 {
|
static func newSocket(protocolFamily: Int32, type: CInt) throws -> Int32 {
|
||||||
let sock = try Posix.socket(domain: protocolFamily,
|
let sock = try Posix.socket(domain: protocolFamily,
|
||||||
type: type,
|
type: type,
|
||||||
|
|
|
@ -18,6 +18,8 @@ import NIOConcurrencyHelpers
|
||||||
///
|
///
|
||||||
/// - note: All methods must be called from the EventLoop thread
|
/// - note: All methods must be called from the EventLoop thread
|
||||||
public protocol ChannelCore : class {
|
public protocol ChannelCore : class {
|
||||||
|
func localAddress0() throws -> SocketAddress
|
||||||
|
func remoteAddress0() throws -> SocketAddress
|
||||||
func register0(promise: EventLoopPromise<Void>?)
|
func register0(promise: EventLoopPromise<Void>?)
|
||||||
func bind0(to: SocketAddress, promise: EventLoopPromise<Void>?)
|
func bind0(to: SocketAddress, promise: EventLoopPromise<Void>?)
|
||||||
func connect0(to: SocketAddress, promise: EventLoopPromise<Void>?)
|
func connect0(to: SocketAddress, promise: EventLoopPromise<Void>?)
|
||||||
|
@ -82,7 +84,9 @@ public protocol Channel : class, ChannelOutboundInvoker {
|
||||||
|
|
||||||
/// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events
|
/// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events
|
||||||
/// before possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`.
|
/// before possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`.
|
||||||
protocol SelectableChannel : Channel {
|
///
|
||||||
|
/// - warning: `SelectableChannel` methods and properties are _not_ thread-safe (unless they also belong to `Channel`).
|
||||||
|
internal protocol SelectableChannel : Channel {
|
||||||
/// The type of the `Selectable`. A `Selectable` is usually wrapping a file descriptor that can be registered in a
|
/// The type of the `Selectable`. A `Selectable` is usually wrapping a file descriptor that can be registered in a
|
||||||
/// `Selector`.
|
/// `Selector`.
|
||||||
associatedtype SelectableType: Selectable
|
associatedtype SelectableType: Selectable
|
||||||
|
|
|
@ -809,6 +809,14 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
||||||
return self.inboundHandler ?? self.outboundHandler!
|
return self.inboundHandler ?? self.outboundHandler!
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public var remoteAddress: SocketAddress? {
|
||||||
|
return try? self.channel._unsafe.remoteAddress0()
|
||||||
|
}
|
||||||
|
|
||||||
|
public var localAddress: SocketAddress? {
|
||||||
|
return try? self.channel._unsafe.localAddress0()
|
||||||
|
}
|
||||||
|
|
||||||
public let name: String
|
public let name: String
|
||||||
public let eventLoop: EventLoop
|
public let eventLoop: EventLoop
|
||||||
private let inboundHandler: _ChannelInboundHandler?
|
private let inboundHandler: _ChannelInboundHandler?
|
||||||
|
|
|
@ -16,6 +16,14 @@
|
||||||
/// the original `Channel` is closed. Given that the original `Channel` is closed the `DeadChannelCore` should fail
|
/// the original `Channel` is closed. Given that the original `Channel` is closed the `DeadChannelCore` should fail
|
||||||
/// all operations.
|
/// all operations.
|
||||||
private final class DeadChannelCore: ChannelCore {
|
private final class DeadChannelCore: ChannelCore {
|
||||||
|
func localAddress0() throws -> SocketAddress {
|
||||||
|
throw ChannelError.ioOnClosedChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
func remoteAddress0() throws -> SocketAddress {
|
||||||
|
throw ChannelError.ioOnClosedChannel
|
||||||
|
}
|
||||||
|
|
||||||
func register0(promise: EventLoopPromise<Void>?) {
|
func register0(promise: EventLoopPromise<Void>?) {
|
||||||
promise?.fail(error: ChannelError.ioOnClosedChannel)
|
promise?.fail(error: ChannelError.ioOnClosedChannel)
|
||||||
}
|
}
|
||||||
|
@ -65,27 +73,30 @@ private final class DeadChannelCore: ChannelCore {
|
||||||
/// channel as it only holds an unowned reference to the original `Channel`. `DeadChannel` serves as a replacement
|
/// channel as it only holds an unowned reference to the original `Channel`. `DeadChannel` serves as a replacement
|
||||||
/// that can be used when the original `Channel` might no longer be valid.
|
/// that can be used when the original `Channel` might no longer be valid.
|
||||||
internal final class DeadChannel: Channel {
|
internal final class DeadChannel: Channel {
|
||||||
|
let eventLoop: EventLoop
|
||||||
let pipeline: ChannelPipeline
|
let pipeline: ChannelPipeline
|
||||||
|
|
||||||
var eventLoop: EventLoop {
|
public var closeFuture: EventLoopFuture<()> {
|
||||||
return self.pipeline.eventLoop
|
return self.eventLoop.newSucceededFuture(result: ())
|
||||||
}
|
}
|
||||||
|
|
||||||
internal init(pipeline: ChannelPipeline) {
|
internal init(pipeline: ChannelPipeline) {
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
self.eventLoop = pipeline.eventLoop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
var allocator: ByteBufferAllocator {
|
var allocator: ByteBufferAllocator {
|
||||||
return ByteBufferAllocator()
|
return ByteBufferAllocator()
|
||||||
}
|
}
|
||||||
|
|
||||||
var closeFuture: EventLoopFuture<Void> {
|
var localAddress: SocketAddress? {
|
||||||
return self.pipeline.eventLoop.newSucceededFuture(result: ())
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
let localAddress: SocketAddress? = nil
|
var remoteAddress: SocketAddress? {
|
||||||
|
return nil
|
||||||
let remoteAddress: SocketAddress? = nil
|
}
|
||||||
|
|
||||||
let parent: Channel? = nil
|
let parent: Channel? = nil
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,6 @@ public class EmbeddedEventLoop: EventLoop {
|
||||||
class EmbeddedChannelCore : ChannelCore {
|
class EmbeddedChannelCore : ChannelCore {
|
||||||
var closed: Bool = false
|
var closed: Bool = false
|
||||||
var isActive: Bool = false
|
var isActive: Bool = false
|
||||||
|
|
||||||
|
|
||||||
var eventLoop: EventLoop
|
var eventLoop: EventLoop
|
||||||
var closePromise: EventLoopPromise<Void>
|
var closePromise: EventLoopPromise<Void>
|
||||||
|
@ -163,6 +162,14 @@ class EmbeddedChannelCore : ChannelCore {
|
||||||
var outboundBuffer: [IOData] = []
|
var outboundBuffer: [IOData] = []
|
||||||
var inboundBuffer: [NIOAny] = []
|
var inboundBuffer: [NIOAny] = []
|
||||||
|
|
||||||
|
func localAddress0() throws -> SocketAddress {
|
||||||
|
throw ChannelError.operationUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func remoteAddress0() throws -> SocketAddress {
|
||||||
|
throw ChannelError.operationUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||||
if closed {
|
if closed {
|
||||||
promise?.fail(error: ChannelError.alreadyClosed)
|
promise?.fail(error: ChannelError.alreadyClosed)
|
||||||
|
@ -269,8 +276,8 @@ public class EmbeddedChannel : Channel {
|
||||||
public var allocator: ByteBufferAllocator = ByteBufferAllocator()
|
public var allocator: ByteBufferAllocator = ByteBufferAllocator()
|
||||||
public var eventLoop: EventLoop = EmbeddedEventLoop()
|
public var eventLoop: EventLoop = EmbeddedEventLoop()
|
||||||
|
|
||||||
public var localAddress: SocketAddress? = nil
|
public let localAddress: SocketAddress? = nil
|
||||||
public var remoteAddress: SocketAddress? = nil
|
public let remoteAddress: SocketAddress? = nil
|
||||||
|
|
||||||
// Embedded channels never have parents.
|
// Embedded channels never have parents.
|
||||||
public let parent: Channel? = nil
|
public let parent: Channel? = nil
|
||||||
|
|
|
@ -465,6 +465,7 @@ internal protocol PendingWritesManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
extension PendingWritesManager {
|
extension PendingWritesManager {
|
||||||
|
// This is called from `Channel` API so must be thread-safe.
|
||||||
var isWritable: Bool {
|
var isWritable: Bool {
|
||||||
return self.channelWritabilityFlag.load()
|
return self.channelWritabilityFlag.load()
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Represents a selectable resource which can be registered to a `Selector` to be notified once there are some events ready for it.
|
/// Represents a selectable resource which can be registered to a `Selector` to be notified once there are some events ready for it.
|
||||||
|
///
|
||||||
|
/// - warning: `Selectable`s are not thread-safe, only to be used on the appropriate `EventLoop`.
|
||||||
protocol Selectable {
|
protocol Selectable {
|
||||||
|
|
||||||
/// The file descriptor itself.
|
/// The file descriptor itself.
|
||||||
|
|
|
@ -43,6 +43,38 @@ private extension ByteBuffer {
|
||||||
class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
typealias SelectableType = T
|
typealias SelectableType = T
|
||||||
|
|
||||||
|
// MARK: Stored Properties
|
||||||
|
// Visible to access from EventLoop directly
|
||||||
|
public let parent: Channel?
|
||||||
|
internal let socket: T
|
||||||
|
private let closePromise: EventLoopPromise<Void>
|
||||||
|
private let selectableEventLoop: SelectableEventLoop
|
||||||
|
private let localAddressCached: AtomicBox<Box<SocketAddress?>> = AtomicBox(value: Box(nil))
|
||||||
|
private let remoteAddressCached: AtomicBox<Box<SocketAddress?>> = AtomicBox(value: Box(nil))
|
||||||
|
private let bufferAllocatorCached: AtomicBox<Box<ByteBufferAllocator>>
|
||||||
|
|
||||||
|
internal var interestedEvent: IOEvent = .none
|
||||||
|
|
||||||
|
fileprivate var readPending = false
|
||||||
|
fileprivate var pendingConnect: EventLoopPromise<Void>?
|
||||||
|
fileprivate var recvAllocator: RecvByteBufferAllocator
|
||||||
|
fileprivate var maxMessagesPerRead: UInt = 4
|
||||||
|
|
||||||
|
private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method.
|
||||||
|
private var neverRegistered = true
|
||||||
|
private var active: Atomic<Bool> = Atomic(value: false)
|
||||||
|
private var _closed: Bool = false
|
||||||
|
private var autoRead: Bool = true
|
||||||
|
private var _pipeline: ChannelPipeline!
|
||||||
|
private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator() {
|
||||||
|
didSet {
|
||||||
|
assert(self.eventLoop.inEventLoop)
|
||||||
|
self.bufferAllocatorCached.store(Box(self.bufferAllocator))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Datatypes
|
||||||
|
|
||||||
/// Indicates if a selectable should registered or not for IO notifications.
|
/// Indicates if a selectable should registered or not for IO notifications.
|
||||||
enum IONotificationState {
|
enum IONotificationState {
|
||||||
/// We should be registered for IO notifications.
|
/// We should be registered for IO notifications.
|
||||||
|
@ -52,7 +84,26 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
case unregister
|
case unregister
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: Methods to override in subclasses.
|
fileprivate enum ReadResult {
|
||||||
|
/// Nothing was read by the read operation.
|
||||||
|
case none
|
||||||
|
|
||||||
|
/// Some data was read by the read operation.
|
||||||
|
case some
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Computed Properties
|
||||||
|
public final var _unsafe: ChannelCore { return self }
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public final var localAddress: SocketAddress? {
|
||||||
|
return self.localAddressCached.load().value
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public final var remoteAddress: SocketAddress? {
|
||||||
|
return self.remoteAddressCached.load().value
|
||||||
|
}
|
||||||
|
|
||||||
/// `true` if the whole `Channel` is closed and so no more IO operation can be done.
|
/// `true` if the whole `Channel` is closed and so no more IO operation can be done.
|
||||||
public var closed: Bool {
|
public var closed: Bool {
|
||||||
|
@ -60,6 +111,48 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
return _closed
|
return _closed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal var selectable: T {
|
||||||
|
return self.socket
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public var isActive: Bool {
|
||||||
|
return self.active.load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public final var closeFuture: EventLoopFuture<Void> {
|
||||||
|
return self.closePromise.futureResult
|
||||||
|
}
|
||||||
|
|
||||||
|
public final var eventLoop: EventLoop {
|
||||||
|
return selectableEventLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public var isWritable: Bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public final var allocator: ByteBufferAllocator {
|
||||||
|
if eventLoop.inEventLoop {
|
||||||
|
return bufferAllocator
|
||||||
|
} else {
|
||||||
|
return self.bufferAllocatorCached.load().value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
|
public final var pipeline: ChannelPipeline {
|
||||||
|
return _pipeline
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Methods to override in subclasses.
|
||||||
|
func writeToSocket() throws -> OverallWriteResult {
|
||||||
|
fatalError("must be overridden")
|
||||||
|
}
|
||||||
|
|
||||||
/// Provides the registration for this selector. Must be implemented by subclasses.
|
/// Provides the registration for this selector. Must be implemented by subclasses.
|
||||||
func registrationFor(interested: IOEvent) -> NIORegistration {
|
func registrationFor(interested: IOEvent) -> NIORegistration {
|
||||||
fatalError("must override")
|
fatalError("must override")
|
||||||
|
@ -72,14 +165,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
fatalError("this must be overridden by sub class")
|
fatalError("this must be overridden by sub class")
|
||||||
}
|
}
|
||||||
|
|
||||||
fileprivate enum ReadResult {
|
|
||||||
/// Nothing was read by the read operation.
|
|
||||||
case none
|
|
||||||
|
|
||||||
/// Some data was read by the read operation.
|
|
||||||
case some
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Begin connection of the underlying socket.
|
/// Begin connection of the underlying socket.
|
||||||
///
|
///
|
||||||
/// - parameters:
|
/// - parameters:
|
||||||
|
@ -111,70 +196,78 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
fatalError("this must be overridden by sub class")
|
fatalError("this must be overridden by sub class")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: Common base socket logic.
|
||||||
|
fileprivate init(socket: T, parent: Channel? = nil, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
|
||||||
|
self.bufferAllocatorCached = AtomicBox(value: Box(self.bufferAllocator))
|
||||||
|
self.socket = socket
|
||||||
|
self.selectableEventLoop = eventLoop
|
||||||
|
self.closePromise = eventLoop.newPromise()
|
||||||
|
self.parent = parent
|
||||||
|
self.active.store(false)
|
||||||
|
self.recvAllocator = recvAllocator
|
||||||
|
self._pipeline = ChannelPipeline(channel: self)
|
||||||
|
}
|
||||||
|
|
||||||
|
deinit {
|
||||||
|
assert(self._closed, "leak of open Channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
public final func localAddress0() throws -> SocketAddress {
|
||||||
|
assert(self.eventLoop.inEventLoop)
|
||||||
|
guard self.open else {
|
||||||
|
throw ChannelError.ioOnClosedChannel
|
||||||
|
}
|
||||||
|
return try self.socket.localAddress()
|
||||||
|
}
|
||||||
|
|
||||||
|
public final func remoteAddress0() throws -> SocketAddress {
|
||||||
|
assert(self.eventLoop.inEventLoop)
|
||||||
|
guard self.open else {
|
||||||
|
throw ChannelError.ioOnClosedChannel
|
||||||
|
}
|
||||||
|
return try self.socket.remoteAddress()
|
||||||
|
}
|
||||||
|
|
||||||
/// Flush data to the underlying socket and return if this socket needs to be registered for write notifications.
|
/// Flush data to the underlying socket and return if this socket needs to be registered for write notifications.
|
||||||
///
|
///
|
||||||
/// - returns: If this socket should be registered for write notifications. Ie. `IONotificationState.register` if _not_ all data could be written, so notifications are necessary; and `IONotificationState.unregister` if everything was written and we don't need to be notified about writability at the moment.
|
/// - returns: If this socket should be registered for write notifications. Ie. `IONotificationState.register` if _not_ all data could be written, so notifications are necessary; and `IONotificationState.unregister` if everything was written and we don't need to be notified about writability at the moment.
|
||||||
fileprivate func flushNow() -> IONotificationState {
|
fileprivate func flushNow() -> IONotificationState {
|
||||||
fatalError("this must be overridden by sub class")
|
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
|
||||||
}
|
// `writeToSocket`.
|
||||||
|
guard !inFlushNow && !closed else {
|
||||||
|
return .unregister
|
||||||
|
}
|
||||||
|
|
||||||
// MARK: Common base socket logic.
|
defer {
|
||||||
|
inFlushNow = false
|
||||||
|
}
|
||||||
|
inFlushNow = true
|
||||||
|
|
||||||
var selectable: T {
|
do {
|
||||||
return self.socket
|
switch try self.writeToSocket() {
|
||||||
}
|
case .couldNotWriteEverything:
|
||||||
|
return .register
|
||||||
|
case .writtenCompletely:
|
||||||
|
return .unregister
|
||||||
|
}
|
||||||
|
} catch let err {
|
||||||
|
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
|
||||||
|
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
|
||||||
|
if readIfNeeded0() {
|
||||||
|
|
||||||
public final var _unsafe: ChannelCore { return self }
|
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
|
||||||
|
while let read = try? readFromSocket(), read == .some {
|
||||||
|
pipeline.fireChannelReadComplete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Visible to access from EventLoop directly
|
close0(error: err, mode: .all, promise: nil)
|
||||||
let socket: T
|
|
||||||
public var interestedEvent: IOEvent = .none
|
|
||||||
|
|
||||||
fileprivate var readPending = false
|
// we handled all writes
|
||||||
private var neverRegistered = true
|
return .unregister
|
||||||
fileprivate var pendingConnect: EventLoopPromise<Void>?
|
|
||||||
private let closePromise: EventLoopPromise<Void>
|
|
||||||
private var active: Atomic<Bool> = Atomic(value: false)
|
|
||||||
private var _closed: Bool = false
|
|
||||||
public var isActive: Bool {
|
|
||||||
return active.load()
|
|
||||||
}
|
|
||||||
|
|
||||||
public var parent: Channel? = nil
|
|
||||||
|
|
||||||
public final var closeFuture: EventLoopFuture<Void> {
|
|
||||||
return closePromise.futureResult
|
|
||||||
}
|
|
||||||
|
|
||||||
private let selectableEventLoop: SelectableEventLoop
|
|
||||||
|
|
||||||
public final var eventLoop: EventLoop {
|
|
||||||
return selectableEventLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
public var isWritable: Bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
public final var allocator: ByteBufferAllocator {
|
|
||||||
if eventLoop.inEventLoop {
|
|
||||||
return bufferAllocator
|
|
||||||
} else {
|
|
||||||
return try! eventLoop.submit{ self.bufferAllocator }.wait()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator()
|
|
||||||
fileprivate var recvAllocator: RecvByteBufferAllocator
|
|
||||||
fileprivate var autoRead: Bool = true
|
|
||||||
fileprivate var maxMessagesPerRead: UInt = 4
|
|
||||||
|
|
||||||
// We don't use lazy var here as this is more expensive then doing this :/
|
|
||||||
public final var pipeline: ChannelPipeline {
|
|
||||||
return _pipeline
|
|
||||||
}
|
|
||||||
|
|
||||||
private var _pipeline: ChannelPipeline!
|
|
||||||
|
|
||||||
public final func setOption<T: ChannelOption>(option: T, value: T.OptionType) -> EventLoopFuture<Void> {
|
public final func setOption<T: ChannelOption>(option: T, value: T.OptionType) -> EventLoopFuture<Void> {
|
||||||
if eventLoop.inEventLoop {
|
if eventLoop.inEventLoop {
|
||||||
|
@ -240,18 +333,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public final var localAddress: SocketAddress? {
|
|
||||||
get {
|
|
||||||
return socket.localAddress
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public final var remoteAddress: SocketAddress? {
|
|
||||||
get {
|
|
||||||
return socket.remoteAddress
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.`
|
/// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.`
|
||||||
///
|
///
|
||||||
/// - returns: `true` if `readPending` is `true`, `false` otherwise.
|
/// - returns: `true` if `readPending` is `true`, `false` otherwise.
|
||||||
|
@ -275,6 +356,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
|
|
||||||
executeAndComplete(promise) {
|
executeAndComplete(promise) {
|
||||||
try socket.bind(to: address)
|
try socket.bind(to: address)
|
||||||
|
self.updateCachedAddressesFromSocket(updateRemote: false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,6 +483,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
|
|
||||||
// Fail all pending writes and so ensure all pending promises are notified
|
// Fail all pending writes and so ensure all pending promises are notified
|
||||||
self._closed = true
|
self._closed = true
|
||||||
|
self.unsetCachedAddressesFromSocket()
|
||||||
self.cancelWritesOnClose(error: error)
|
self.cancelWritesOnClose(error: error)
|
||||||
|
|
||||||
becomeInactive0()
|
becomeInactive0()
|
||||||
|
@ -515,6 +598,22 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
readIfNeeded0()
|
readIfNeeded0()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal final func updateCachedAddressesFromSocket(updateLocal: Bool = true, updateRemote: Bool = true) {
|
||||||
|
assert(self.eventLoop.inEventLoop)
|
||||||
|
if updateLocal {
|
||||||
|
self.localAddressCached.store(Box(try? self.localAddress0()))
|
||||||
|
}
|
||||||
|
if updateRemote {
|
||||||
|
self.remoteAddressCached.store(Box(try? self.remoteAddress0()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal final func unsetCachedAddressesFromSocket() {
|
||||||
|
assert(self.eventLoop.inEventLoop)
|
||||||
|
self.localAddressCached.store(Box(nil))
|
||||||
|
self.remoteAddressCached.store(Box(nil))
|
||||||
|
}
|
||||||
|
|
||||||
public final func connect0(to address: SocketAddress, promise: EventLoopPromise<Void>?) {
|
public final func connect0(to address: SocketAddress, promise: EventLoopPromise<Void>?) {
|
||||||
assert(eventLoop.inEventLoop)
|
assert(eventLoop.inEventLoop)
|
||||||
|
|
||||||
|
@ -529,6 +628,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
}
|
}
|
||||||
do {
|
do {
|
||||||
if try !connectSocket(to: address) {
|
if try !connectSocket(to: address) {
|
||||||
|
self.updateCachedAddressesFromSocket()
|
||||||
if promise != nil {
|
if promise != nil {
|
||||||
pendingConnect = promise
|
pendingConnect = promise
|
||||||
} else {
|
} else {
|
||||||
|
@ -536,6 +636,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
}
|
}
|
||||||
registerForWritable()
|
registerForWritable()
|
||||||
} else {
|
} else {
|
||||||
|
self.updateCachedAddressesFromSocket()
|
||||||
promise?.succeed(result: ())
|
promise?.succeed(result: ())
|
||||||
}
|
}
|
||||||
} catch let error {
|
} catch let error {
|
||||||
|
@ -602,19 +703,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
||||||
active.store(false)
|
active.store(false)
|
||||||
pipeline.fireChannelInactive0()
|
pipeline.fireChannelInactive0()
|
||||||
}
|
}
|
||||||
|
|
||||||
fileprivate init(socket: T, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
|
|
||||||
self.socket = socket
|
|
||||||
self.selectableEventLoop = eventLoop
|
|
||||||
self.closePromise = eventLoop.newPromise()
|
|
||||||
active.store(false)
|
|
||||||
self.recvAllocator = recvAllocator
|
|
||||||
self._pipeline = ChannelPipeline(channel: self)
|
|
||||||
}
|
|
||||||
|
|
||||||
deinit {
|
|
||||||
assert(self._closed, "leak of open Channel")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `Channel` for a client socket.
|
/// A `Channel` for a client socket.
|
||||||
|
@ -628,10 +716,9 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
||||||
private var inputShutdown: Bool = false
|
private var inputShutdown: Bool = false
|
||||||
private var outputShutdown: Bool = false
|
private var outputShutdown: Bool = false
|
||||||
|
|
||||||
// Guard against re-entrance of flushNow() method.
|
|
||||||
private var inFlushNow: Bool = false
|
|
||||||
private let pendingWrites: PendingStreamWritesManager
|
private let pendingWrites: PendingStreamWritesManager
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
override public var isWritable: Bool {
|
override public var isWritable: Bool {
|
||||||
return pendingWrites.isWritable
|
return pendingWrites.isWritable
|
||||||
}
|
}
|
||||||
|
@ -694,15 +781,10 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
||||||
return .socketChannel(self, interested)
|
return .socketChannel(self, interested)
|
||||||
}
|
}
|
||||||
|
|
||||||
fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws {
|
fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
|
||||||
try socket.setNonBlocking()
|
try socket.setNonBlocking()
|
||||||
self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs)
|
self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs)
|
||||||
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator())
|
try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator())
|
||||||
}
|
|
||||||
|
|
||||||
fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws {
|
|
||||||
try self.init(socket: socket, eventLoop: eventLoop)
|
|
||||||
self.parent = parent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fileprivate func readFromSocket() throws -> ReadResult {
|
override fileprivate func readFromSocket() throws -> ReadResult {
|
||||||
|
@ -749,8 +831,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
private func writeToSocket(pendingWrites: PendingStreamWritesManager) throws -> OverallWriteResult {
|
override func writeToSocket() throws -> OverallWriteResult {
|
||||||
let result = try pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in
|
let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in
|
||||||
guard ptr.count > 0 else {
|
guard ptr.count > 0 else {
|
||||||
// No need to call write if the buffer is empty.
|
// No need to call write if the buffer is empty.
|
||||||
return .processed(0)
|
return .processed(0)
|
||||||
|
@ -882,43 +964,6 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
||||||
pipeline.fireChannelWritabilityChanged0()
|
pipeline.fireChannelWritabilityChanged0()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fileprivate func flushNow() -> IONotificationState {
|
|
||||||
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
|
|
||||||
// `writeToSocket`.
|
|
||||||
guard !inFlushNow && !closed else {
|
|
||||||
return .unregister
|
|
||||||
}
|
|
||||||
|
|
||||||
defer {
|
|
||||||
inFlushNow = false
|
|
||||||
}
|
|
||||||
inFlushNow = true
|
|
||||||
|
|
||||||
do {
|
|
||||||
switch try self.writeToSocket(pendingWrites: pendingWrites) {
|
|
||||||
case .couldNotWriteEverything:
|
|
||||||
return .register
|
|
||||||
case .writtenCompletely:
|
|
||||||
return .unregister
|
|
||||||
}
|
|
||||||
} catch let err {
|
|
||||||
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
|
|
||||||
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
|
|
||||||
if readIfNeeded0() {
|
|
||||||
|
|
||||||
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
|
|
||||||
while let read = try? readFromSocket(), read == .some {
|
|
||||||
pipeline.fireChannelReadComplete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
close0(error: err, mode: .all, promise: nil)
|
|
||||||
|
|
||||||
// we handled all writes
|
|
||||||
return .unregister
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `Channel` for a server socket.
|
/// A `Channel` for a server socket.
|
||||||
|
@ -930,6 +975,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
|
||||||
private let group: EventLoopGroup
|
private let group: EventLoopGroup
|
||||||
|
|
||||||
/// The server socket channel is never writable.
|
/// The server socket channel is never writable.
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
override public var isWritable: Bool { return false }
|
override public var isWritable: Bool { return false }
|
||||||
|
|
||||||
init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: Int32) throws {
|
init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: Int32) throws {
|
||||||
|
@ -987,6 +1033,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
|
||||||
}
|
}
|
||||||
executeAndComplete(p) {
|
executeAndComplete(p) {
|
||||||
try socket.bind(to: address)
|
try socket.bind(to: address)
|
||||||
|
self.updateCachedAddressesFromSocket(updateRemote: false)
|
||||||
try self.socket.listen(backlog: backlog)
|
try self.socket.listen(backlog: backlog)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1009,7 +1056,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
|
||||||
readPending = false
|
readPending = false
|
||||||
result = .some
|
result = .some
|
||||||
do {
|
do {
|
||||||
let chan = try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop, parent: self)
|
let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop)
|
||||||
pipeline.fireChannelRead0(NIOAny(chan))
|
pipeline.fireChannelRead0(NIOAny(chan))
|
||||||
} catch let err {
|
} catch let err {
|
||||||
_ = try? accepted.close()
|
_ = try? accepted.close()
|
||||||
|
@ -1054,9 +1101,9 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
|
||||||
final class DatagramChannel: BaseSocketChannel<Socket> {
|
final class DatagramChannel: BaseSocketChannel<Socket> {
|
||||||
|
|
||||||
// Guard against re-entrance of flushNow() method.
|
// Guard against re-entrance of flushNow() method.
|
||||||
private var inFlushNow: Bool = false
|
|
||||||
private let pendingWrites: PendingDatagramWritesManager
|
private let pendingWrites: PendingDatagramWritesManager
|
||||||
|
|
||||||
|
// This is `Channel` API so must be thread-safe.
|
||||||
override public var isWritable: Bool {
|
override public var isWritable: Bool {
|
||||||
return pendingWrites.isWritable
|
return pendingWrites.isWritable
|
||||||
}
|
}
|
||||||
|
@ -1082,18 +1129,13 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
|
||||||
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
|
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
|
||||||
}
|
}
|
||||||
|
|
||||||
fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws {
|
fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
|
||||||
try socket.setNonBlocking()
|
try socket.setNonBlocking()
|
||||||
self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs,
|
self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs,
|
||||||
iovecs: eventLoop.iovecs,
|
iovecs: eventLoop.iovecs,
|
||||||
addresses: eventLoop.addresses,
|
addresses: eventLoop.addresses,
|
||||||
storageRefs: eventLoop.storageRefs)
|
storageRefs: eventLoop.storageRefs)
|
||||||
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
|
try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
|
||||||
}
|
|
||||||
|
|
||||||
fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws {
|
|
||||||
try self.init(socket: socket, eventLoop: eventLoop)
|
|
||||||
self.parent = parent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: Datagram Channel overrides required by BaseSocketChannel
|
// MARK: Datagram Channel overrides required by BaseSocketChannel
|
||||||
|
@ -1156,22 +1198,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
|
||||||
let mayGrow = recvAllocator.record(actualReadBytes: bytesRead)
|
let mayGrow = recvAllocator.record(actualReadBytes: bytesRead)
|
||||||
readPending = false
|
readPending = false
|
||||||
|
|
||||||
let sourceAddress: SocketAddress
|
let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer)
|
||||||
switch Int(rawAddressLength) {
|
|
||||||
case MemoryLayout<sockaddr_in>.size:
|
|
||||||
let addr: sockaddr_in = rawAddress.convert()
|
|
||||||
sourceAddress = .init(addr, host: "")
|
|
||||||
case MemoryLayout<sockaddr_in6>.size:
|
|
||||||
let addr: sockaddr_in6 = rawAddress.convert()
|
|
||||||
sourceAddress = .init(addr, host: "")
|
|
||||||
case MemoryLayout<sockaddr_un>.size:
|
|
||||||
let addr: sockaddr_un = rawAddress.convert()
|
|
||||||
sourceAddress = .init(addr)
|
|
||||||
default:
|
|
||||||
fatalError("Unexpected sockaddr size")
|
|
||||||
}
|
|
||||||
|
|
||||||
let msg = AddressedEnvelope(remoteAddress: sourceAddress, data: buffer)
|
|
||||||
pipeline.fireChannelRead0(NIOAny(msg))
|
pipeline.fireChannelRead0(NIOAny(msg))
|
||||||
if mayGrow && i < maxMessagesPerRead {
|
if mayGrow && i < maxMessagesPerRead {
|
||||||
buffer = recvAllocator.buffer(allocator: allocator)
|
buffer = recvAllocator.buffer(allocator: allocator)
|
||||||
|
@ -1211,45 +1238,8 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
|
||||||
self.pendingWrites.failAll(error: error, close: true)
|
self.pendingWrites.failAll(error: error, close: true)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fileprivate func flushNow() -> IONotificationState {
|
override func writeToSocket() throws -> OverallWriteResult {
|
||||||
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
|
let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in
|
||||||
// `writeToSocket`.
|
|
||||||
guard !inFlushNow && !closed else {
|
|
||||||
return .unregister
|
|
||||||
}
|
|
||||||
|
|
||||||
defer {
|
|
||||||
inFlushNow = false
|
|
||||||
}
|
|
||||||
inFlushNow = true
|
|
||||||
|
|
||||||
do {
|
|
||||||
switch try self.writeToSocket(pendingWrites: pendingWrites) {
|
|
||||||
case .couldNotWriteEverything:
|
|
||||||
return .register
|
|
||||||
case .writtenCompletely:
|
|
||||||
return .unregister
|
|
||||||
}
|
|
||||||
} catch let err {
|
|
||||||
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
|
|
||||||
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
|
|
||||||
if readIfNeeded0() {
|
|
||||||
|
|
||||||
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
|
|
||||||
while let read = try? readFromSocket(), read == .some {
|
|
||||||
pipeline.fireChannelReadComplete()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
close0(error: err, mode: .all, promise: nil)
|
|
||||||
|
|
||||||
// we handled all writes
|
|
||||||
return .unregister
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func writeToSocket(pendingWrites: PendingDatagramWritesManager) throws -> OverallWriteResult {
|
|
||||||
let result = try pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in
|
|
||||||
guard ptr.count > 0 else {
|
guard ptr.count > 0 else {
|
||||||
// No need to call write if the buffer is empty.
|
// No need to call write if the buffer is empty.
|
||||||
return .processed(0)
|
return .processed(0)
|
||||||
|
@ -1273,6 +1263,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
|
||||||
assert(self.eventLoop.inEventLoop)
|
assert(self.eventLoop.inEventLoop)
|
||||||
do {
|
do {
|
||||||
try socket.bind(to: address)
|
try socket.bind(to: address)
|
||||||
|
self.updateCachedAddressesFromSocket(updateRemote: false)
|
||||||
promise?.succeed(result: ())
|
promise?.succeed(result: ())
|
||||||
becomeActive0()
|
becomeActive0()
|
||||||
readIfNeeded0()
|
readIfNeeded0()
|
||||||
|
|
|
@ -50,6 +50,12 @@ private let sysLseek = lseek
|
||||||
private let sysRecvFrom = recvfrom
|
private let sysRecvFrom = recvfrom
|
||||||
private let sysSendTo = sendto
|
private let sysSendTo = sendto
|
||||||
private let sysDup = dup
|
private let sysDup = dup
|
||||||
|
private let sysGetpeername = getpeername
|
||||||
|
private let sysGetsockname = getsockname
|
||||||
|
private let sysAF_INET = AF_INET
|
||||||
|
private let sysAF_INET6 = AF_INET6
|
||||||
|
private let sysAF_UNIX = AF_UNIX
|
||||||
|
private let sysInet_ntop = inet_ntop
|
||||||
|
|
||||||
#if os(Linux)
|
#if os(Linux)
|
||||||
private let sysSendMmsg = CNIOLinux_sendmmsg
|
private let sysSendMmsg = CNIOLinux_sendmmsg
|
||||||
|
@ -110,6 +116,23 @@ internal func wrapSyscall<T: FixedWidthInteger>(where function: StaticString = #
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Sorry, we really try hard to not use underscored attributes. In this case however we seem to break the inlining threshold which makes a system call take twice the time, ie. we need this exception. */
|
||||||
|
@inline(__always)
|
||||||
|
internal func wrapErrorIsNullReturnCall(where function: StaticString = #function, _ fn: () throws -> UnsafePointer<CChar>?) throws -> UnsafePointer<CChar>? {
|
||||||
|
while true {
|
||||||
|
let res = try fn()
|
||||||
|
if res == nil {
|
||||||
|
let err = errno
|
||||||
|
if err == EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert(!isBlacklistedErrno(err), "blacklisted errno \(err) \(strerror(err)!)")
|
||||||
|
throw IOError(errnoCode: err, function: function)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
enum Shutdown {
|
enum Shutdown {
|
||||||
case RD
|
case RD
|
||||||
case WR
|
case WR
|
||||||
|
@ -148,6 +171,9 @@ internal enum Posix {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static let AF_INET = sa_family_t(sysAF_INET)
|
||||||
|
static let AF_INET6 = sa_family_t(sysAF_INET6)
|
||||||
|
static let AF_UNIX = sa_family_t(sysAF_UNIX)
|
||||||
|
|
||||||
@inline(never)
|
@inline(never)
|
||||||
public static func shutdown(descriptor: CInt, how: Shutdown) throws {
|
public static func shutdown(descriptor: CInt, how: Shutdown) throws {
|
||||||
|
@ -321,6 +347,14 @@ internal enum Posix {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
@inline(never)
|
||||||
|
public static func inet_ntop(addressFamily: CInt, addressBytes: UnsafeRawPointer, addressDescription: UnsafeMutablePointer<CChar>, addressDescriptionLength: socklen_t) throws -> UnsafePointer<CChar>? {
|
||||||
|
return try wrapErrorIsNullReturnCall {
|
||||||
|
sysInet_ntop(addressFamily, addressBytes, addressDescription, addressDescriptionLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Its not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple
|
// Its not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple
|
||||||
@inline(never)
|
@inline(never)
|
||||||
public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult<Int> {
|
public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult<Int> {
|
||||||
|
@ -365,6 +399,21 @@ internal enum Posix {
|
||||||
Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout))
|
Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@inline(never)
|
||||||
|
public static func getpeername(socket: CInt, address: UnsafeMutablePointer<sockaddr>, addressLength: UnsafeMutablePointer<socklen_t>) throws {
|
||||||
|
_ = try wrapSyscall {
|
||||||
|
return sysGetpeername(socket, address, addressLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@inline(never)
|
||||||
|
public static func getsockname(socket: CInt, address: UnsafeMutablePointer<sockaddr>, addressLength: UnsafeMutablePointer<socklen_t>) throws {
|
||||||
|
_ = try wrapSyscall {
|
||||||
|
return sysGetsockname(socket, address, addressLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
|
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
|
||||||
|
|
|
@ -74,7 +74,7 @@ final class ChatHandler: ChannelInboundHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
public func channelActive(ctx: ChannelHandlerContext) {
|
public func channelActive(ctx: ChannelHandlerContext) {
|
||||||
let remoteAddress = ctx.channel.remoteAddress!
|
let remoteAddress = ctx.remoteAddress!
|
||||||
let channel = ctx.channel
|
let channel = ctx.channel
|
||||||
self.channelsSyncQueue.async {
|
self.channelsSyncQueue.async {
|
||||||
// broadcast the message to all the connected clients except the one that just became active.
|
// broadcast the message to all the connected clients except the one that just became active.
|
||||||
|
@ -84,7 +84,7 @@ final class ChatHandler: ChannelInboundHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
var buffer = channel.allocator.buffer(capacity: 64)
|
var buffer = channel.allocator.buffer(capacity: 64)
|
||||||
buffer.write(string: "(ChatServer) - Welcome to: \(channel.localAddress!)\n")
|
buffer.write(string: "(ChatServer) - Welcome to: \(ctx.localAddress!)\n")
|
||||||
ctx.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
|
ctx.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -307,3 +307,91 @@ extension UInt: AtomicPrimitive {
|
||||||
public static let atomic_load = catmc_atomic_unsigned_long_load
|
public static let atomic_load = catmc_atomic_unsigned_long_load
|
||||||
public static let atomic_store = catmc_atomic_unsigned_long_store
|
public static let atomic_store = catmc_atomic_unsigned_long_store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `AtomicBox` is a heap-allocated box which allows atomic access to an instance of a Swift class.
|
||||||
|
///
|
||||||
|
/// It behaves very much like `Atomic<T>` but for objects, maintaining the correct retain counts.
|
||||||
|
public class AtomicBox<T: AnyObject> {
|
||||||
|
private let storage: Atomic<Int>
|
||||||
|
|
||||||
|
public init(value: T) {
|
||||||
|
let ptr = Unmanaged<T>.passRetained(value)
|
||||||
|
self.storage = Atomic(value: Int(bitPattern: ptr.toOpaque()))
|
||||||
|
}
|
||||||
|
|
||||||
|
deinit {
|
||||||
|
let oldPtrBits = self.storage.exchange(with: 0xdeadbeef)
|
||||||
|
let oldPtr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!)
|
||||||
|
oldPtr.release()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomically compares the value against `expected` and, if they are equal,
|
||||||
|
/// replaces the value with `desired`.
|
||||||
|
///
|
||||||
|
/// This implementation conforms to C11's `atomic_compare_exchange_strong`. This
|
||||||
|
/// means that the compare-and-swap will always succeed if `expected` is equal to
|
||||||
|
/// value. Additionally, it uses a *sequentially consistent ordering*. For more
|
||||||
|
/// details on atomic memory models, check the documentation for C11's
|
||||||
|
/// `stdatomic.h`.
|
||||||
|
///
|
||||||
|
/// - Parameter expected: The value that this object must currently hold for the
|
||||||
|
/// compare-and-swap to succeed.
|
||||||
|
/// - Parameter desired: The new value that this object will hold if the compare
|
||||||
|
/// succeeds.
|
||||||
|
/// - Returns: `True` if the exchange occurred, or `False` if `expected` did not
|
||||||
|
/// match the current value and so no exchange occurred.
|
||||||
|
public func compareAndExchange(expected: T, desired: T) -> Bool {
|
||||||
|
return withExtendedLifetime(desired) {
|
||||||
|
let expectedPtr = Unmanaged<T>.passUnretained(expected)
|
||||||
|
let desiredPtr = Unmanaged<T>.passUnretained(desired)
|
||||||
|
|
||||||
|
if self.storage.compareAndExchange(expected: Int(bitPattern: expectedPtr.toOpaque()),
|
||||||
|
desired: Int(bitPattern: desiredPtr.toOpaque())) {
|
||||||
|
_ = desiredPtr.retain()
|
||||||
|
expectedPtr.release()
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomically exchanges `value` for the current value of this object.
|
||||||
|
///
|
||||||
|
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
|
||||||
|
/// more than that this operation is atomic: there is no guarantee that any other
|
||||||
|
/// event will be ordered before or after this one.
|
||||||
|
///
|
||||||
|
/// - Parameter value: The new value to set this object to.
|
||||||
|
/// - Returns: The value previously held by this object.
|
||||||
|
public func exchange(with value: T) -> T {
|
||||||
|
let newPtr = Unmanaged<T>.passRetained(value)
|
||||||
|
let oldPtrBits = self.storage.exchange(with: Int(bitPattern: newPtr.toOpaque()))
|
||||||
|
let oldPtr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!)
|
||||||
|
return oldPtr.takeRetainedValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomically loads and returns the value of this object.
|
||||||
|
///
|
||||||
|
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
|
||||||
|
/// more than that this operation is atomic: there is no guarantee that any other
|
||||||
|
/// event will be ordered before or after this one.
|
||||||
|
///
|
||||||
|
/// - Returns: The value of this object
|
||||||
|
public func load() -> T {
|
||||||
|
let ptrBits = self.storage.load()
|
||||||
|
let ptr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: ptrBits)!)
|
||||||
|
return ptr.takeUnretainedValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomically replaces the value of this object with `value`.
|
||||||
|
///
|
||||||
|
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
|
||||||
|
/// more than that this operation is atomic: there is no guarantee that any other
|
||||||
|
/// event will be ordered before or after this one.
|
||||||
|
///
|
||||||
|
/// - Parameter value: The new value to set the object to.
|
||||||
|
public func store(_ value: T) -> Void {
|
||||||
|
_ = self.exchange(with: value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ private final class EchoHandler: ChannelInboundHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
public func channelActive(ctx: ChannelHandlerContext) {
|
public func channelActive(ctx: ChannelHandlerContext) {
|
||||||
print("Client connected to \(ctx.channel.remoteAddress!)")
|
print("Client connected to \(ctx.remoteAddress!)")
|
||||||
|
|
||||||
// We are connected its time to send the message to the server to initialize the ping-pong sequence.
|
// We are connected its time to send the message to the server to initialize the ping-pong sequence.
|
||||||
var buffer = ctx.channel.allocator.buffer(capacity: line.utf8.count)
|
var buffer = ctx.channel.allocator.buffer(capacity: line.utf8.count)
|
||||||
|
|
|
@ -72,7 +72,7 @@ private final class HTTPHandler: ChannelInboundHandler {
|
||||||
URL: \(self.infoSavedRequestHead!.uri)\r
|
URL: \(self.infoSavedRequestHead!.uri)\r
|
||||||
body length: \(self.infoSavedBodyBytes)\r
|
body length: \(self.infoSavedBodyBytes)\r
|
||||||
headers: \(self.infoSavedRequestHead!.headers)\r
|
headers: \(self.infoSavedRequestHead!.headers)\r
|
||||||
client: \(ctx.channel.remoteAddress?.description ?? "zombie")\r
|
client: \(ctx.remoteAddress?.description ?? "zombie")\r
|
||||||
IO: SwiftNIO Electric Boogaloo™️\r\n
|
IO: SwiftNIO Electric Boogaloo™️\r\n
|
||||||
"""
|
"""
|
||||||
self.buffer.clear()
|
self.buffer.clear()
|
||||||
|
@ -200,7 +200,7 @@ private final class HTTPHandler: ChannelInboundHandler {
|
||||||
case "/dynamic/count-to-ten":
|
case "/dynamic/count-to-ten":
|
||||||
return { self.handleMultipleWrites(ctx: $0, request: $1, strings: (1...10).map { "\($0)\r\n" }, delay: .milliseconds(100)) }
|
return { self.handleMultipleWrites(ctx: $0, request: $1, strings: (1...10).map { "\($0)\r\n" }, delay: .milliseconds(100)) }
|
||||||
case "/dynamic/client-ip":
|
case "/dynamic/client-ip":
|
||||||
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.channel.remoteAddress.debugDescription)") }
|
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.remoteAddress.debugDescription)") }
|
||||||
default:
|
default:
|
||||||
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, statusCode: .notFound, string: "not found") }
|
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, statusCode: .notFound, string: "not found") }
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,11 @@ extension NIOConcurrencyHelpersTests {
|
||||||
("testConditionLockMutualExclusion", testConditionLockMutualExclusion),
|
("testConditionLockMutualExclusion", testConditionLockMutualExclusion),
|
||||||
("testConditionLock", testConditionLock),
|
("testConditionLock", testConditionLock),
|
||||||
("testConditionLockWithDifferentConditions", testConditionLockWithDifferentConditions),
|
("testConditionLockWithDifferentConditions", testConditionLockWithDifferentConditions),
|
||||||
|
("testAtomicBoxDoesNotTriviallyLeak", testAtomicBoxDoesNotTriviallyLeak),
|
||||||
|
("testAtomicBoxCompareAndExchangeWorksIfEqual", testAtomicBoxCompareAndExchangeWorksIfEqual),
|
||||||
|
("testAtomicBoxCompareAndExchangeWorksIfNotEqual", testAtomicBoxCompareAndExchangeWorksIfNotEqual),
|
||||||
|
("testAtomicBoxStoreWorks", testAtomicBoxStoreWorks),
|
||||||
|
("testAtomicBoxCompareAndExchangeOntoItselfWorks", testAtomicBoxCompareAndExchangeOntoItselfWorks),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -401,4 +401,149 @@ class NIOConcurrencyHelpersTests: XCTestCase {
|
||||||
doneSem.wait() /* job on 'q2' is done */
|
doneSem.wait() /* job on 'q2' is done */
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testAtomicBoxDoesNotTriviallyLeak() throws {
|
||||||
|
class SomeClass {}
|
||||||
|
weak var weakSomeInstance1: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance2: SomeClass? = nil
|
||||||
|
({
|
||||||
|
let someInstance = SomeClass()
|
||||||
|
weakSomeInstance1 = someInstance
|
||||||
|
let someAtomic = AtomicBox(value: someInstance)
|
||||||
|
let loadedFromAtomic = someAtomic.load()
|
||||||
|
weakSomeInstance2 = loadedFromAtomic
|
||||||
|
XCTAssertNotNil(weakSomeInstance1)
|
||||||
|
XCTAssertNotNil(weakSomeInstance2)
|
||||||
|
XCTAssert(someInstance === loadedFromAtomic)
|
||||||
|
})()
|
||||||
|
XCTAssertNil(weakSomeInstance1)
|
||||||
|
XCTAssertNil(weakSomeInstance2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAtomicBoxCompareAndExchangeWorksIfEqual() throws {
|
||||||
|
class SomeClass {}
|
||||||
|
weak var weakSomeInstance1: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance2: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance3: SomeClass? = nil
|
||||||
|
({
|
||||||
|
let someInstance1 = SomeClass()
|
||||||
|
let someInstance2 = SomeClass()
|
||||||
|
weakSomeInstance1 = someInstance1
|
||||||
|
|
||||||
|
let atomic = AtomicBox(value: someInstance1)
|
||||||
|
var loadedFromAtomic = atomic.load()
|
||||||
|
XCTAssert(someInstance1 === loadedFromAtomic)
|
||||||
|
weakSomeInstance2 = loadedFromAtomic
|
||||||
|
|
||||||
|
XCTAssertTrue(atomic.compareAndExchange(expected: loadedFromAtomic, desired: someInstance2))
|
||||||
|
|
||||||
|
loadedFromAtomic = atomic.load()
|
||||||
|
weakSomeInstance3 = loadedFromAtomic
|
||||||
|
XCTAssert(someInstance1 !== loadedFromAtomic)
|
||||||
|
XCTAssert(someInstance2 === loadedFromAtomic)
|
||||||
|
|
||||||
|
XCTAssertNotNil(weakSomeInstance1)
|
||||||
|
XCTAssertNotNil(weakSomeInstance2)
|
||||||
|
XCTAssertNotNil(weakSomeInstance3)
|
||||||
|
XCTAssert(weakSomeInstance1 === weakSomeInstance2 && weakSomeInstance2 !== weakSomeInstance3)
|
||||||
|
})()
|
||||||
|
XCTAssertNil(weakSomeInstance1)
|
||||||
|
XCTAssertNil(weakSomeInstance2)
|
||||||
|
XCTAssertNil(weakSomeInstance3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAtomicBoxCompareAndExchangeWorksIfNotEqual() throws {
|
||||||
|
class SomeClass {}
|
||||||
|
weak var weakSomeInstance1: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance2: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance3: SomeClass? = nil
|
||||||
|
({
|
||||||
|
let someInstance1 = SomeClass()
|
||||||
|
let someInstance2 = SomeClass()
|
||||||
|
weakSomeInstance1 = someInstance1
|
||||||
|
|
||||||
|
let atomic = AtomicBox(value: someInstance1)
|
||||||
|
var loadedFromAtomic = atomic.load()
|
||||||
|
XCTAssert(someInstance1 === loadedFromAtomic)
|
||||||
|
weakSomeInstance2 = loadedFromAtomic
|
||||||
|
|
||||||
|
XCTAssertFalse(atomic.compareAndExchange(expected: someInstance2, desired: someInstance2))
|
||||||
|
XCTAssertFalse(atomic.compareAndExchange(expected: SomeClass(), desired: someInstance2))
|
||||||
|
XCTAssertTrue(atomic.load() === someInstance1)
|
||||||
|
|
||||||
|
loadedFromAtomic = atomic.load()
|
||||||
|
weakSomeInstance3 = someInstance2
|
||||||
|
XCTAssert(someInstance1 === loadedFromAtomic)
|
||||||
|
XCTAssert(someInstance2 !== loadedFromAtomic)
|
||||||
|
|
||||||
|
XCTAssertNotNil(weakSomeInstance1)
|
||||||
|
XCTAssertNotNil(weakSomeInstance2)
|
||||||
|
XCTAssertNotNil(weakSomeInstance3)
|
||||||
|
})()
|
||||||
|
XCTAssertNil(weakSomeInstance1)
|
||||||
|
XCTAssertNil(weakSomeInstance2)
|
||||||
|
XCTAssertNil(weakSomeInstance3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAtomicBoxStoreWorks() throws {
|
||||||
|
class SomeClass {}
|
||||||
|
weak var weakSomeInstance1: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance2: SomeClass? = nil
|
||||||
|
weak var weakSomeInstance3: SomeClass? = nil
|
||||||
|
({
|
||||||
|
let someInstance1 = SomeClass()
|
||||||
|
let someInstance2 = SomeClass()
|
||||||
|
weakSomeInstance1 = someInstance1
|
||||||
|
|
||||||
|
let atomic = AtomicBox(value: someInstance1)
|
||||||
|
var loadedFromAtomic = atomic.load()
|
||||||
|
XCTAssert(someInstance1 === loadedFromAtomic)
|
||||||
|
weakSomeInstance2 = loadedFromAtomic
|
||||||
|
|
||||||
|
atomic.store(someInstance2)
|
||||||
|
|
||||||
|
loadedFromAtomic = atomic.load()
|
||||||
|
weakSomeInstance3 = loadedFromAtomic
|
||||||
|
XCTAssert(someInstance1 !== loadedFromAtomic)
|
||||||
|
XCTAssert(someInstance2 === loadedFromAtomic)
|
||||||
|
|
||||||
|
XCTAssertNotNil(weakSomeInstance1)
|
||||||
|
XCTAssertNotNil(weakSomeInstance2)
|
||||||
|
XCTAssertNotNil(weakSomeInstance3)
|
||||||
|
})()
|
||||||
|
XCTAssertNil(weakSomeInstance1)
|
||||||
|
XCTAssertNil(weakSomeInstance2)
|
||||||
|
XCTAssertNil(weakSomeInstance3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAtomicBoxCompareAndExchangeOntoItselfWorks() {
|
||||||
|
let q = DispatchQueue(label: "q")
|
||||||
|
let g = DispatchGroup()
|
||||||
|
let sem1 = DispatchSemaphore(value: 0)
|
||||||
|
let sem2 = DispatchSemaphore(value: 0)
|
||||||
|
class SomeClass {}
|
||||||
|
weak var weakInstance: SomeClass?
|
||||||
|
({
|
||||||
|
let instance = SomeClass()
|
||||||
|
weakInstance = instance
|
||||||
|
|
||||||
|
let atomic = AtomicBox(value: instance)
|
||||||
|
q.async(group: g) {
|
||||||
|
sem1.signal()
|
||||||
|
sem2.wait()
|
||||||
|
for _ in 0..<1000 {
|
||||||
|
XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sem2.signal()
|
||||||
|
sem1.wait()
|
||||||
|
for _ in 0..<1000 {
|
||||||
|
XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance))
|
||||||
|
}
|
||||||
|
g.wait()
|
||||||
|
let v = atomic.load()
|
||||||
|
XCTAssert(v === instance)
|
||||||
|
})()
|
||||||
|
XCTAssertNil(weakInstance)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,18 @@ import NIOFoundationCompat
|
||||||
import Dispatch
|
import Dispatch
|
||||||
@testable import NIOHTTP1
|
@testable import NIOHTTP1
|
||||||
|
|
||||||
|
internal extension Channel {
|
||||||
|
func syncCloseAcceptingAlreadyClosed() throws {
|
||||||
|
do {
|
||||||
|
try self.close().wait()
|
||||||
|
} catch ChannelError.alreadyClosed {
|
||||||
|
/* we're happy with this one */
|
||||||
|
} catch let e {
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extension Array where Array.Element == ByteBuffer {
|
extension Array where Array.Element == ByteBuffer {
|
||||||
public func allAsBytes() -> [UInt8] {
|
public func allAsBytes() -> [UInt8] {
|
||||||
var out: [UInt8] = []
|
var out: [UInt8] = []
|
||||||
|
@ -61,17 +73,6 @@ internal class ArrayAccumulationHandler<T>: ChannelInboundHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
class HTTPServerClientTest : XCTestCase {
|
class HTTPServerClientTest : XCTestCase {
|
||||||
|
|
||||||
private func syncCloseAcceptingAlreadyClosed(channel: Channel) throws {
|
|
||||||
do {
|
|
||||||
try channel.close().wait()
|
|
||||||
} catch ChannelError.alreadyClosed {
|
|
||||||
/* we're happy with this one */
|
|
||||||
} catch let e {
|
|
||||||
throw e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */
|
/* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */
|
||||||
private static let massiveResponseLength = 5 * 1024 * 1024 + 7
|
private static let massiveResponseLength = 5 * 1024 * 1024 + 7
|
||||||
private static let massiveResponseBytes: [UInt8] = {
|
private static let massiveResponseBytes: [UInt8] = {
|
||||||
|
@ -374,7 +375,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
}.bind(host: "127.0.0.1", port: 0).wait()
|
}.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
|
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
let clientChannel = try ClientBootstrap(group: group)
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
@ -387,7 +388,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
.wait()
|
.wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
|
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/helloworld")
|
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/helloworld")
|
||||||
|
@ -432,7 +433,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
}.bind(host: "127.0.0.1", port: 0).wait()
|
}.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
|
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
let clientChannel = try ClientBootstrap(group: group)
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
@ -445,7 +446,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
.wait()
|
.wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
|
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/count-to-ten")
|
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/count-to-ten")
|
||||||
|
@ -490,7 +491,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
}.bind(host: "127.0.0.1", port: 0).wait()
|
}.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
|
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
let clientChannel = try ClientBootstrap(group: group)
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
@ -503,7 +504,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
.wait()
|
.wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
|
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/trailers")
|
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/trailers")
|
||||||
|
@ -550,7 +551,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
}.bind(host: "127.0.0.1", port: 0).wait()
|
}.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
|
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
let clientChannel = try ClientBootstrap(group: group)
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
@ -559,7 +560,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
.wait()
|
.wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
|
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
var buffer = clientChannel.allocator.buffer(capacity: numBytes)
|
var buffer = clientChannel.allocator.buffer(capacity: numBytes)
|
||||||
|
@ -591,7 +592,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
}
|
}
|
||||||
}.bind(host: "127.0.0.1", port: 0).wait()
|
}.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
|
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
let clientChannel = try ClientBootstrap(group: group)
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
@ -604,7 +605,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
.wait()
|
.wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
|
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .HEAD, uri: "/head")
|
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .HEAD, uri: "/head")
|
||||||
|
@ -636,7 +637,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
}
|
}
|
||||||
}.bind(host: "127.0.0.1", port: 0).wait()
|
}.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
|
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
let clientChannel = try ClientBootstrap(group: group)
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
@ -649,7 +650,7 @@ class HTTPServerClientTest : XCTestCase {
|
||||||
.wait()
|
.wait()
|
||||||
|
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
|
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/204")
|
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/204")
|
||||||
|
|
|
@ -55,6 +55,7 @@ extension ChannelTests {
|
||||||
("testHalfClosure", testHalfClosure),
|
("testHalfClosure", testHalfClosure),
|
||||||
("testRejectsInvalidData", testRejectsInvalidData),
|
("testRejectsInvalidData", testRejectsInvalidData),
|
||||||
("testWeDontCrashIfChannelReleasesBeforePipeline", testWeDontCrashIfChannelReleasesBeforePipeline),
|
("testWeDontCrashIfChannelReleasesBeforePipeline", testWeDontCrashIfChannelReleasesBeforePipeline),
|
||||||
|
("testAskForLocalAndRemoteAddressesAfterChannelIsClosed", testAskForLocalAndRemoteAddressesAfterChannelIsClosed),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1157,7 +1157,7 @@ public class ChannelTests: XCTestCase {
|
||||||
return channel.pipeline.add(handler: byteCountingHandler)
|
return channel.pipeline.add(handler: byteCountingHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.connect(to: server.localAddress!)
|
.connect(to: try! server.localAddress())
|
||||||
let accepted = try server.accept()!
|
let accepted = try server.accept()!
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try accepted.close())
|
XCTAssertNoThrow(try accepted.close())
|
||||||
|
@ -1219,7 +1219,7 @@ public class ChannelTests: XCTestCase {
|
||||||
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
|
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.connect(to: server.localAddress!)
|
.connect(to: try! server.localAddress())
|
||||||
let accepted = try server.accept()!
|
let accepted = try server.accept()!
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try accepted.close())
|
XCTAssertNoThrow(try accepted.close())
|
||||||
|
@ -1269,7 +1269,7 @@ public class ChannelTests: XCTestCase {
|
||||||
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
|
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
|
||||||
}
|
}
|
||||||
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
|
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
|
||||||
.connect(to: server.localAddress!)
|
.connect(to: try! server.localAddress())
|
||||||
let accepted = try server.accept()!
|
let accepted = try server.accept()!
|
||||||
defer {
|
defer {
|
||||||
XCTAssertNoThrow(try accepted.close())
|
XCTAssertNoThrow(try accepted.close())
|
||||||
|
@ -1412,4 +1412,27 @@ public class ChannelTests: XCTestCase {
|
||||||
XCTAssertNil(weakServerChannel, "weakServerChannel not nil, looks like we leaked it!")
|
XCTAssertNil(weakServerChannel, "weakServerChannel not nil, looks like we leaked it!")
|
||||||
XCTAssertNil(weakServerChildChannel, "weakServerChildChannel not nil, looks like we leaked it!")
|
XCTAssertNil(weakServerChildChannel, "weakServerChildChannel not nil, looks like we leaked it!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testAskForLocalAndRemoteAddressesAfterChannelIsClosed() throws {
|
||||||
|
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||||
|
defer {
|
||||||
|
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||||
|
}
|
||||||
|
|
||||||
|
let serverChannel = try ServerBootstrap(group: group)
|
||||||
|
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
|
||||||
|
.bind(host: "127.0.0.1", port: 0).wait()
|
||||||
|
|
||||||
|
let clientChannel = try ClientBootstrap(group: group)
|
||||||
|
.connect(to: serverChannel.localAddress!).wait()
|
||||||
|
|
||||||
|
|
||||||
|
// Start shutting stuff down.
|
||||||
|
try serverChannel.syncCloseAcceptingAlreadyClosed()
|
||||||
|
try clientChannel.syncCloseAcceptingAlreadyClosed()
|
||||||
|
|
||||||
|
for f in [ serverChannel.remoteAddress, serverChannel.localAddress, clientChannel.remoteAddress, clientChannel.localAddress ] {
|
||||||
|
XCTAssertNil(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,17 @@ func openTemporaryFile() -> (CInt, String) {
|
||||||
return (fd, String(decoding: templateBytes, as: UTF8.self))
|
return (fd, String(decoding: templateBytes, as: UTF8.self))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal extension Channel {
|
||||||
|
func syncCloseAcceptingAlreadyClosed() throws {
|
||||||
|
do {
|
||||||
|
try self.close().wait()
|
||||||
|
} catch ChannelError.alreadyClosed {
|
||||||
|
/* we're happy with this one */
|
||||||
|
} catch let e {
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final class ByteCountingHandler : ChannelInboundHandler {
|
final class ByteCountingHandler : ChannelInboundHandler {
|
||||||
typealias InboundIn = ByteBuffer
|
typealias InboundIn = ByteBuffer
|
||||||
|
|
Loading…
Reference in New Issue