fix Channel thread-safety (local/remoteAddress)

Motivation:

There were a couple of places in the Channel implementations that just
weren't thread-safe at all, namely:

- localAddress
- remoetAddress
- parent

those are now fixed. I also re-ordered the code so it should be easier
to maintain in the future (the `// MARK` markers were totally
incorrect).

Modifications:

- made Channel.{local,remote}Address return a future
- made Channel.parent a `let`
- unified more of `SocketChannel` and `DatagramSocketChannel`
- fixed the `// MARK`s by reordering code
- annotated methods/properties that are in the Channel API and need to
be thread-safe

Result:

slightly more thread-safety :)
This commit is contained in:
Johannes Weiss 2018-02-20 15:18:51 +00:00
parent 273932ae36
commit 68724be6e6
19 changed files with 638 additions and 275 deletions

View File

@ -21,6 +21,19 @@ protocol SockAddrProtocol {
mutating func withMutableSockAddr<R>(_ fn: (UnsafeMutablePointer<sockaddr>, Int) throws -> R) rethrows -> R mutating func withMutableSockAddr<R>(_ fn: (UnsafeMutablePointer<sockaddr>, Int) throws -> R) rethrows -> R
} }
private func descriptionForAddress(family: CInt, bytes: UnsafeRawPointer, length byteCount: Int) -> String {
var addressBytes: [Int8] = Array(repeating: 0, count: byteCount)
return addressBytes.withUnsafeMutableBufferPointer { (addressBytesPtr: inout UnsafeMutableBufferPointer<Int8>) -> String in
try! Posix.inet_ntop(addressFamily: family,
addressBytes: bytes,
addressDescription: addressBytesPtr.baseAddress!,
addressDescriptionLength: socklen_t(byteCount))
return addressBytesPtr.baseAddress!.withMemoryRebound(to: UInt8.self, capacity: byteCount) { addressBytesPtr -> String in
return String(decoding: UnsafeBufferPointer<UInt8>(start: addressBytesPtr, count: byteCount), as: Unicode.ASCII.self)
}
}
}
extension sockaddr_in: SockAddrProtocol { extension sockaddr_in: SockAddrProtocol {
mutating func withSockAddr<R>(_ fn: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R { mutating func withSockAddr<R>(_ fn: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R {
var me = self var me = self
@ -35,6 +48,12 @@ extension sockaddr_in: SockAddrProtocol {
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
} }
} }
mutating func addressDescription() -> String {
return withUnsafePointer(to: &self.sin_addr) { addrPtr in
descriptionForAddress(family: AF_INET, bytes: addrPtr, length: Int(INET_ADDRSTRLEN))
}
}
} }
extension sockaddr_in6: SockAddrProtocol { extension sockaddr_in6: SockAddrProtocol {
@ -51,6 +70,12 @@ extension sockaddr_in6: SockAddrProtocol {
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
} }
} }
mutating func addressDescription() -> String {
return withUnsafePointer(to: &self.sin6_addr) { addrPtr in
descriptionForAddress(family: AF_INET6, bytes: addrPtr, length: Int(INET6_ADDRSTRLEN))
}
}
} }
extension sockaddr_un: SockAddrProtocol { extension sockaddr_un: SockAddrProtocol {
@ -119,54 +144,45 @@ extension sockaddr_storage {
} }
} }
} }
mutating func convert() -> SocketAddress {
switch self.ss_family {
case Posix.AF_INET:
var sockAddr: sockaddr_in = self.convert()
return SocketAddress(sockAddr, host: sockAddr.addressDescription())
case Posix.AF_INET6:
var sockAddr: sockaddr_in6 = self.convert()
return SocketAddress(sockAddr, host: sockAddr.addressDescription())
case Posix.AF_UNIX:
return SocketAddress(self.convert() as sockaddr_un)
default:
fatalError("unknown sockaddr family \(self.ss_family)")
}
}
} }
class BaseSocket: Selectable { class BaseSocket: Selectable {
public let descriptor: Int32 public let descriptor: Int32
public private(set) var open: Bool public private(set) var open: Bool
final var localAddress: SocketAddress? { final func localAddress() throws -> SocketAddress {
get { return try get_addr { try Posix.getsockname(socket: $0, address: $1, addressLength: $2) }
return get_addr { getsockname($0, $1, $2) }
}
} }
final var remoteAddress: SocketAddress? { final func remoteAddress() throws -> SocketAddress {
get { return try get_addr { try Posix.getpeername(socket: $0, address: $1, addressLength: $2) }
return get_addr { getpeername($0, $1, $2) }
}
} }
private func get_addr(_ fn: (Int32, UnsafeMutablePointer<sockaddr>, UnsafeMutablePointer<socklen_t>) -> Int32) -> SocketAddress? { private func get_addr(_ fn: (Int32, UnsafeMutablePointer<sockaddr>, UnsafeMutablePointer<socklen_t>) throws -> Void) throws -> SocketAddress {
var addr = sockaddr_storage() var addr = sockaddr_storage()
var len: socklen_t = socklen_t(MemoryLayout<sockaddr_storage>.size)
try addr.withMutableSockAddr { addressPtr, size in
return withUnsafeMutablePointer(to: &addr) { var size = socklen_t(size)
$0.withMemoryRebound(to: sockaddr.self, capacity: 1, { address in try fn(self.descriptor, addressPtr, &size)
guard fn(descriptor, address, &len) == 0 else {
return nil
}
switch Int32(address.pointee.sa_family) {
case AF_INET:
return address.withMemoryRebound(to: sockaddr_in.self, capacity: 1, { ipv4 in
var ipAddressString = [CChar](repeating: 0, count: Int(INET_ADDRSTRLEN))
return SocketAddress(ipv4.pointee, host: String(cString: inet_ntop(AF_INET, &ipv4.pointee.sin_addr, &ipAddressString, socklen_t(INET_ADDRSTRLEN))))
})
case AF_INET6:
return address.withMemoryRebound(to: sockaddr_in6.self, capacity: 1, { ipv6 in
var ipAddressString = [CChar](repeating: 0, count: Int(INET6_ADDRSTRLEN))
return SocketAddress(ipv6.pointee, host: String(cString: inet_ntop(AF_INET6, &ipv6.pointee.sin6_addr, &ipAddressString, socklen_t(INET6_ADDRSTRLEN))))
})
case AF_UNIX:
return address.withMemoryRebound(to: sockaddr_un.self, capacity: 1) { uds in
return SocketAddress(uds.pointee)
}
default:
fatalError("address family \(address.pointee.sa_family) not supported")
}
})
} }
return addr.convert()
} }
static func newSocket(protocolFamily: Int32, type: CInt) throws -> Int32 { static func newSocket(protocolFamily: Int32, type: CInt) throws -> Int32 {
let sock = try Posix.socket(domain: protocolFamily, let sock = try Posix.socket(domain: protocolFamily,
type: type, type: type,

View File

@ -18,6 +18,8 @@ import NIOConcurrencyHelpers
/// ///
/// - note: All methods must be called from the EventLoop thread /// - note: All methods must be called from the EventLoop thread
public protocol ChannelCore : class { public protocol ChannelCore : class {
func localAddress0() throws -> SocketAddress
func remoteAddress0() throws -> SocketAddress
func register0(promise: EventLoopPromise<Void>?) func register0(promise: EventLoopPromise<Void>?)
func bind0(to: SocketAddress, promise: EventLoopPromise<Void>?) func bind0(to: SocketAddress, promise: EventLoopPromise<Void>?)
func connect0(to: SocketAddress, promise: EventLoopPromise<Void>?) func connect0(to: SocketAddress, promise: EventLoopPromise<Void>?)
@ -82,7 +84,9 @@ public protocol Channel : class, ChannelOutboundInvoker {
/// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events /// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events
/// before possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`. /// before possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`.
protocol SelectableChannel : Channel { ///
/// - warning: `SelectableChannel` methods and properties are _not_ thread-safe (unless they also belong to `Channel`).
internal protocol SelectableChannel : Channel {
/// The type of the `Selectable`. A `Selectable` is usually wrapping a file descriptor that can be registered in a /// The type of the `Selectable`. A `Selectable` is usually wrapping a file descriptor that can be registered in a
/// `Selector`. /// `Selector`.
associatedtype SelectableType: Selectable associatedtype SelectableType: Selectable

View File

@ -809,6 +809,14 @@ public final class ChannelHandlerContext : ChannelInvoker {
return self.inboundHandler ?? self.outboundHandler! return self.inboundHandler ?? self.outboundHandler!
} }
public var remoteAddress: SocketAddress? {
return try? self.channel._unsafe.remoteAddress0()
}
public var localAddress: SocketAddress? {
return try? self.channel._unsafe.localAddress0()
}
public let name: String public let name: String
public let eventLoop: EventLoop public let eventLoop: EventLoop
private let inboundHandler: _ChannelInboundHandler? private let inboundHandler: _ChannelInboundHandler?

View File

@ -16,6 +16,14 @@
/// the original `Channel` is closed. Given that the original `Channel` is closed the `DeadChannelCore` should fail /// the original `Channel` is closed. Given that the original `Channel` is closed the `DeadChannelCore` should fail
/// all operations. /// all operations.
private final class DeadChannelCore: ChannelCore { private final class DeadChannelCore: ChannelCore {
func localAddress0() throws -> SocketAddress {
throw ChannelError.ioOnClosedChannel
}
func remoteAddress0() throws -> SocketAddress {
throw ChannelError.ioOnClosedChannel
}
func register0(promise: EventLoopPromise<Void>?) { func register0(promise: EventLoopPromise<Void>?) {
promise?.fail(error: ChannelError.ioOnClosedChannel) promise?.fail(error: ChannelError.ioOnClosedChannel)
} }
@ -65,27 +73,30 @@ private final class DeadChannelCore: ChannelCore {
/// channel as it only holds an unowned reference to the original `Channel`. `DeadChannel` serves as a replacement /// channel as it only holds an unowned reference to the original `Channel`. `DeadChannel` serves as a replacement
/// that can be used when the original `Channel` might no longer be valid. /// that can be used when the original `Channel` might no longer be valid.
internal final class DeadChannel: Channel { internal final class DeadChannel: Channel {
let eventLoop: EventLoop
let pipeline: ChannelPipeline let pipeline: ChannelPipeline
var eventLoop: EventLoop { public var closeFuture: EventLoopFuture<()> {
return self.pipeline.eventLoop return self.eventLoop.newSucceededFuture(result: ())
} }
internal init(pipeline: ChannelPipeline) { internal init(pipeline: ChannelPipeline) {
self.pipeline = pipeline self.pipeline = pipeline
self.eventLoop = pipeline.eventLoop
} }
// This is `Channel` API so must be thread-safe.
var allocator: ByteBufferAllocator { var allocator: ByteBufferAllocator {
return ByteBufferAllocator() return ByteBufferAllocator()
} }
var closeFuture: EventLoopFuture<Void> { var localAddress: SocketAddress? {
return self.pipeline.eventLoop.newSucceededFuture(result: ()) return nil
} }
let localAddress: SocketAddress? = nil var remoteAddress: SocketAddress? {
return nil
let remoteAddress: SocketAddress? = nil }
let parent: Channel? = nil let parent: Channel? = nil

View File

@ -140,7 +140,6 @@ public class EmbeddedEventLoop: EventLoop {
class EmbeddedChannelCore : ChannelCore { class EmbeddedChannelCore : ChannelCore {
var closed: Bool = false var closed: Bool = false
var isActive: Bool = false var isActive: Bool = false
var eventLoop: EventLoop var eventLoop: EventLoop
var closePromise: EventLoopPromise<Void> var closePromise: EventLoopPromise<Void>
@ -163,6 +162,14 @@ class EmbeddedChannelCore : ChannelCore {
var outboundBuffer: [IOData] = [] var outboundBuffer: [IOData] = []
var inboundBuffer: [NIOAny] = [] var inboundBuffer: [NIOAny] = []
func localAddress0() throws -> SocketAddress {
throw ChannelError.operationUnsupported
}
func remoteAddress0() throws -> SocketAddress {
throw ChannelError.operationUnsupported
}
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) { func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
if closed { if closed {
promise?.fail(error: ChannelError.alreadyClosed) promise?.fail(error: ChannelError.alreadyClosed)
@ -269,8 +276,8 @@ public class EmbeddedChannel : Channel {
public var allocator: ByteBufferAllocator = ByteBufferAllocator() public var allocator: ByteBufferAllocator = ByteBufferAllocator()
public var eventLoop: EventLoop = EmbeddedEventLoop() public var eventLoop: EventLoop = EmbeddedEventLoop()
public var localAddress: SocketAddress? = nil public let localAddress: SocketAddress? = nil
public var remoteAddress: SocketAddress? = nil public let remoteAddress: SocketAddress? = nil
// Embedded channels never have parents. // Embedded channels never have parents.
public let parent: Channel? = nil public let parent: Channel? = nil

View File

@ -465,6 +465,7 @@ internal protocol PendingWritesManager {
} }
extension PendingWritesManager { extension PendingWritesManager {
// This is called from `Channel` API so must be thread-safe.
var isWritable: Bool { var isWritable: Bool {
return self.channelWritabilityFlag.load() return self.channelWritabilityFlag.load()
} }

View File

@ -13,6 +13,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Represents a selectable resource which can be registered to a `Selector` to be notified once there are some events ready for it. /// Represents a selectable resource which can be registered to a `Selector` to be notified once there are some events ready for it.
///
/// - warning: `Selectable`s are not thread-safe, only to be used on the appropriate `EventLoop`.
protocol Selectable { protocol Selectable {
/// The file descriptor itself. /// The file descriptor itself.

View File

@ -43,6 +43,38 @@ private extension ByteBuffer {
class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore { class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
typealias SelectableType = T typealias SelectableType = T
// MARK: Stored Properties
// Visible to access from EventLoop directly
public let parent: Channel?
internal let socket: T
private let closePromise: EventLoopPromise<Void>
private let selectableEventLoop: SelectableEventLoop
private let localAddressCached: AtomicBox<Box<SocketAddress?>> = AtomicBox(value: Box(nil))
private let remoteAddressCached: AtomicBox<Box<SocketAddress?>> = AtomicBox(value: Box(nil))
private let bufferAllocatorCached: AtomicBox<Box<ByteBufferAllocator>>
internal var interestedEvent: IOEvent = .none
fileprivate var readPending = false
fileprivate var pendingConnect: EventLoopPromise<Void>?
fileprivate var recvAllocator: RecvByteBufferAllocator
fileprivate var maxMessagesPerRead: UInt = 4
private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method.
private var neverRegistered = true
private var active: Atomic<Bool> = Atomic(value: false)
private var _closed: Bool = false
private var autoRead: Bool = true
private var _pipeline: ChannelPipeline!
private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator() {
didSet {
assert(self.eventLoop.inEventLoop)
self.bufferAllocatorCached.store(Box(self.bufferAllocator))
}
}
// MARK: Datatypes
/// Indicates if a selectable should registered or not for IO notifications. /// Indicates if a selectable should registered or not for IO notifications.
enum IONotificationState { enum IONotificationState {
/// We should be registered for IO notifications. /// We should be registered for IO notifications.
@ -52,7 +84,26 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
case unregister case unregister
} }
// MARK: Methods to override in subclasses. fileprivate enum ReadResult {
/// Nothing was read by the read operation.
case none
/// Some data was read by the read operation.
case some
}
// MARK: Computed Properties
public final var _unsafe: ChannelCore { return self }
// This is `Channel` API so must be thread-safe.
public final var localAddress: SocketAddress? {
return self.localAddressCached.load().value
}
// This is `Channel` API so must be thread-safe.
public final var remoteAddress: SocketAddress? {
return self.remoteAddressCached.load().value
}
/// `true` if the whole `Channel` is closed and so no more IO operation can be done. /// `true` if the whole `Channel` is closed and so no more IO operation can be done.
public var closed: Bool { public var closed: Bool {
@ -60,6 +111,48 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
return _closed return _closed
} }
internal var selectable: T {
return self.socket
}
// This is `Channel` API so must be thread-safe.
public var isActive: Bool {
return self.active.load()
}
// This is `Channel` API so must be thread-safe.
public final var closeFuture: EventLoopFuture<Void> {
return self.closePromise.futureResult
}
public final var eventLoop: EventLoop {
return selectableEventLoop
}
// This is `Channel` API so must be thread-safe.
public var isWritable: Bool {
return true
}
// This is `Channel` API so must be thread-safe.
public final var allocator: ByteBufferAllocator {
if eventLoop.inEventLoop {
return bufferAllocator
} else {
return self.bufferAllocatorCached.load().value
}
}
// This is `Channel` API so must be thread-safe.
public final var pipeline: ChannelPipeline {
return _pipeline
}
// MARK: Methods to override in subclasses.
func writeToSocket() throws -> OverallWriteResult {
fatalError("must be overridden")
}
/// Provides the registration for this selector. Must be implemented by subclasses. /// Provides the registration for this selector. Must be implemented by subclasses.
func registrationFor(interested: IOEvent) -> NIORegistration { func registrationFor(interested: IOEvent) -> NIORegistration {
fatalError("must override") fatalError("must override")
@ -72,14 +165,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
fatalError("this must be overridden by sub class") fatalError("this must be overridden by sub class")
} }
fileprivate enum ReadResult {
/// Nothing was read by the read operation.
case none
/// Some data was read by the read operation.
case some
}
/// Begin connection of the underlying socket. /// Begin connection of the underlying socket.
/// ///
/// - parameters: /// - parameters:
@ -111,70 +196,78 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
fatalError("this must be overridden by sub class") fatalError("this must be overridden by sub class")
} }
// MARK: Common base socket logic.
fileprivate init(socket: T, parent: Channel? = nil, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
self.bufferAllocatorCached = AtomicBox(value: Box(self.bufferAllocator))
self.socket = socket
self.selectableEventLoop = eventLoop
self.closePromise = eventLoop.newPromise()
self.parent = parent
self.active.store(false)
self.recvAllocator = recvAllocator
self._pipeline = ChannelPipeline(channel: self)
}
deinit {
assert(self._closed, "leak of open Channel")
}
public final func localAddress0() throws -> SocketAddress {
assert(self.eventLoop.inEventLoop)
guard self.open else {
throw ChannelError.ioOnClosedChannel
}
return try self.socket.localAddress()
}
public final func remoteAddress0() throws -> SocketAddress {
assert(self.eventLoop.inEventLoop)
guard self.open else {
throw ChannelError.ioOnClosedChannel
}
return try self.socket.remoteAddress()
}
/// Flush data to the underlying socket and return if this socket needs to be registered for write notifications. /// Flush data to the underlying socket and return if this socket needs to be registered for write notifications.
/// ///
/// - returns: If this socket should be registered for write notifications. Ie. `IONotificationState.register` if _not_ all data could be written, so notifications are necessary; and `IONotificationState.unregister` if everything was written and we don't need to be notified about writability at the moment. /// - returns: If this socket should be registered for write notifications. Ie. `IONotificationState.register` if _not_ all data could be written, so notifications are necessary; and `IONotificationState.unregister` if everything was written and we don't need to be notified about writability at the moment.
fileprivate func flushNow() -> IONotificationState { fileprivate func flushNow() -> IONotificationState {
fatalError("this must be overridden by sub class") // Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
} // `writeToSocket`.
guard !inFlushNow && !closed else {
return .unregister
}
// MARK: Common base socket logic. defer {
inFlushNow = false
}
inFlushNow = true
var selectable: T { do {
return self.socket switch try self.writeToSocket() {
} case .couldNotWriteEverything:
return .register
case .writtenCompletely:
return .unregister
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if readIfNeeded0() {
public final var _unsafe: ChannelCore { return self } // We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
while let read = try? readFromSocket(), read == .some {
pipeline.fireChannelReadComplete()
}
}
// Visible to access from EventLoop directly close0(error: err, mode: .all, promise: nil)
let socket: T
public var interestedEvent: IOEvent = .none
fileprivate var readPending = false // we handled all writes
private var neverRegistered = true return .unregister
fileprivate var pendingConnect: EventLoopPromise<Void>?
private let closePromise: EventLoopPromise<Void>
private var active: Atomic<Bool> = Atomic(value: false)
private var _closed: Bool = false
public var isActive: Bool {
return active.load()
}
public var parent: Channel? = nil
public final var closeFuture: EventLoopFuture<Void> {
return closePromise.futureResult
}
private let selectableEventLoop: SelectableEventLoop
public final var eventLoop: EventLoop {
return selectableEventLoop
}
public var isWritable: Bool {
return true
}
public final var allocator: ByteBufferAllocator {
if eventLoop.inEventLoop {
return bufferAllocator
} else {
return try! eventLoop.submit{ self.bufferAllocator }.wait()
} }
} }
private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator()
fileprivate var recvAllocator: RecvByteBufferAllocator
fileprivate var autoRead: Bool = true
fileprivate var maxMessagesPerRead: UInt = 4
// We don't use lazy var here as this is more expensive then doing this :/
public final var pipeline: ChannelPipeline {
return _pipeline
}
private var _pipeline: ChannelPipeline!
public final func setOption<T: ChannelOption>(option: T, value: T.OptionType) -> EventLoopFuture<Void> { public final func setOption<T: ChannelOption>(option: T, value: T.OptionType) -> EventLoopFuture<Void> {
if eventLoop.inEventLoop { if eventLoop.inEventLoop {
@ -240,18 +333,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
} }
} }
public final var localAddress: SocketAddress? {
get {
return socket.localAddress
}
}
public final var remoteAddress: SocketAddress? {
get {
return socket.remoteAddress
}
}
/// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.` /// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.`
/// ///
/// - returns: `true` if `readPending` is `true`, `false` otherwise. /// - returns: `true` if `readPending` is `true`, `false` otherwise.
@ -275,6 +356,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
executeAndComplete(promise) { executeAndComplete(promise) {
try socket.bind(to: address) try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
} }
} }
@ -401,6 +483,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
// Fail all pending writes and so ensure all pending promises are notified // Fail all pending writes and so ensure all pending promises are notified
self._closed = true self._closed = true
self.unsetCachedAddressesFromSocket()
self.cancelWritesOnClose(error: error) self.cancelWritesOnClose(error: error)
becomeInactive0() becomeInactive0()
@ -515,6 +598,22 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
readIfNeeded0() readIfNeeded0()
} }
internal final func updateCachedAddressesFromSocket(updateLocal: Bool = true, updateRemote: Bool = true) {
assert(self.eventLoop.inEventLoop)
if updateLocal {
self.localAddressCached.store(Box(try? self.localAddress0()))
}
if updateRemote {
self.remoteAddressCached.store(Box(try? self.remoteAddress0()))
}
}
internal final func unsetCachedAddressesFromSocket() {
assert(self.eventLoop.inEventLoop)
self.localAddressCached.store(Box(nil))
self.remoteAddressCached.store(Box(nil))
}
public final func connect0(to address: SocketAddress, promise: EventLoopPromise<Void>?) { public final func connect0(to address: SocketAddress, promise: EventLoopPromise<Void>?) {
assert(eventLoop.inEventLoop) assert(eventLoop.inEventLoop)
@ -529,6 +628,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
} }
do { do {
if try !connectSocket(to: address) { if try !connectSocket(to: address) {
self.updateCachedAddressesFromSocket()
if promise != nil { if promise != nil {
pendingConnect = promise pendingConnect = promise
} else { } else {
@ -536,6 +636,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
} }
registerForWritable() registerForWritable()
} else { } else {
self.updateCachedAddressesFromSocket()
promise?.succeed(result: ()) promise?.succeed(result: ())
} }
} catch let error { } catch let error {
@ -602,19 +703,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
active.store(false) active.store(false)
pipeline.fireChannelInactive0() pipeline.fireChannelInactive0()
} }
fileprivate init(socket: T, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
self.socket = socket
self.selectableEventLoop = eventLoop
self.closePromise = eventLoop.newPromise()
active.store(false)
self.recvAllocator = recvAllocator
self._pipeline = ChannelPipeline(channel: self)
}
deinit {
assert(self._closed, "leak of open Channel")
}
} }
/// A `Channel` for a client socket. /// A `Channel` for a client socket.
@ -628,10 +716,9 @@ final class SocketChannel: BaseSocketChannel<Socket> {
private var inputShutdown: Bool = false private var inputShutdown: Bool = false
private var outputShutdown: Bool = false private var outputShutdown: Bool = false
// Guard against re-entrance of flushNow() method.
private var inFlushNow: Bool = false
private let pendingWrites: PendingStreamWritesManager private let pendingWrites: PendingStreamWritesManager
// This is `Channel` API so must be thread-safe.
override public var isWritable: Bool { override public var isWritable: Bool {
return pendingWrites.isWritable return pendingWrites.isWritable
} }
@ -694,15 +781,10 @@ final class SocketChannel: BaseSocketChannel<Socket> {
return .socketChannel(self, interested) return .socketChannel(self, interested)
} }
fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws { fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
try socket.setNonBlocking() try socket.setNonBlocking()
self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs) self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs)
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator())
}
fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws {
try self.init(socket: socket, eventLoop: eventLoop)
self.parent = parent
} }
override fileprivate func readFromSocket() throws -> ReadResult { override fileprivate func readFromSocket() throws -> ReadResult {
@ -749,8 +831,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
return result return result
} }
private func writeToSocket(pendingWrites: PendingStreamWritesManager) throws -> OverallWriteResult { override func writeToSocket() throws -> OverallWriteResult {
let result = try pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in
guard ptr.count > 0 else { guard ptr.count > 0 else {
// No need to call write if the buffer is empty. // No need to call write if the buffer is empty.
return .processed(0) return .processed(0)
@ -882,43 +964,6 @@ final class SocketChannel: BaseSocketChannel<Socket> {
pipeline.fireChannelWritabilityChanged0() pipeline.fireChannelWritabilityChanged0()
} }
} }
override fileprivate func flushNow() -> IONotificationState {
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
// `writeToSocket`.
guard !inFlushNow && !closed else {
return .unregister
}
defer {
inFlushNow = false
}
inFlushNow = true
do {
switch try self.writeToSocket(pendingWrites: pendingWrites) {
case .couldNotWriteEverything:
return .register
case .writtenCompletely:
return .unregister
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if readIfNeeded0() {
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
while let read = try? readFromSocket(), read == .some {
pipeline.fireChannelReadComplete()
}
}
close0(error: err, mode: .all, promise: nil)
// we handled all writes
return .unregister
}
}
} }
/// A `Channel` for a server socket. /// A `Channel` for a server socket.
@ -930,6 +975,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
private let group: EventLoopGroup private let group: EventLoopGroup
/// The server socket channel is never writable. /// The server socket channel is never writable.
// This is `Channel` API so must be thread-safe.
override public var isWritable: Bool { return false } override public var isWritable: Bool { return false }
init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: Int32) throws { init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: Int32) throws {
@ -987,6 +1033,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
} }
executeAndComplete(p) { executeAndComplete(p) {
try socket.bind(to: address) try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
try self.socket.listen(backlog: backlog) try self.socket.listen(backlog: backlog)
} }
} }
@ -1009,7 +1056,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
readPending = false readPending = false
result = .some result = .some
do { do {
let chan = try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop, parent: self) let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop)
pipeline.fireChannelRead0(NIOAny(chan)) pipeline.fireChannelRead0(NIOAny(chan))
} catch let err { } catch let err {
_ = try? accepted.close() _ = try? accepted.close()
@ -1054,9 +1101,9 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
final class DatagramChannel: BaseSocketChannel<Socket> { final class DatagramChannel: BaseSocketChannel<Socket> {
// Guard against re-entrance of flushNow() method. // Guard against re-entrance of flushNow() method.
private var inFlushNow: Bool = false
private let pendingWrites: PendingDatagramWritesManager private let pendingWrites: PendingDatagramWritesManager
// This is `Channel` API so must be thread-safe.
override public var isWritable: Bool { override public var isWritable: Bool {
return pendingWrites.isWritable return pendingWrites.isWritable
} }
@ -1082,18 +1129,13 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
} }
fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws { fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
try socket.setNonBlocking() try socket.setNonBlocking()
self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs, self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs,
iovecs: eventLoop.iovecs, iovecs: eventLoop.iovecs,
addresses: eventLoop.addresses, addresses: eventLoop.addresses,
storageRefs: eventLoop.storageRefs) storageRefs: eventLoop.storageRefs)
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
}
fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws {
try self.init(socket: socket, eventLoop: eventLoop)
self.parent = parent
} }
// MARK: Datagram Channel overrides required by BaseSocketChannel // MARK: Datagram Channel overrides required by BaseSocketChannel
@ -1156,22 +1198,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
let mayGrow = recvAllocator.record(actualReadBytes: bytesRead) let mayGrow = recvAllocator.record(actualReadBytes: bytesRead)
readPending = false readPending = false
let sourceAddress: SocketAddress let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer)
switch Int(rawAddressLength) {
case MemoryLayout<sockaddr_in>.size:
let addr: sockaddr_in = rawAddress.convert()
sourceAddress = .init(addr, host: "")
case MemoryLayout<sockaddr_in6>.size:
let addr: sockaddr_in6 = rawAddress.convert()
sourceAddress = .init(addr, host: "")
case MemoryLayout<sockaddr_un>.size:
let addr: sockaddr_un = rawAddress.convert()
sourceAddress = .init(addr)
default:
fatalError("Unexpected sockaddr size")
}
let msg = AddressedEnvelope(remoteAddress: sourceAddress, data: buffer)
pipeline.fireChannelRead0(NIOAny(msg)) pipeline.fireChannelRead0(NIOAny(msg))
if mayGrow && i < maxMessagesPerRead { if mayGrow && i < maxMessagesPerRead {
buffer = recvAllocator.buffer(allocator: allocator) buffer = recvAllocator.buffer(allocator: allocator)
@ -1211,45 +1238,8 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
self.pendingWrites.failAll(error: error, close: true) self.pendingWrites.failAll(error: error, close: true)
} }
override fileprivate func flushNow() -> IONotificationState { override func writeToSocket() throws -> OverallWriteResult {
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in
// `writeToSocket`.
guard !inFlushNow && !closed else {
return .unregister
}
defer {
inFlushNow = false
}
inFlushNow = true
do {
switch try self.writeToSocket(pendingWrites: pendingWrites) {
case .couldNotWriteEverything:
return .register
case .writtenCompletely:
return .unregister
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if readIfNeeded0() {
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
while let read = try? readFromSocket(), read == .some {
pipeline.fireChannelReadComplete()
}
}
close0(error: err, mode: .all, promise: nil)
// we handled all writes
return .unregister
}
}
private func writeToSocket(pendingWrites: PendingDatagramWritesManager) throws -> OverallWriteResult {
let result = try pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in
guard ptr.count > 0 else { guard ptr.count > 0 else {
// No need to call write if the buffer is empty. // No need to call write if the buffer is empty.
return .processed(0) return .processed(0)
@ -1273,6 +1263,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
assert(self.eventLoop.inEventLoop) assert(self.eventLoop.inEventLoop)
do { do {
try socket.bind(to: address) try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
promise?.succeed(result: ()) promise?.succeed(result: ())
becomeActive0() becomeActive0()
readIfNeeded0() readIfNeeded0()

View File

@ -50,6 +50,12 @@ private let sysLseek = lseek
private let sysRecvFrom = recvfrom private let sysRecvFrom = recvfrom
private let sysSendTo = sendto private let sysSendTo = sendto
private let sysDup = dup private let sysDup = dup
private let sysGetpeername = getpeername
private let sysGetsockname = getsockname
private let sysAF_INET = AF_INET
private let sysAF_INET6 = AF_INET6
private let sysAF_UNIX = AF_UNIX
private let sysInet_ntop = inet_ntop
#if os(Linux) #if os(Linux)
private let sysSendMmsg = CNIOLinux_sendmmsg private let sysSendMmsg = CNIOLinux_sendmmsg
@ -110,6 +116,23 @@ internal func wrapSyscall<T: FixedWidthInteger>(where function: StaticString = #
} }
} }
/* Sorry, we really try hard to not use underscored attributes. In this case however we seem to break the inlining threshold which makes a system call take twice the time, ie. we need this exception. */
@inline(__always)
internal func wrapErrorIsNullReturnCall(where function: StaticString = #function, _ fn: () throws -> UnsafePointer<CChar>?) throws -> UnsafePointer<CChar>? {
while true {
let res = try fn()
if res == nil {
let err = errno
if err == EINTR {
continue
}
assert(!isBlacklistedErrno(err), "blacklisted errno \(err) \(strerror(err)!)")
throw IOError(errnoCode: err, function: function)
}
return res
}
}
enum Shutdown { enum Shutdown {
case RD case RD
case WR case WR
@ -148,6 +171,9 @@ internal enum Posix {
} }
#endif #endif
static let AF_INET = sa_family_t(sysAF_INET)
static let AF_INET6 = sa_family_t(sysAF_INET6)
static let AF_UNIX = sa_family_t(sysAF_UNIX)
@inline(never) @inline(never)
public static func shutdown(descriptor: CInt, how: Shutdown) throws { public static func shutdown(descriptor: CInt, how: Shutdown) throws {
@ -321,6 +347,14 @@ internal enum Posix {
} }
} }
@discardableResult
@inline(never)
public static func inet_ntop(addressFamily: CInt, addressBytes: UnsafeRawPointer, addressDescription: UnsafeMutablePointer<CChar>, addressDescriptionLength: socklen_t) throws -> UnsafePointer<CChar>? {
return try wrapErrorIsNullReturnCall {
sysInet_ntop(addressFamily, addressBytes, addressDescription, addressDescriptionLength)
}
}
// Its not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple // Its not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple
@inline(never) @inline(never)
public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult<Int> { public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult<Int> {
@ -365,6 +399,21 @@ internal enum Posix {
Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout)) Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout))
} }
} }
@inline(never)
public static func getpeername(socket: CInt, address: UnsafeMutablePointer<sockaddr>, addressLength: UnsafeMutablePointer<socklen_t>) throws {
_ = try wrapSyscall {
return sysGetpeername(socket, address, addressLength)
}
}
@inline(never)
public static func getsockname(socket: CInt, address: UnsafeMutablePointer<sockaddr>, addressLength: UnsafeMutablePointer<socklen_t>) throws {
_ = try wrapSyscall {
return sysGetsockname(socket, address, addressLength)
}
}
} }
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)

View File

@ -74,7 +74,7 @@ final class ChatHandler: ChannelInboundHandler {
} }
public func channelActive(ctx: ChannelHandlerContext) { public func channelActive(ctx: ChannelHandlerContext) {
let remoteAddress = ctx.channel.remoteAddress! let remoteAddress = ctx.remoteAddress!
let channel = ctx.channel let channel = ctx.channel
self.channelsSyncQueue.async { self.channelsSyncQueue.async {
// broadcast the message to all the connected clients except the one that just became active. // broadcast the message to all the connected clients except the one that just became active.
@ -84,7 +84,7 @@ final class ChatHandler: ChannelInboundHandler {
} }
var buffer = channel.allocator.buffer(capacity: 64) var buffer = channel.allocator.buffer(capacity: 64)
buffer.write(string: "(ChatServer) - Welcome to: \(channel.localAddress!)\n") buffer.write(string: "(ChatServer) - Welcome to: \(ctx.localAddress!)\n")
ctx.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) ctx.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
} }

View File

@ -307,3 +307,91 @@ extension UInt: AtomicPrimitive {
public static let atomic_load = catmc_atomic_unsigned_long_load public static let atomic_load = catmc_atomic_unsigned_long_load
public static let atomic_store = catmc_atomic_unsigned_long_store public static let atomic_store = catmc_atomic_unsigned_long_store
} }
/// `AtomicBox` is a heap-allocated box which allows atomic access to an instance of a Swift class.
///
/// It behaves very much like `Atomic<T>` but for objects, maintaining the correct retain counts.
public class AtomicBox<T: AnyObject> {
private let storage: Atomic<Int>
public init(value: T) {
let ptr = Unmanaged<T>.passRetained(value)
self.storage = Atomic(value: Int(bitPattern: ptr.toOpaque()))
}
deinit {
let oldPtrBits = self.storage.exchange(with: 0xdeadbeef)
let oldPtr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!)
oldPtr.release()
}
/// Atomically compares the value against `expected` and, if they are equal,
/// replaces the value with `desired`.
///
/// This implementation conforms to C11's `atomic_compare_exchange_strong`. This
/// means that the compare-and-swap will always succeed if `expected` is equal to
/// value. Additionally, it uses a *sequentially consistent ordering*. For more
/// details on atomic memory models, check the documentation for C11's
/// `stdatomic.h`.
///
/// - Parameter expected: The value that this object must currently hold for the
/// compare-and-swap to succeed.
/// - Parameter desired: The new value that this object will hold if the compare
/// succeeds.
/// - Returns: `True` if the exchange occurred, or `False` if `expected` did not
/// match the current value and so no exchange occurred.
public func compareAndExchange(expected: T, desired: T) -> Bool {
return withExtendedLifetime(desired) {
let expectedPtr = Unmanaged<T>.passUnretained(expected)
let desiredPtr = Unmanaged<T>.passUnretained(desired)
if self.storage.compareAndExchange(expected: Int(bitPattern: expectedPtr.toOpaque()),
desired: Int(bitPattern: desiredPtr.toOpaque())) {
_ = desiredPtr.retain()
expectedPtr.release()
return true
} else {
return false
}
}
}
/// Atomically exchanges `value` for the current value of this object.
///
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
/// more than that this operation is atomic: there is no guarantee that any other
/// event will be ordered before or after this one.
///
/// - Parameter value: The new value to set this object to.
/// - Returns: The value previously held by this object.
public func exchange(with value: T) -> T {
let newPtr = Unmanaged<T>.passRetained(value)
let oldPtrBits = self.storage.exchange(with: Int(bitPattern: newPtr.toOpaque()))
let oldPtr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!)
return oldPtr.takeRetainedValue()
}
/// Atomically loads and returns the value of this object.
///
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
/// more than that this operation is atomic: there is no guarantee that any other
/// event will be ordered before or after this one.
///
/// - Returns: The value of this object
public func load() -> T {
let ptrBits = self.storage.load()
let ptr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: ptrBits)!)
return ptr.takeUnretainedValue()
}
/// Atomically replaces the value of this object with `value`.
///
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
/// more than that this operation is atomic: there is no guarantee that any other
/// event will be ordered before or after this one.
///
/// - Parameter value: The new value to set the object to.
public func store(_ value: T) -> Void {
_ = self.exchange(with: value)
}
}

