fix Channel thread-safety (local/remoteAddress)

Motivation:

There were a couple of places in the Channel implementations that just
weren't thread-safe at all, namely:

- localAddress
- remoetAddress
- parent

those are now fixed. I also re-ordered the code so it should be easier
to maintain in the future (the `// MARK` markers were totally
incorrect).

Modifications:

- made Channel.{local,remote}Address return a future
- made Channel.parent a `let`
- unified more of `SocketChannel` and `DatagramSocketChannel`
- fixed the `// MARK`s by reordering code
- annotated methods/properties that are in the Channel API and need to
be thread-safe

Result:

slightly more thread-safety :)
This commit is contained in:
Johannes Weiss 2018-02-20 15:18:51 +00:00
parent 273932ae36
commit 68724be6e6
19 changed files with 638 additions and 275 deletions

View File

@ -21,6 +21,19 @@ protocol SockAddrProtocol {
mutating func withMutableSockAddr<R>(_ fn: (UnsafeMutablePointer<sockaddr>, Int) throws -> R) rethrows -> R
}
private func descriptionForAddress(family: CInt, bytes: UnsafeRawPointer, length byteCount: Int) -> String {
var addressBytes: [Int8] = Array(repeating: 0, count: byteCount)
return addressBytes.withUnsafeMutableBufferPointer { (addressBytesPtr: inout UnsafeMutableBufferPointer<Int8>) -> String in
try! Posix.inet_ntop(addressFamily: family,
addressBytes: bytes,
addressDescription: addressBytesPtr.baseAddress!,
addressDescriptionLength: socklen_t(byteCount))
return addressBytesPtr.baseAddress!.withMemoryRebound(to: UInt8.self, capacity: byteCount) { addressBytesPtr -> String in
return String(decoding: UnsafeBufferPointer<UInt8>(start: addressBytesPtr, count: byteCount), as: Unicode.ASCII.self)
}
}
}
extension sockaddr_in: SockAddrProtocol {
mutating func withSockAddr<R>(_ fn: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R {
var me = self
@ -35,6 +48,12 @@ extension sockaddr_in: SockAddrProtocol {
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
}
}
mutating func addressDescription() -> String {
return withUnsafePointer(to: &self.sin_addr) { addrPtr in
descriptionForAddress(family: AF_INET, bytes: addrPtr, length: Int(INET_ADDRSTRLEN))
}
}
}
extension sockaddr_in6: SockAddrProtocol {
@ -51,6 +70,12 @@ extension sockaddr_in6: SockAddrProtocol {
try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
}
}
mutating func addressDescription() -> String {
return withUnsafePointer(to: &self.sin6_addr) { addrPtr in
descriptionForAddress(family: AF_INET6, bytes: addrPtr, length: Int(INET6_ADDRSTRLEN))
}
}
}
extension sockaddr_un: SockAddrProtocol {
@ -119,54 +144,45 @@ extension sockaddr_storage {
}
}
}
mutating func convert() -> SocketAddress {
switch self.ss_family {
case Posix.AF_INET:
var sockAddr: sockaddr_in = self.convert()
return SocketAddress(sockAddr, host: sockAddr.addressDescription())
case Posix.AF_INET6:
var sockAddr: sockaddr_in6 = self.convert()
return SocketAddress(sockAddr, host: sockAddr.addressDescription())
case Posix.AF_UNIX:
return SocketAddress(self.convert() as sockaddr_un)
default:
fatalError("unknown sockaddr family \(self.ss_family)")
}
}
}
class BaseSocket: Selectable {
public let descriptor: Int32
public private(set) var open: Bool
final var localAddress: SocketAddress? {
get {
return get_addr { getsockname($0, $1, $2) }
}
final func localAddress() throws -> SocketAddress {
return try get_addr { try Posix.getsockname(socket: $0, address: $1, addressLength: $2) }
}
final var remoteAddress: SocketAddress? {
get {
return get_addr { getpeername($0, $1, $2) }
}
final func remoteAddress() throws -> SocketAddress {
return try get_addr { try Posix.getpeername(socket: $0, address: $1, addressLength: $2) }
}
private func get_addr(_ fn: (Int32, UnsafeMutablePointer<sockaddr>, UnsafeMutablePointer<socklen_t>) -> Int32) -> SocketAddress? {
private func get_addr(_ fn: (Int32, UnsafeMutablePointer<sockaddr>, UnsafeMutablePointer<socklen_t>) throws -> Void) throws -> SocketAddress {
var addr = sockaddr_storage()
var len: socklen_t = socklen_t(MemoryLayout<sockaddr_storage>.size)
return withUnsafeMutablePointer(to: &addr) {
$0.withMemoryRebound(to: sockaddr.self, capacity: 1, { address in
guard fn(descriptor, address, &len) == 0 else {
return nil
}
switch Int32(address.pointee.sa_family) {
case AF_INET:
return address.withMemoryRebound(to: sockaddr_in.self, capacity: 1, { ipv4 in
var ipAddressString = [CChar](repeating: 0, count: Int(INET_ADDRSTRLEN))
return SocketAddress(ipv4.pointee, host: String(cString: inet_ntop(AF_INET, &ipv4.pointee.sin_addr, &ipAddressString, socklen_t(INET_ADDRSTRLEN))))
})
case AF_INET6:
return address.withMemoryRebound(to: sockaddr_in6.self, capacity: 1, { ipv6 in
var ipAddressString = [CChar](repeating: 0, count: Int(INET6_ADDRSTRLEN))
return SocketAddress(ipv6.pointee, host: String(cString: inet_ntop(AF_INET6, &ipv6.pointee.sin6_addr, &ipAddressString, socklen_t(INET6_ADDRSTRLEN))))
})
case AF_UNIX:
return address.withMemoryRebound(to: sockaddr_un.self, capacity: 1) { uds in
return SocketAddress(uds.pointee)
}
default:
fatalError("address family \(address.pointee.sa_family) not supported")
}
})
try addr.withMutableSockAddr { addressPtr, size in
var size = socklen_t(size)
try fn(self.descriptor, addressPtr, &size)
}
return addr.convert()
}
static func newSocket(protocolFamily: Int32, type: CInt) throws -> Int32 {
let sock = try Posix.socket(domain: protocolFamily,
type: type,

View File

@ -18,6 +18,8 @@ import NIOConcurrencyHelpers
///
/// - note: All methods must be called from the EventLoop thread
public protocol ChannelCore : class {
func localAddress0() throws -> SocketAddress
func remoteAddress0() throws -> SocketAddress
func register0(promise: EventLoopPromise<Void>?)
func bind0(to: SocketAddress, promise: EventLoopPromise<Void>?)
func connect0(to: SocketAddress, promise: EventLoopPromise<Void>?)
@ -82,7 +84,9 @@ public protocol Channel : class, ChannelOutboundInvoker {
/// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events
/// before possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`.
protocol SelectableChannel : Channel {
///
/// - warning: `SelectableChannel` methods and properties are _not_ thread-safe (unless they also belong to `Channel`).
internal protocol SelectableChannel : Channel {
/// The type of the `Selectable`. A `Selectable` is usually wrapping a file descriptor that can be registered in a
/// `Selector`.
associatedtype SelectableType: Selectable

View File

@ -809,6 +809,14 @@ public final class ChannelHandlerContext : ChannelInvoker {
return self.inboundHandler ?? self.outboundHandler!
}
public var remoteAddress: SocketAddress? {
return try? self.channel._unsafe.remoteAddress0()
}
public var localAddress: SocketAddress? {
return try? self.channel._unsafe.localAddress0()
}
public let name: String
public let eventLoop: EventLoop
private let inboundHandler: _ChannelInboundHandler?

View File

@ -16,6 +16,14 @@
/// the original `Channel` is closed. Given that the original `Channel` is closed the `DeadChannelCore` should fail
/// all operations.
private final class DeadChannelCore: ChannelCore {
func localAddress0() throws -> SocketAddress {
throw ChannelError.ioOnClosedChannel
}
func remoteAddress0() throws -> SocketAddress {
throw ChannelError.ioOnClosedChannel
}
func register0(promise: EventLoopPromise<Void>?) {
promise?.fail(error: ChannelError.ioOnClosedChannel)
}
@ -65,27 +73,30 @@ private final class DeadChannelCore: ChannelCore {
/// channel as it only holds an unowned reference to the original `Channel`. `DeadChannel` serves as a replacement
/// that can be used when the original `Channel` might no longer be valid.
internal final class DeadChannel: Channel {
let eventLoop: EventLoop
let pipeline: ChannelPipeline
var eventLoop: EventLoop {
return self.pipeline.eventLoop
public var closeFuture: EventLoopFuture<()> {
return self.eventLoop.newSucceededFuture(result: ())
}
internal init(pipeline: ChannelPipeline) {
self.pipeline = pipeline
self.eventLoop = pipeline.eventLoop
}
// This is `Channel` API so must be thread-safe.
var allocator: ByteBufferAllocator {
return ByteBufferAllocator()
}
var closeFuture: EventLoopFuture<Void> {
return self.pipeline.eventLoop.newSucceededFuture(result: ())
var localAddress: SocketAddress? {
return nil
}
let localAddress: SocketAddress? = nil
let remoteAddress: SocketAddress? = nil
var remoteAddress: SocketAddress? {
return nil
}
let parent: Channel? = nil

View File

@ -140,7 +140,6 @@ public class EmbeddedEventLoop: EventLoop {
class EmbeddedChannelCore : ChannelCore {
var closed: Bool = false
var isActive: Bool = false
var eventLoop: EventLoop
var closePromise: EventLoopPromise<Void>
@ -163,6 +162,14 @@ class EmbeddedChannelCore : ChannelCore {
var outboundBuffer: [IOData] = []
var inboundBuffer: [NIOAny] = []
func localAddress0() throws -> SocketAddress {
throw ChannelError.operationUnsupported
}
func remoteAddress0() throws -> SocketAddress {
throw ChannelError.operationUnsupported
}
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
if closed {
promise?.fail(error: ChannelError.alreadyClosed)
@ -269,8 +276,8 @@ public class EmbeddedChannel : Channel {
public var allocator: ByteBufferAllocator = ByteBufferAllocator()
public var eventLoop: EventLoop = EmbeddedEventLoop()
public var localAddress: SocketAddress? = nil
public var remoteAddress: SocketAddress? = nil
public let localAddress: SocketAddress? = nil
public let remoteAddress: SocketAddress? = nil
// Embedded channels never have parents.
public let parent: Channel? = nil

View File

@ -465,6 +465,7 @@ internal protocol PendingWritesManager {
}
extension PendingWritesManager {
// This is called from `Channel` API so must be thread-safe.
var isWritable: Bool {
return self.channelWritabilityFlag.load()
}

View File

@ -13,6 +13,8 @@
//===----------------------------------------------------------------------===//
/// Represents a selectable resource which can be registered to a `Selector` to be notified once there are some events ready for it.
///
/// - warning: `Selectable`s are not thread-safe, only to be used on the appropriate `EventLoop`.
protocol Selectable {
/// The file descriptor itself.

View File

@ -43,6 +43,38 @@ private extension ByteBuffer {
class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
typealias SelectableType = T
// MARK: Stored Properties
// Visible to access from EventLoop directly
public let parent: Channel?
internal let socket: T
private let closePromise: EventLoopPromise<Void>
private let selectableEventLoop: SelectableEventLoop
private let localAddressCached: AtomicBox<Box<SocketAddress?>> = AtomicBox(value: Box(nil))
private let remoteAddressCached: AtomicBox<Box<SocketAddress?>> = AtomicBox(value: Box(nil))
private let bufferAllocatorCached: AtomicBox<Box<ByteBufferAllocator>>
internal var interestedEvent: IOEvent = .none
fileprivate var readPending = false
fileprivate var pendingConnect: EventLoopPromise<Void>?
fileprivate var recvAllocator: RecvByteBufferAllocator
fileprivate var maxMessagesPerRead: UInt = 4
private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method.
private var neverRegistered = true
private var active: Atomic<Bool> = Atomic(value: false)
private var _closed: Bool = false
private var autoRead: Bool = true
private var _pipeline: ChannelPipeline!
private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator() {
didSet {
assert(self.eventLoop.inEventLoop)
self.bufferAllocatorCached.store(Box(self.bufferAllocator))
}
}
// MARK: Datatypes
/// Indicates if a selectable should registered or not for IO notifications.
enum IONotificationState {
/// We should be registered for IO notifications.
@ -52,7 +84,26 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
case unregister
}
// MARK: Methods to override in subclasses.
fileprivate enum ReadResult {
/// Nothing was read by the read operation.
case none
/// Some data was read by the read operation.
case some
}
// MARK: Computed Properties
public final var _unsafe: ChannelCore { return self }
// This is `Channel` API so must be thread-safe.
public final var localAddress: SocketAddress? {
return self.localAddressCached.load().value
}
// This is `Channel` API so must be thread-safe.
public final var remoteAddress: SocketAddress? {
return self.remoteAddressCached.load().value
}
/// `true` if the whole `Channel` is closed and so no more IO operation can be done.
public var closed: Bool {
@ -60,6 +111,48 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
return _closed
}
internal var selectable: T {
return self.socket
}
// This is `Channel` API so must be thread-safe.
public var isActive: Bool {
return self.active.load()
}
// This is `Channel` API so must be thread-safe.
public final var closeFuture: EventLoopFuture<Void> {
return self.closePromise.futureResult
}
public final var eventLoop: EventLoop {
return selectableEventLoop
}
// This is `Channel` API so must be thread-safe.
public var isWritable: Bool {
return true
}
// This is `Channel` API so must be thread-safe.
public final var allocator: ByteBufferAllocator {
if eventLoop.inEventLoop {
return bufferAllocator
} else {
return self.bufferAllocatorCached.load().value
}
}
// This is `Channel` API so must be thread-safe.
public final var pipeline: ChannelPipeline {
return _pipeline
}
// MARK: Methods to override in subclasses.
func writeToSocket() throws -> OverallWriteResult {
fatalError("must be overridden")
}
/// Provides the registration for this selector. Must be implemented by subclasses.
func registrationFor(interested: IOEvent) -> NIORegistration {
fatalError("must override")
@ -72,14 +165,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
fatalError("this must be overridden by sub class")
}
fileprivate enum ReadResult {
/// Nothing was read by the read operation.
case none
/// Some data was read by the read operation.
case some
}
/// Begin connection of the underlying socket.
///
/// - parameters:
@ -111,70 +196,78 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
fatalError("this must be overridden by sub class")
}
// MARK: Common base socket logic.
fileprivate init(socket: T, parent: Channel? = nil, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
self.bufferAllocatorCached = AtomicBox(value: Box(self.bufferAllocator))
self.socket = socket
self.selectableEventLoop = eventLoop
self.closePromise = eventLoop.newPromise()
self.parent = parent
self.active.store(false)
self.recvAllocator = recvAllocator
self._pipeline = ChannelPipeline(channel: self)
}
deinit {
assert(self._closed, "leak of open Channel")
}
public final func localAddress0() throws -> SocketAddress {
assert(self.eventLoop.inEventLoop)
guard self.open else {
throw ChannelError.ioOnClosedChannel
}
return try self.socket.localAddress()
}
public final func remoteAddress0() throws -> SocketAddress {
assert(self.eventLoop.inEventLoop)
guard self.open else {
throw ChannelError.ioOnClosedChannel
}
return try self.socket.remoteAddress()
}
/// Flush data to the underlying socket and return if this socket needs to be registered for write notifications.
///
/// - returns: If this socket should be registered for write notifications. Ie. `IONotificationState.register` if _not_ all data could be written, so notifications are necessary; and `IONotificationState.unregister` if everything was written and we don't need to be notified about writability at the moment.
fileprivate func flushNow() -> IONotificationState {
fatalError("this must be overridden by sub class")
}
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
// `writeToSocket`.
guard !inFlushNow && !closed else {
return .unregister
}
// MARK: Common base socket logic.
defer {
inFlushNow = false
}
inFlushNow = true
var selectable: T {
return self.socket
}
do {
switch try self.writeToSocket() {
case .couldNotWriteEverything:
return .register
case .writtenCompletely:
return .unregister
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if readIfNeeded0() {
public final var _unsafe: ChannelCore { return self }
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
while let read = try? readFromSocket(), read == .some {
pipeline.fireChannelReadComplete()
}
}
// Visible to access from EventLoop directly
let socket: T
public var interestedEvent: IOEvent = .none
close0(error: err, mode: .all, promise: nil)
fileprivate var readPending = false
private var neverRegistered = true
fileprivate var pendingConnect: EventLoopPromise<Void>?
private let closePromise: EventLoopPromise<Void>
private var active: Atomic<Bool> = Atomic(value: false)
private var _closed: Bool = false
public var isActive: Bool {
return active.load()
}
public var parent: Channel? = nil
public final var closeFuture: EventLoopFuture<Void> {
return closePromise.futureResult
}
private let selectableEventLoop: SelectableEventLoop
public final var eventLoop: EventLoop {
return selectableEventLoop
}
public var isWritable: Bool {
return true
}
public final var allocator: ByteBufferAllocator {
if eventLoop.inEventLoop {
return bufferAllocator
} else {
return try! eventLoop.submit{ self.bufferAllocator }.wait()
// we handled all writes
return .unregister
}
}
private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator()
fileprivate var recvAllocator: RecvByteBufferAllocator
fileprivate var autoRead: Bool = true
fileprivate var maxMessagesPerRead: UInt = 4
// We don't use lazy var here as this is more expensive then doing this :/
public final var pipeline: ChannelPipeline {
return _pipeline
}
private var _pipeline: ChannelPipeline!
public final func setOption<T: ChannelOption>(option: T, value: T.OptionType) -> EventLoopFuture<Void> {
if eventLoop.inEventLoop {
@ -240,18 +333,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
public final var localAddress: SocketAddress? {
get {
return socket.localAddress
}
}
public final var remoteAddress: SocketAddress? {
get {
return socket.remoteAddress
}
}
/// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.`
///
/// - returns: `true` if `readPending` is `true`, `false` otherwise.
@ -275,6 +356,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
executeAndComplete(promise) {
try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
}
}
@ -401,6 +483,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
// Fail all pending writes and so ensure all pending promises are notified
self._closed = true
self.unsetCachedAddressesFromSocket()
self.cancelWritesOnClose(error: error)
becomeInactive0()
@ -515,6 +598,22 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
readIfNeeded0()
}
internal final func updateCachedAddressesFromSocket(updateLocal: Bool = true, updateRemote: Bool = true) {
assert(self.eventLoop.inEventLoop)
if updateLocal {
self.localAddressCached.store(Box(try? self.localAddress0()))
}
if updateRemote {
self.remoteAddressCached.store(Box(try? self.remoteAddress0()))
}
}
internal final func unsetCachedAddressesFromSocket() {
assert(self.eventLoop.inEventLoop)
self.localAddressCached.store(Box(nil))
self.remoteAddressCached.store(Box(nil))
}
public final func connect0(to address: SocketAddress, promise: EventLoopPromise<Void>?) {
assert(eventLoop.inEventLoop)
@ -529,6 +628,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
do {
if try !connectSocket(to: address) {
self.updateCachedAddressesFromSocket()
if promise != nil {
pendingConnect = promise
} else {
@ -536,6 +636,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
registerForWritable()
} else {
self.updateCachedAddressesFromSocket()
promise?.succeed(result: ())
}
} catch let error {
@ -602,19 +703,6 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
active.store(false)
pipeline.fireChannelInactive0()
}
fileprivate init(socket: T, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
self.socket = socket
self.selectableEventLoop = eventLoop
self.closePromise = eventLoop.newPromise()
active.store(false)
self.recvAllocator = recvAllocator
self._pipeline = ChannelPipeline(channel: self)
}
deinit {
assert(self._closed, "leak of open Channel")
}
}
/// A `Channel` for a client socket.
@ -628,10 +716,9 @@ final class SocketChannel: BaseSocketChannel<Socket> {
private var inputShutdown: Bool = false
private var outputShutdown: Bool = false
// Guard against re-entrance of flushNow() method.
private var inFlushNow: Bool = false
private let pendingWrites: PendingStreamWritesManager
// This is `Channel` API so must be thread-safe.
override public var isWritable: Bool {
return pendingWrites.isWritable
}
@ -694,15 +781,10 @@ final class SocketChannel: BaseSocketChannel<Socket> {
return .socketChannel(self, interested)
}
fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws {
fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
try socket.setNonBlocking()
self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs)
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator())
}
fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws {
try self.init(socket: socket, eventLoop: eventLoop)
self.parent = parent
try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator())
}
override fileprivate func readFromSocket() throws -> ReadResult {
@ -749,8 +831,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
return result
}
private func writeToSocket(pendingWrites: PendingStreamWritesManager) throws -> OverallWriteResult {
let result = try pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in
override func writeToSocket() throws -> OverallWriteResult {
let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in
guard ptr.count > 0 else {
// No need to call write if the buffer is empty.
return .processed(0)
@ -882,43 +964,6 @@ final class SocketChannel: BaseSocketChannel<Socket> {
pipeline.fireChannelWritabilityChanged0()
}
}
override fileprivate func flushNow() -> IONotificationState {
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
// `writeToSocket`.
guard !inFlushNow && !closed else {
return .unregister
}
defer {
inFlushNow = false
}
inFlushNow = true
do {
switch try self.writeToSocket(pendingWrites: pendingWrites) {
case .couldNotWriteEverything:
return .register
case .writtenCompletely:
return .unregister
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if readIfNeeded0() {
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
while let read = try? readFromSocket(), read == .some {
pipeline.fireChannelReadComplete()
}
}
close0(error: err, mode: .all, promise: nil)
// we handled all writes
return .unregister
}
}
}
/// A `Channel` for a server socket.
@ -930,6 +975,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
private let group: EventLoopGroup
/// The server socket channel is never writable.
// This is `Channel` API so must be thread-safe.
override public var isWritable: Bool { return false }
init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: Int32) throws {
@ -987,6 +1033,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
}
executeAndComplete(p) {
try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
try self.socket.listen(backlog: backlog)
}
}
@ -1009,7 +1056,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
readPending = false
result = .some
do {
let chan = try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop, parent: self)
let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop)
pipeline.fireChannelRead0(NIOAny(chan))
} catch let err {
_ = try? accepted.close()
@ -1054,9 +1101,9 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
final class DatagramChannel: BaseSocketChannel<Socket> {
// Guard against re-entrance of flushNow() method.
private var inFlushNow: Bool = false
private let pendingWrites: PendingDatagramWritesManager
// This is `Channel` API so must be thread-safe.
override public var isWritable: Bool {
return pendingWrites.isWritable
}
@ -1082,18 +1129,13 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
}
fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws {
fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
try socket.setNonBlocking()
self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs,
iovecs: eventLoop.iovecs,
addresses: eventLoop.addresses,
storageRefs: eventLoop.storageRefs)
try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
}
fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws {
try self.init(socket: socket, eventLoop: eventLoop)
self.parent = parent
try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
}
// MARK: Datagram Channel overrides required by BaseSocketChannel
@ -1156,22 +1198,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
let mayGrow = recvAllocator.record(actualReadBytes: bytesRead)
readPending = false
let sourceAddress: SocketAddress
switch Int(rawAddressLength) {
case MemoryLayout<sockaddr_in>.size:
let addr: sockaddr_in = rawAddress.convert()
sourceAddress = .init(addr, host: "")
case MemoryLayout<sockaddr_in6>.size:
let addr: sockaddr_in6 = rawAddress.convert()
sourceAddress = .init(addr, host: "")
case MemoryLayout<sockaddr_un>.size:
let addr: sockaddr_un = rawAddress.convert()
sourceAddress = .init(addr)
default:
fatalError("Unexpected sockaddr size")
}
let msg = AddressedEnvelope(remoteAddress: sourceAddress, data: buffer)
let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer)
pipeline.fireChannelRead0(NIOAny(msg))
if mayGrow && i < maxMessagesPerRead {
buffer = recvAllocator.buffer(allocator: allocator)
@ -1211,45 +1238,8 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
self.pendingWrites.failAll(error: error, close: true)
}
override fileprivate func flushNow() -> IONotificationState {
// Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by
// `writeToSocket`.
guard !inFlushNow && !closed else {
return .unregister
}
defer {
inFlushNow = false
}
inFlushNow = true
do {
switch try self.writeToSocket(pendingWrites: pendingWrites) {
case .couldNotWriteEverything:
return .register
case .writtenCompletely:
return .unregister
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if readIfNeeded0() {
// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
while let read = try? readFromSocket(), read == .some {
pipeline.fireChannelReadComplete()
}
}
close0(error: err, mode: .all, promise: nil)
// we handled all writes
return .unregister
}
}
private func writeToSocket(pendingWrites: PendingDatagramWritesManager) throws -> OverallWriteResult {
let result = try pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in
override func writeToSocket() throws -> OverallWriteResult {
let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in
guard ptr.count > 0 else {
// No need to call write if the buffer is empty.
return .processed(0)
@ -1273,6 +1263,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
assert(self.eventLoop.inEventLoop)
do {
try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
promise?.succeed(result: ())
becomeActive0()
readIfNeeded0()

View File

@ -50,6 +50,12 @@ private let sysLseek = lseek
private let sysRecvFrom = recvfrom
private let sysSendTo = sendto
private let sysDup = dup
private let sysGetpeername = getpeername
private let sysGetsockname = getsockname
private let sysAF_INET = AF_INET
private let sysAF_INET6 = AF_INET6
private let sysAF_UNIX = AF_UNIX
private let sysInet_ntop = inet_ntop
#if os(Linux)
private let sysSendMmsg = CNIOLinux_sendmmsg
@ -110,6 +116,23 @@ internal func wrapSyscall<T: FixedWidthInteger>(where function: StaticString = #
}
}
/* Sorry, we really try hard to not use underscored attributes. In this case however we seem to break the inlining threshold which makes a system call take twice the time, ie. we need this exception. */
@inline(__always)
internal func wrapErrorIsNullReturnCall(where function: StaticString = #function, _ fn: () throws -> UnsafePointer<CChar>?) throws -> UnsafePointer<CChar>? {
while true {
let res = try fn()
if res == nil {
let err = errno
if err == EINTR {
continue
}
assert(!isBlacklistedErrno(err), "blacklisted errno \(err) \(strerror(err)!)")
throw IOError(errnoCode: err, function: function)
}
return res
}
}
enum Shutdown {
case RD
case WR
@ -148,6 +171,9 @@ internal enum Posix {
}
#endif
static let AF_INET = sa_family_t(sysAF_INET)
static let AF_INET6 = sa_family_t(sysAF_INET6)
static let AF_UNIX = sa_family_t(sysAF_UNIX)
@inline(never)
public static func shutdown(descriptor: CInt, how: Shutdown) throws {
@ -321,6 +347,14 @@ internal enum Posix {
}
}
@discardableResult
@inline(never)
public static func inet_ntop(addressFamily: CInt, addressBytes: UnsafeRawPointer, addressDescription: UnsafeMutablePointer<CChar>, addressDescriptionLength: socklen_t) throws -> UnsafePointer<CChar>? {
return try wrapErrorIsNullReturnCall {
sysInet_ntop(addressFamily, addressBytes, addressDescription, addressDescriptionLength)
}
}
// Its not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple
@inline(never)
public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult<Int> {
@ -365,6 +399,21 @@ internal enum Posix {
Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout))
}
}
@inline(never)
public static func getpeername(socket: CInt, address: UnsafeMutablePointer<sockaddr>, addressLength: UnsafeMutablePointer<socklen_t>) throws {
_ = try wrapSyscall {
return sysGetpeername(socket, address, addressLength)
}
}
@inline(never)
public static func getsockname(socket: CInt, address: UnsafeMutablePointer<sockaddr>, addressLength: UnsafeMutablePointer<socklen_t>) throws {
_ = try wrapSyscall {
return sysGetsockname(socket, address, addressLength)
}
}
}
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)

View File

@ -74,7 +74,7 @@ final class ChatHandler: ChannelInboundHandler {
}
public func channelActive(ctx: ChannelHandlerContext) {
let remoteAddress = ctx.channel.remoteAddress!
let remoteAddress = ctx.remoteAddress!
let channel = ctx.channel
self.channelsSyncQueue.async {
// broadcast the message to all the connected clients except the one that just became active.
@ -84,7 +84,7 @@ final class ChatHandler: ChannelInboundHandler {
}
var buffer = channel.allocator.buffer(capacity: 64)
buffer.write(string: "(ChatServer) - Welcome to: \(channel.localAddress!)\n")
buffer.write(string: "(ChatServer) - Welcome to: \(ctx.localAddress!)\n")
ctx.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}

View File

@ -307,3 +307,91 @@ extension UInt: AtomicPrimitive {
public static let atomic_load = catmc_atomic_unsigned_long_load
public static let atomic_store = catmc_atomic_unsigned_long_store
}
/// `AtomicBox` is a heap-allocated box which allows atomic access to an instance of a Swift class.
///
/// It behaves very much like `Atomic<T>` but for objects, maintaining the correct retain counts.
public class AtomicBox<T: AnyObject> {
private let storage: Atomic<Int>
public init(value: T) {
let ptr = Unmanaged<T>.passRetained(value)
self.storage = Atomic(value: Int(bitPattern: ptr.toOpaque()))
}
deinit {
let oldPtrBits = self.storage.exchange(with: 0xdeadbeef)
let oldPtr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!)
oldPtr.release()
}
/// Atomically compares the value against `expected` and, if they are equal,
/// replaces the value with `desired`.
///
/// This implementation conforms to C11's `atomic_compare_exchange_strong`. This
/// means that the compare-and-swap will always succeed if `expected` is equal to
/// value. Additionally, it uses a *sequentially consistent ordering*. For more
/// details on atomic memory models, check the documentation for C11's
/// `stdatomic.h`.
///
/// - Parameter expected: The value that this object must currently hold for the
/// compare-and-swap to succeed.
/// - Parameter desired: The new value that this object will hold if the compare
/// succeeds.
/// - Returns: `True` if the exchange occurred, or `False` if `expected` did not
/// match the current value and so no exchange occurred.
public func compareAndExchange(expected: T, desired: T) -> Bool {
return withExtendedLifetime(desired) {
let expectedPtr = Unmanaged<T>.passUnretained(expected)
let desiredPtr = Unmanaged<T>.passUnretained(desired)
if self.storage.compareAndExchange(expected: Int(bitPattern: expectedPtr.toOpaque()),
desired: Int(bitPattern: desiredPtr.toOpaque())) {
_ = desiredPtr.retain()
expectedPtr.release()
return true
} else {
return false
}
}
}
/// Atomically exchanges `value` for the current value of this object.
///
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
/// more than that this operation is atomic: there is no guarantee that any other
/// event will be ordered before or after this one.
///
/// - Parameter value: The new value to set this object to.
/// - Returns: The value previously held by this object.
public func exchange(with value: T) -> T {
let newPtr = Unmanaged<T>.passRetained(value)
let oldPtrBits = self.storage.exchange(with: Int(bitPattern: newPtr.toOpaque()))
let oldPtr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!)
return oldPtr.takeRetainedValue()
}
/// Atomically loads and returns the value of this object.
///
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
/// more than that this operation is atomic: there is no guarantee that any other
/// event will be ordered before or after this one.
///
/// - Returns: The value of this object
public func load() -> T {
let ptrBits = self.storage.load()
let ptr = Unmanaged<T>.fromOpaque(UnsafeRawPointer(bitPattern: ptrBits)!)
return ptr.takeUnretainedValue()
}
/// Atomically replaces the value of this object with `value`.
///
/// This implementation uses a *relaxed* memory ordering. This guarantees nothing
/// more than that this operation is atomic: there is no guarantee that any other
/// event will be ordered before or after this one.
///
/// - Parameter value: The new value to set the object to.
public func store(_ value: T) -> Void {
_ = self.exchange(with: value)
}
}

View File

@ -41,7 +41,7 @@ private final class EchoHandler: ChannelInboundHandler {
}
public func channelActive(ctx: ChannelHandlerContext) {
print("Client connected to \(ctx.channel.remoteAddress!)")
print("Client connected to \(ctx.remoteAddress!)")
// We are connected its time to send the message to the server to initialize the ping-pong sequence.
var buffer = ctx.channel.allocator.buffer(capacity: line.utf8.count)

View File

@ -72,7 +72,7 @@ private final class HTTPHandler: ChannelInboundHandler {
URL: \(self.infoSavedRequestHead!.uri)\r
body length: \(self.infoSavedBodyBytes)\r
headers: \(self.infoSavedRequestHead!.headers)\r
client: \(ctx.channel.remoteAddress?.description ?? "zombie")\r
client: \(ctx.remoteAddress?.description ?? "zombie")\r
IO: SwiftNIO Electric Boogaloo\r\n
"""
self.buffer.clear()
@ -200,7 +200,7 @@ private final class HTTPHandler: ChannelInboundHandler {
case "/dynamic/count-to-ten":
return { self.handleMultipleWrites(ctx: $0, request: $1, strings: (1...10).map { "\($0)\r\n" }, delay: .milliseconds(100)) }
case "/dynamic/client-ip":
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.channel.remoteAddress.debugDescription)") }
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.remoteAddress.debugDescription)") }
default:
return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, statusCode: .notFound, string: "not found") }
}

View File

@ -39,6 +39,11 @@ extension NIOConcurrencyHelpersTests {
("testConditionLockMutualExclusion", testConditionLockMutualExclusion),
("testConditionLock", testConditionLock),
("testConditionLockWithDifferentConditions", testConditionLockWithDifferentConditions),
("testAtomicBoxDoesNotTriviallyLeak", testAtomicBoxDoesNotTriviallyLeak),
("testAtomicBoxCompareAndExchangeWorksIfEqual", testAtomicBoxCompareAndExchangeWorksIfEqual),
("testAtomicBoxCompareAndExchangeWorksIfNotEqual", testAtomicBoxCompareAndExchangeWorksIfNotEqual),
("testAtomicBoxStoreWorks", testAtomicBoxStoreWorks),
("testAtomicBoxCompareAndExchangeOntoItselfWorks", testAtomicBoxCompareAndExchangeOntoItselfWorks),
]
}
}

View File

@ -401,4 +401,149 @@ class NIOConcurrencyHelpersTests: XCTestCase {
doneSem.wait() /* job on 'q2' is done */
}
}
func testAtomicBoxDoesNotTriviallyLeak() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
({
let someInstance = SomeClass()
weakSomeInstance1 = someInstance
let someAtomic = AtomicBox(value: someInstance)
let loadedFromAtomic = someAtomic.load()
weakSomeInstance2 = loadedFromAtomic
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssert(someInstance === loadedFromAtomic)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
}
func testAtomicBoxCompareAndExchangeWorksIfEqual() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
weak var weakSomeInstance3: SomeClass? = nil
({
let someInstance1 = SomeClass()
let someInstance2 = SomeClass()
weakSomeInstance1 = someInstance1
let atomic = AtomicBox(value: someInstance1)
var loadedFromAtomic = atomic.load()
XCTAssert(someInstance1 === loadedFromAtomic)
weakSomeInstance2 = loadedFromAtomic
XCTAssertTrue(atomic.compareAndExchange(expected: loadedFromAtomic, desired: someInstance2))
loadedFromAtomic = atomic.load()
weakSomeInstance3 = loadedFromAtomic
XCTAssert(someInstance1 !== loadedFromAtomic)
XCTAssert(someInstance2 === loadedFromAtomic)
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssertNotNil(weakSomeInstance3)
XCTAssert(weakSomeInstance1 === weakSomeInstance2 && weakSomeInstance2 !== weakSomeInstance3)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
XCTAssertNil(weakSomeInstance3)
}
func testAtomicBoxCompareAndExchangeWorksIfNotEqual() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
weak var weakSomeInstance3: SomeClass? = nil
({
let someInstance1 = SomeClass()
let someInstance2 = SomeClass()
weakSomeInstance1 = someInstance1
let atomic = AtomicBox(value: someInstance1)
var loadedFromAtomic = atomic.load()
XCTAssert(someInstance1 === loadedFromAtomic)
weakSomeInstance2 = loadedFromAtomic
XCTAssertFalse(atomic.compareAndExchange(expected: someInstance2, desired: someInstance2))
XCTAssertFalse(atomic.compareAndExchange(expected: SomeClass(), desired: someInstance2))
XCTAssertTrue(atomic.load() === someInstance1)
loadedFromAtomic = atomic.load()
weakSomeInstance3 = someInstance2
XCTAssert(someInstance1 === loadedFromAtomic)
XCTAssert(someInstance2 !== loadedFromAtomic)
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssertNotNil(weakSomeInstance3)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
XCTAssertNil(weakSomeInstance3)
}
func testAtomicBoxStoreWorks() throws {
class SomeClass {}
weak var weakSomeInstance1: SomeClass? = nil
weak var weakSomeInstance2: SomeClass? = nil
weak var weakSomeInstance3: SomeClass? = nil
({
let someInstance1 = SomeClass()
let someInstance2 = SomeClass()
weakSomeInstance1 = someInstance1
let atomic = AtomicBox(value: someInstance1)
var loadedFromAtomic = atomic.load()
XCTAssert(someInstance1 === loadedFromAtomic)
weakSomeInstance2 = loadedFromAtomic
atomic.store(someInstance2)
loadedFromAtomic = atomic.load()
weakSomeInstance3 = loadedFromAtomic
XCTAssert(someInstance1 !== loadedFromAtomic)
XCTAssert(someInstance2 === loadedFromAtomic)
XCTAssertNotNil(weakSomeInstance1)
XCTAssertNotNil(weakSomeInstance2)
XCTAssertNotNil(weakSomeInstance3)
})()
XCTAssertNil(weakSomeInstance1)
XCTAssertNil(weakSomeInstance2)
XCTAssertNil(weakSomeInstance3)
}
func testAtomicBoxCompareAndExchangeOntoItselfWorks() {
let q = DispatchQueue(label: "q")
let g = DispatchGroup()
let sem1 = DispatchSemaphore(value: 0)
let sem2 = DispatchSemaphore(value: 0)
class SomeClass {}
weak var weakInstance: SomeClass?
({
let instance = SomeClass()
weakInstance = instance
let atomic = AtomicBox(value: instance)
q.async(group: g) {
sem1.signal()
sem2.wait()
for _ in 0..<1000 {
XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance))
}
}
sem2.signal()
sem1.wait()
for _ in 0..<1000 {
XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance))
}
g.wait()
let v = atomic.load()
XCTAssert(v === instance)
})()
XCTAssertNil(weakInstance)
}
}

View File

@ -19,6 +19,18 @@ import NIOFoundationCompat
import Dispatch
@testable import NIOHTTP1
internal extension Channel {
func syncCloseAcceptingAlreadyClosed() throws {
do {
try self.close().wait()
} catch ChannelError.alreadyClosed {
/* we're happy with this one */
} catch let e {
throw e
}
}
}
extension Array where Array.Element == ByteBuffer {
public func allAsBytes() -> [UInt8] {
var out: [UInt8] = []
@ -61,17 +73,6 @@ internal class ArrayAccumulationHandler<T>: ChannelInboundHandler {
}
class HTTPServerClientTest : XCTestCase {
private func syncCloseAcceptingAlreadyClosed(channel: Channel) throws {
do {
try channel.close().wait()
} catch ChannelError.alreadyClosed {
/* we're happy with this one */
} catch let e {
throw e
}
}
/* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */
private static let massiveResponseLength = 5 * 1024 * 1024 + 7
private static let massiveResponseBytes: [UInt8] = {
@ -374,7 +375,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}
let clientChannel = try ClientBootstrap(group: group)
@ -387,7 +388,7 @@ class HTTPServerClientTest : XCTestCase {
.wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
}
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/helloworld")
@ -432,7 +433,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}
let clientChannel = try ClientBootstrap(group: group)
@ -445,7 +446,7 @@ class HTTPServerClientTest : XCTestCase {
.wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
}
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/count-to-ten")
@ -490,7 +491,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}
let clientChannel = try ClientBootstrap(group: group)
@ -503,7 +504,7 @@ class HTTPServerClientTest : XCTestCase {
.wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
}
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/trailers")
@ -550,7 +551,7 @@ class HTTPServerClientTest : XCTestCase {
}.bind(host: "127.0.0.1", port: 0).wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}
let clientChannel = try ClientBootstrap(group: group)
@ -559,7 +560,7 @@ class HTTPServerClientTest : XCTestCase {
.wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
}
var buffer = clientChannel.allocator.buffer(capacity: numBytes)
@ -591,7 +592,7 @@ class HTTPServerClientTest : XCTestCase {
}
}.bind(host: "127.0.0.1", port: 0).wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}
let clientChannel = try ClientBootstrap(group: group)
@ -604,7 +605,7 @@ class HTTPServerClientTest : XCTestCase {
.wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
}
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .HEAD, uri: "/head")
@ -636,7 +637,7 @@ class HTTPServerClientTest : XCTestCase {
}
}.bind(host: "127.0.0.1", port: 0).wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel))
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}
let clientChannel = try ClientBootstrap(group: group)
@ -649,7 +650,7 @@ class HTTPServerClientTest : XCTestCase {
.wait()
defer {
XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel))
XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed())
}
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/204")

View File

@ -55,6 +55,7 @@ extension ChannelTests {
("testHalfClosure", testHalfClosure),
("testRejectsInvalidData", testRejectsInvalidData),
("testWeDontCrashIfChannelReleasesBeforePipeline", testWeDontCrashIfChannelReleasesBeforePipeline),
("testAskForLocalAndRemoteAddressesAfterChannelIsClosed", testAskForLocalAndRemoteAddressesAfterChannelIsClosed),
]
}
}

View File

@ -1157,7 +1157,7 @@ public class ChannelTests: XCTestCase {
return channel.pipeline.add(handler: byteCountingHandler)
}
}
.connect(to: server.localAddress!)
.connect(to: try! server.localAddress())
let accepted = try server.accept()!
defer {
XCTAssertNoThrow(try accepted.close())
@ -1219,7 +1219,7 @@ public class ChannelTests: XCTestCase {
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
}
}
.connect(to: server.localAddress!)
.connect(to: try! server.localAddress())
let accepted = try server.accept()!
defer {
XCTAssertNoThrow(try accepted.close())
@ -1269,7 +1269,7 @@ public class ChannelTests: XCTestCase {
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
}
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
.connect(to: server.localAddress!)
.connect(to: try! server.localAddress())
let accepted = try server.accept()!
defer {
XCTAssertNoThrow(try accepted.close())
@ -1412,4 +1412,27 @@ public class ChannelTests: XCTestCase {
XCTAssertNil(weakServerChannel, "weakServerChannel not nil, looks like we leaked it!")
XCTAssertNil(weakServerChildChannel, "weakServerChildChannel not nil, looks like we leaked it!")
}
func testAskForLocalAndRemoteAddressesAfterChannelIsClosed() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let serverChannel = try ServerBootstrap(group: group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.bind(host: "127.0.0.1", port: 0).wait()
let clientChannel = try ClientBootstrap(group: group)
.connect(to: serverChannel.localAddress!).wait()
// Start shutting stuff down.
try serverChannel.syncCloseAcceptingAlreadyClosed()
try clientChannel.syncCloseAcceptingAlreadyClosed()
for f in [ serverChannel.remoteAddress, serverChannel.localAddress, clientChannel.remoteAddress, clientChannel.localAddress ] {
XCTAssertNil(f)
}
}
}

View File

@ -69,6 +69,17 @@ func openTemporaryFile() -> (CInt, String) {
return (fd, String(decoding: templateBytes, as: UTF8.self))
}
internal extension Channel {
func syncCloseAcceptingAlreadyClosed() throws {
do {
try self.close().wait()
} catch ChannelError.alreadyClosed {
/* we're happy with this one */
} catch let e {
throw e
}
}
}
final class ByteCountingHandler : ChannelInboundHandler {
typealias InboundIn = ByteBuffer