View File

@ -41,7 +41,7 @@ private final class EchoHandler: ChannelInboundHandler {
} }
public func channelActive(ctx: ChannelHandlerContext) { public func channelActive(ctx: ChannelHandlerContext) {
print("Client connected to \(ctx.channel.remoteAddress!)") print("Client connected to \(ctx.remoteAddress!)")
// We are connected its time to send the message to the server to initialize the ping-pong sequence. // We are connected its time to send the message to the server to initialize the ping-pong sequence.
var buffer = ctx.channel.allocator.buffer(capacity: line.utf8.count) var buffer = ctx.channel.allocator.buffer(capacity: line.utf8.count)

View File

@ -72,7 +72,7 @@ private final class HTTPHandler: ChannelInboundHandler {
URL: \(self.infoSavedRequestHead!.uri)\r URL: \(self.infoSavedRequestHead!.uri)\r
body length: \(self.infoSavedBodyBytes)\r body length: \(self.infoSavedBodyBytes)\r
headers: \(self.infoSavedRequestHead!.headers)\r headers: \(self.infoSavedRequestHead!.headers)\r
client: \(ctx.channel.remoteAddress?.description ?? "zombie")\r client: \(ctx.remoteAddress?.description ?? "zombie")\r
IO: SwiftNIO Electric Boogaloo\r\n IO: SwiftNIO Electric Boogaloo\r\n
""" """
self.buffer.clear() self.buffer.clear()
@ -200,7 +200,7 @@ private final class HTTPHandler: ChannelInboundHandler {
case "/dynamic/count-to-ten": case "/dynamic/count-to-ten":
return { self.handleMultipleWrites(ctx: $0, request: $1, strings: (1...10).map { "\($0)\r\n" }, delay: .milliseconds(100)) } return { self.handleMultipleWrites(ctx: $0, request: $1, strings: (1...10).map { "\($0)\r\n" }, delay: .milliseconds(100)) }
case "/dynamic/client-ip": case "/dynamic/client-ip":
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.channel.remoteAddress.debugDescription)") } return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.remoteAddress.debugDescription)") }
default: default:
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, statusCode: .notFound, string: "not found") } return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, statusCode: .notFound, string: "not found") }
} }

View File

@ -39,6 +39,11 @@ extension NIOConcurrencyHelpersTests {
("testConditionLockMutualExclusion", testConditionLockMutualExclusion), ("testConditionLockMutualExclusion", testConditionLockMutualExclusion),
("testConditionLock", testConditionLock), ("testConditionLock", testConditionLock),
("testConditionLockWithDifferentConditions", testConditionLockWithDifferentConditions), ("testConditionLockWithDifferentConditions", testConditionLockWithDifferentConditions),
("testAtomicBoxDoesNotTriviallyLeak", testAtomicBoxDoesNotTriviallyLeak),
("testAtomicBoxCompareAndExchangeWorksIfEqual", testAtomicBoxCompareAndExchangeWorksIfEqual),
("testAtomicBoxCompareAndExchangeWorksIfNotEqual", testAtomicBoxCompareAndExchangeWorksIfNotEqual),
("testAtomicBoxStoreWorks", testAtomicBoxStoreWorks),
("testAtomicBoxCompareAndExchangeOntoItselfWorks", testAtomicBoxCompareAndExchangeOntoItselfWorks),
] ]
} }
} }

View File

@ -401,4 +401,149 @@ class NIOConcurrencyHelpersTests: XCTestCase {
doneSem.wait() /* job on 'q2' is done */ doneSem.wait() /* job on 'q2' is done */
} }
} }
func testAtomicBoxDoesNotTriviallyLeak() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
({
let someInstance = SomeClass()
weakSomeInstance1 = someInstance
let someAtomic = AtomicBox(value: someInstance)
let loadedFromAtomic = someAtomic.load()
weakSomeInstance2 = loadedFromAtomic
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssert(someInstance === loadedFromAtomic)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
}
func testAtomicBoxCompareAndExchangeWorksIfEqual() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
weak var weakSomeInstance3: SomeClass? = nil
({
let someInstance1 = SomeClass()
let someInstance2 = SomeClass()
weakSomeInstance1 = someInstance1
let atomic = AtomicBox(value: someInstance1)
var loadedFromAtomic = atomic.load()
XCTAssert(someInstance1 === loadedFromAtomic)
weakSomeInstance2 = loadedFromAtomic
XCTAssertTrue(atomic.compareAndExchange(expected: loadedFromAtomic, desired: someInstance2))
loadedFromAtomic = atomic.load()
weakSomeInstance3 = loadedFromAtomic
XCTAssert(someInstance1 !== loadedFromAtomic)
XCTAssert(someInstance2 === loadedFromAtomic)
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssertNotNil(weakSomeInstance3)
XCTAssert(weakSomeInstance1 === weakSomeInstance2 && weakSomeInstance2 !== weakSomeInstance3)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
XCTAssertNil(weakSomeInstance3)
}
func testAtomicBoxCompareAndExchangeWorksIfNotEqual() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
weak var weakSomeInstance3: SomeClass? = nil
({
let someInstance1 = SomeClass()
let someInstance2 = SomeClass()
weakSomeInstance1 = someInstance1
let atomic = AtomicBox(value: someInstance1)
var loadedFromAtomic = atomic.load()
XCTAssert(someInstance1 === loadedFromAtomic)
weakSomeInstance2 = loadedFromAtomic
XCTAssertFalse(atomic.compareAndExchange(expected: someInstance2, desired: someInstance2))
XCTAssertFalse(atomic.compareAndExchange(expected: SomeClass(), desired: someInstance2))
XCTAssertTrue(atomic.load() === someInstance1)
loadedFromAtomic = atomic.load()
weakSomeInstance3 = someInstance2
XCTAssert(someInstance1 === loadedFromAtomic)
XCTAssert(someInstance2 !== loadedFromAtomic)
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssertNotNil(weakSomeInstance3)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
XCTAssertNil(weakSomeInstance3)
}
func testAtomicBoxStoreWorks() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
weak var weakSomeInstance3: SomeClass? = nil
({
let someInstance1 = SomeClass()
let someInstance2 = SomeClass()
weakSomeInstance1 = someInstance1
let atomic = AtomicBox(value: someInstance1)
var loadedFromAtomic = atomic.load()
XCTAssert(someInstance1 === loadedFromAtomic)
weakSomeInstance2 = loadedFromAtomic
atomic.store(someInstance2)
loadedFromAtomic = atomic.load()
weakSomeInstance3 = loadedFromAtomic
XCTAssert(someInstance1 !== loadedFromAtomic)
XCTAssert(someInstance2 === loadedFromAtomic)
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssertNotNil(weakSomeInstance3)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
XCTAssertNil(weakSomeInstance3)
}
func testAtomicBoxCompareAndExchangeOntoItselfWorks() {
let q = DispatchQueue(label: "q")
let g = DispatchGroup()
let sem1 = DispatchSemaphore(value: 0)
let sem2 = DispatchSemaphore(value: 0)
class SomeClass {}
weak var weakInstance: SomeClass?
({
let instance = SomeClass()
weakInstance = instance
let atomic = AtomicBox(value: instance)
q.async(group: g) {
sem1.signal()
sem2.wait()
for _ in 0..<1000 {
XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance))
}
}
sem2.signal()
sem1.wait()
for _ in 0..<1000 {
XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance))
}
g.wait()
let v = atomic.load()
XCTAssert(v === instance)
})()
XCTAssertNil(weakInstance)
}
} }

View File

@ -19,6 +19,18 @@ import NIOFoundationCompat
import Dispatch import Dispatch
@testable import NIOHTTP1 @testable import NIOHTTP1
internal extension Channel {
func syncCloseAcceptingAlreadyClosed() throws {
do {
try self.close().wait()
} catch ChannelError.alreadyClosed {
/* we're happy with this one */
} catch let e {
throw e
}
}
}
extension Array where Array.Element == ByteBuffer { extension Array where Array.Element == ByteBuffer {
public func allAsBytes() -> [UInt8] { public func allAsBytes() -> [UInt8] {
var out: [UInt8] = [] var out: [UInt8] = []
@ -61,17 +73,6 @@ internal class ArrayAccumulationHandler<T>: ChannelInboundHandler {
} }
class HTTPServerClientTest : XCTestCase { class HTTPServerClientTest : XCTestCase {
private func syncCloseAcceptingAlreadyClosed(channel: Channel) throws {
do {
try channel.close().wait()
} catch ChannelError.alreadyClosed {
/* we're happy with this one */
} catch let e {
throw e
}
}
/* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */ /* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */
private static let massiveResponseLength = 5 * 1024 * 1024 + 7 private static let massiveResponseLength = 5 * 1024 * 1024 + 7
private static let massiveResponseBytes: [UInt8] = { private static let massiveResponseBytes: [UInt8] = {
@ -374,7 +375,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait() }.bind(host: "127.0.0.1", port: 0).wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
} }
let clientChannel = try ClientBootstrap(group: group) let clientChannel = try ClientBootstrap(group: group)
@ -387,7 +388,7 @@ class HTTPServerClientTest : XCTestCase {
.wait() .wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
} }
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/helloworld") var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/helloworld")
@ -432,7 +433,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait() }.bind(host: "127.0.0.1", port: 0).wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
} }
let clientChannel = try ClientBootstrap(group: group) let clientChannel = try ClientBootstrap(group: group)
@ -445,7 +446,7 @@ class HTTPServerClientTest : XCTestCase {
.wait() .wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
} }
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/count-to-ten") var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/count-to-ten")
@ -490,7 +491,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait() }.bind(host: "127.0.0.1", port: 0).wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
} }
let clientChannel = try ClientBootstrap(group: group) let clientChannel = try ClientBootstrap(group: group)
@ -503,7 +504,7 @@ class HTTPServerClientTest : XCTestCase {
.wait() .wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
} }
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/trailers") var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/trailers")
@ -550,7 +551,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait() }.bind(host: "127.0.0.1", port: 0).wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
} }
let clientChannel = try ClientBootstrap(group: group) let clientChannel = try ClientBootstrap(group: group)
@ -559,7 +560,7 @@ class HTTPServerClientTest : XCTestCase {
.wait() .wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
} }
var buffer = clientChannel.allocator.buffer(capacity: numBytes) var buffer = clientChannel.allocator.buffer(capacity: numBytes)
@ -591,7 +592,7 @@ class HTTPServerClientTest : XCTestCase {
} }
}.bind(host: "127.0.0.1", port: 0).wait() }.bind(host: "127.0.0.1", port: 0).wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
} }
let clientChannel = try ClientBootstrap(group: group) let clientChannel = try ClientBootstrap(group: group)
@ -604,7 +605,7 @@ class HTTPServerClientTest : XCTestCase {
.wait() .wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
} }
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .HEAD, uri: "/head") var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .HEAD, uri: "/head")
@ -636,7 +637,7 @@ class HTTPServerClientTest : XCTestCase {
} }
}.bind(host: "127.0.0.1", port: 0).wait() }.bind(host: "127.0.0.1", port: 0).wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
} }
let clientChannel = try ClientBootstrap(group: group) let clientChannel = try ClientBootstrap(group: group)
@ -649,7 +650,7 @@ class HTTPServerClientTest : XCTestCase {
.wait() .wait()
defer { defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
} }
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/204") var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/204")

View File

@ -55,6 +55,7 @@ extension ChannelTests {
("testHalfClosure", testHalfClosure), ("testHalfClosure", testHalfClosure),
("testRejectsInvalidData", testRejectsInvalidData), ("testRejectsInvalidData", testRejectsInvalidData),
("testWeDontCrashIfChannelReleasesBeforePipeline", testWeDontCrashIfChannelReleasesBeforePipeline), ("testWeDontCrashIfChannelReleasesBeforePipeline", testWeDontCrashIfChannelReleasesBeforePipeline),
("testAskForLocalAndRemoteAddressesAfterChannelIsClosed", testAskForLocalAndRemoteAddressesAfterChannelIsClosed),
] ]
} }
} }

View File

@ -1157,7 +1157,7 @@ public class ChannelTests: XCTestCase {
return channel.pipeline.add(handler: byteCountingHandler) return channel.pipeline.add(handler: byteCountingHandler)
} }
} }
.connect(to: server.localAddress!) .connect(to: try! server.localAddress())
let accepted = try server.accept()! let accepted = try server.accept()!
defer { defer {
XCTAssertNoThrow(try accepted.close()) XCTAssertNoThrow(try accepted.close())
@ -1219,7 +1219,7 @@ public class ChannelTests: XCTestCase {
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false)) return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
} }
} }
.connect(to: server.localAddress!) .connect(to: try! server.localAddress())
let accepted = try server.accept()! let accepted = try server.accept()!
defer { defer {
XCTAssertNoThrow(try accepted.close()) XCTAssertNoThrow(try accepted.close())
@ -1269,7 +1269,7 @@ public class ChannelTests: XCTestCase {
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false)) return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
} }
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
.connect(to: server.localAddress!) .connect(to: try! server.localAddress())
let accepted = try server.accept()! let accepted = try server.accept()!
defer { defer {
XCTAssertNoThrow(try accepted.close()) XCTAssertNoThrow(try accepted.close())
@ -1412,4 +1412,27 @@ public class ChannelTests: XCTestCase {
XCTAssertNil(weakServerChannel, "weakServerChannel not nil, looks like we leaked it!") XCTAssertNil(weakServerChannel, "weakServerChannel not nil, looks like we leaked it!")
XCTAssertNil(weakServerChildChannel, "weakServerChildChannel not nil, looks like we leaked it!") XCTAssertNil(weakServerChildChannel, "weakServerChildChannel not nil, looks like we leaked it!")
} }
func testAskForLocalAndRemoteAddressesAfterChannelIsClosed() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let serverChannel = try ServerBootstrap(group: group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.bind(host: "127.0.0.1", port: 0).wait()
let clientChannel = try ClientBootstrap(group: group)
.connect(to: serverChannel.localAddress!).wait()
// Start shutting stuff down.
try serverChannel.syncCloseAcceptingAlreadyClosed()
try clientChannel.syncCloseAcceptingAlreadyClosed()
for f in [ serverChannel.remoteAddress, serverChannel.localAddress, clientChannel.remoteAddress, clientChannel.localAddress ] {
XCTAssertNil(f)
}
}
} }

View File

@ -69,6 +69,17 @@ func openTemporaryFile() -> (CInt, String) {
return (fd, String(decoding: templateBytes, as: UTF8.self)) return (fd, String(decoding: templateBytes, as: UTF8.self))
} }
internal extension Channel {
func syncCloseAcceptingAlreadyClosed() throws {
do {
try self.close().wait()
} catch ChannelError.alreadyClosed {
/* we're happy with this one */
} catch let e {
throw e
}
}
}
final class ByteCountingHandler : ChannelInboundHandler { final class ByteCountingHandler : ChannelInboundHandler {
typealias InboundIn = ByteBuffer typealias InboundIn = ByteBuffer