diff --git a/Sources/NIO/BaseSocket.swift b/Sources/NIO/BaseSocket.swift index 91ff3505..a0a9e1f5 100644 --- a/Sources/NIO/BaseSocket.swift +++ b/Sources/NIO/BaseSocket.swift @@ -21,6 +21,19 @@ protocol SockAddrProtocol { mutating func withMutableSockAddr(_ fn: (UnsafeMutablePointer, Int) throws -> R) rethrows -> R } +private func descriptionForAddress(family: CInt, bytes: UnsafeRawPointer, length byteCount: Int) -> String { + var addressBytes: [Int8] = Array(repeating: 0, count: byteCount) + return addressBytes.withUnsafeMutableBufferPointer { (addressBytesPtr: inout UnsafeMutableBufferPointer) -> String in + try! Posix.inet_ntop(addressFamily: family, + addressBytes: bytes, + addressDescription: addressBytesPtr.baseAddress!, + addressDescriptionLength: socklen_t(byteCount)) + return addressBytesPtr.baseAddress!.withMemoryRebound(to: UInt8.self, capacity: byteCount) { addressBytesPtr -> String in + return String(decoding: UnsafeBufferPointer(start: addressBytesPtr, count: byteCount), as: Unicode.ASCII.self) + } + } +} + extension sockaddr_in: SockAddrProtocol { mutating func withSockAddr(_ fn: (UnsafePointer, Int) throws -> R) rethrows -> R { var me = self @@ -35,6 +48,12 @@ extension sockaddr_in: SockAddrProtocol { try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } + + mutating func addressDescription() -> String { + return withUnsafePointer(to: &self.sin_addr) { addrPtr in + descriptionForAddress(family: AF_INET, bytes: addrPtr, length: Int(INET_ADDRSTRLEN)) + } + } } extension sockaddr_in6: SockAddrProtocol { @@ -51,6 +70,12 @@ extension sockaddr_in6: SockAddrProtocol { try fn(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } + + mutating func addressDescription() -> String { + return withUnsafePointer(to: &self.sin6_addr) { addrPtr in + descriptionForAddress(family: AF_INET6, bytes: addrPtr, length: Int(INET6_ADDRSTRLEN)) + } + } } extension sockaddr_un: SockAddrProtocol { @@ -119,54 +144,45 @@ extension sockaddr_storage { } } } + + mutating func convert() -> SocketAddress { + switch self.ss_family { + case Posix.AF_INET: + var sockAddr: sockaddr_in = self.convert() + return SocketAddress(sockAddr, host: sockAddr.addressDescription()) + case Posix.AF_INET6: + var sockAddr: sockaddr_in6 = self.convert() + return SocketAddress(sockAddr, host: sockAddr.addressDescription()) + case Posix.AF_UNIX: + return SocketAddress(self.convert() as sockaddr_un) + default: + fatalError("unknown sockaddr family \(self.ss_family)") + } + } } class BaseSocket: Selectable { public let descriptor: Int32 public private(set) var open: Bool - final var localAddress: SocketAddress? { - get { - return get_addr { getsockname($0, $1, $2) } - } + final func localAddress() throws -> SocketAddress { + return try get_addr { try Posix.getsockname(socket: $0, address: $1, addressLength: $2) } } - final var remoteAddress: SocketAddress? { - get { - return get_addr { getpeername($0, $1, $2) } - } + final func remoteAddress() throws -> SocketAddress { + return try get_addr { try Posix.getpeername(socket: $0, address: $1, addressLength: $2) } } - private func get_addr(_ fn: (Int32, UnsafeMutablePointer, UnsafeMutablePointer) -> Int32) -> SocketAddress? { + private func get_addr(_ fn: (Int32, UnsafeMutablePointer, UnsafeMutablePointer) throws -> Void) throws -> SocketAddress { var addr = sockaddr_storage() - var len: socklen_t = socklen_t(MemoryLayout.size) - - return withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1, { address in - guard fn(descriptor, address, &len) == 0 else { - return nil - } - switch Int32(address.pointee.sa_family) { - case AF_INET: - return address.withMemoryRebound(to: sockaddr_in.self, capacity: 1, { ipv4 in - var ipAddressString = [CChar](repeating: 0, count: Int(INET_ADDRSTRLEN)) - return SocketAddress(ipv4.pointee, host: String(cString: inet_ntop(AF_INET, &ipv4.pointee.sin_addr, &ipAddressString, socklen_t(INET_ADDRSTRLEN)))) - }) - case AF_INET6: - return address.withMemoryRebound(to: sockaddr_in6.self, capacity: 1, { ipv6 in - var ipAddressString = [CChar](repeating: 0, count: Int(INET6_ADDRSTRLEN)) - return SocketAddress(ipv6.pointee, host: String(cString: inet_ntop(AF_INET6, &ipv6.pointee.sin6_addr, &ipAddressString, socklen_t(INET6_ADDRSTRLEN)))) - }) - case AF_UNIX: - return address.withMemoryRebound(to: sockaddr_un.self, capacity: 1) { uds in - return SocketAddress(uds.pointee) - } - default: - fatalError("address family \(address.pointee.sa_family) not supported") - } - }) + + try addr.withMutableSockAddr { addressPtr, size in + var size = socklen_t(size) + try fn(self.descriptor, addressPtr, &size) } + return addr.convert() } + static func newSocket(protocolFamily: Int32, type: CInt) throws -> Int32 { let sock = try Posix.socket(domain: protocolFamily, type: type, diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index 05afffce..45120d7a 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -18,6 +18,8 @@ import NIOConcurrencyHelpers /// /// - note: All methods must be called from the EventLoop thread public protocol ChannelCore : class { + func localAddress0() throws -> SocketAddress + func remoteAddress0() throws -> SocketAddress func register0(promise: EventLoopPromise?) func bind0(to: SocketAddress, promise: EventLoopPromise?) func connect0(to: SocketAddress, promise: EventLoopPromise?) @@ -82,7 +84,9 @@ public protocol Channel : class, ChannelOutboundInvoker { /// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events /// before possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`. -protocol SelectableChannel : Channel { +/// +/// - warning: `SelectableChannel` methods and properties are _not_ thread-safe (unless they also belong to `Channel`). +internal protocol SelectableChannel : Channel { /// The type of the `Selectable`. A `Selectable` is usually wrapping a file descriptor that can be registered in a /// `Selector`. associatedtype SelectableType: Selectable diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 23d767c1..8afdf653 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -809,6 +809,14 @@ public final class ChannelHandlerContext : ChannelInvoker { return self.inboundHandler ?? self.outboundHandler! } + public var remoteAddress: SocketAddress? { + return try? self.channel._unsafe.remoteAddress0() + } + + public var localAddress: SocketAddress? { + return try? self.channel._unsafe.localAddress0() + } + public let name: String public let eventLoop: EventLoop private let inboundHandler: _ChannelInboundHandler? diff --git a/Sources/NIO/DeadChannel.swift b/Sources/NIO/DeadChannel.swift index 902943b2..c2fd4e99 100644 --- a/Sources/NIO/DeadChannel.swift +++ b/Sources/NIO/DeadChannel.swift @@ -16,6 +16,14 @@ /// the original `Channel` is closed. Given that the original `Channel` is closed the `DeadChannelCore` should fail /// all operations. private final class DeadChannelCore: ChannelCore { + func localAddress0() throws -> SocketAddress { + throw ChannelError.ioOnClosedChannel + } + + func remoteAddress0() throws -> SocketAddress { + throw ChannelError.ioOnClosedChannel + } + func register0(promise: EventLoopPromise?) { promise?.fail(error: ChannelError.ioOnClosedChannel) } @@ -65,27 +73,30 @@ private final class DeadChannelCore: ChannelCore { /// channel as it only holds an unowned reference to the original `Channel`. `DeadChannel` serves as a replacement /// that can be used when the original `Channel` might no longer be valid. internal final class DeadChannel: Channel { + let eventLoop: EventLoop let pipeline: ChannelPipeline - var eventLoop: EventLoop { - return self.pipeline.eventLoop + public var closeFuture: EventLoopFuture<()> { + return self.eventLoop.newSucceededFuture(result: ()) } internal init(pipeline: ChannelPipeline) { self.pipeline = pipeline + self.eventLoop = pipeline.eventLoop } + // This is `Channel` API so must be thread-safe. var allocator: ByteBufferAllocator { return ByteBufferAllocator() } - var closeFuture: EventLoopFuture { - return self.pipeline.eventLoop.newSucceededFuture(result: ()) + var localAddress: SocketAddress? { + return nil } - let localAddress: SocketAddress? = nil - - let remoteAddress: SocketAddress? = nil + var remoteAddress: SocketAddress? { + return nil + } let parent: Channel? = nil diff --git a/Sources/NIO/Embedded.swift b/Sources/NIO/Embedded.swift index b01f5433..9c2bf897 100644 --- a/Sources/NIO/Embedded.swift +++ b/Sources/NIO/Embedded.swift @@ -140,7 +140,6 @@ public class EmbeddedEventLoop: EventLoop { class EmbeddedChannelCore : ChannelCore { var closed: Bool = false var isActive: Bool = false - var eventLoop: EventLoop var closePromise: EventLoopPromise @@ -163,6 +162,14 @@ class EmbeddedChannelCore : ChannelCore { var outboundBuffer: [IOData] = [] var inboundBuffer: [NIOAny] = [] + func localAddress0() throws -> SocketAddress { + throw ChannelError.operationUnsupported + } + + func remoteAddress0() throws -> SocketAddress { + throw ChannelError.operationUnsupported + } + func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { if closed { promise?.fail(error: ChannelError.alreadyClosed) @@ -269,8 +276,8 @@ public class EmbeddedChannel : Channel { public var allocator: ByteBufferAllocator = ByteBufferAllocator() public var eventLoop: EventLoop = EmbeddedEventLoop() - public var localAddress: SocketAddress? = nil - public var remoteAddress: SocketAddress? = nil + public let localAddress: SocketAddress? = nil + public let remoteAddress: SocketAddress? = nil // Embedded channels never have parents. public let parent: Channel? = nil diff --git a/Sources/NIO/PendingWritesManager.swift b/Sources/NIO/PendingWritesManager.swift index 7ed1cd82..c3f84a8c 100644 --- a/Sources/NIO/PendingWritesManager.swift +++ b/Sources/NIO/PendingWritesManager.swift @@ -465,6 +465,7 @@ internal protocol PendingWritesManager { } extension PendingWritesManager { + // This is called from `Channel` API so must be thread-safe. var isWritable: Bool { return self.channelWritabilityFlag.load() } diff --git a/Sources/NIO/Selectable.swift b/Sources/NIO/Selectable.swift index 66db2595..ab611970 100644 --- a/Sources/NIO/Selectable.swift +++ b/Sources/NIO/Selectable.swift @@ -13,6 +13,8 @@ //===----------------------------------------------------------------------===// /// Represents a selectable resource which can be registered to a `Selector` to be notified once there are some events ready for it. +/// +/// - warning: `Selectable`s are not thread-safe, only to be used on the appropriate `EventLoop`. protocol Selectable { /// The file descriptor itself. diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index 42f447b5..9bbbe5d2 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -43,6 +43,38 @@ private extension ByteBuffer { class BaseSocketChannel : SelectableChannel, ChannelCore { typealias SelectableType = T + // MARK: Stored Properties + // Visible to access from EventLoop directly + public let parent: Channel? + internal let socket: T + private let closePromise: EventLoopPromise + private let selectableEventLoop: SelectableEventLoop + private let localAddressCached: AtomicBox> = AtomicBox(value: Box(nil)) + private let remoteAddressCached: AtomicBox> = AtomicBox(value: Box(nil)) + private let bufferAllocatorCached: AtomicBox> + + internal var interestedEvent: IOEvent = .none + + fileprivate var readPending = false + fileprivate var pendingConnect: EventLoopPromise? + fileprivate var recvAllocator: RecvByteBufferAllocator + fileprivate var maxMessagesPerRead: UInt = 4 + + private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method. + private var neverRegistered = true + private var active: Atomic = Atomic(value: false) + private var _closed: Bool = false + private var autoRead: Bool = true + private var _pipeline: ChannelPipeline! + private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator() { + didSet { + assert(self.eventLoop.inEventLoop) + self.bufferAllocatorCached.store(Box(self.bufferAllocator)) + } + } + + // MARK: Datatypes + /// Indicates if a selectable should registered or not for IO notifications. enum IONotificationState { /// We should be registered for IO notifications. @@ -52,7 +84,26 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { case unregister } - // MARK: Methods to override in subclasses. + fileprivate enum ReadResult { + /// Nothing was read by the read operation. + case none + + /// Some data was read by the read operation. + case some + } + + // MARK: Computed Properties + public final var _unsafe: ChannelCore { return self } + + // This is `Channel` API so must be thread-safe. + public final var localAddress: SocketAddress? { + return self.localAddressCached.load().value + } + + // This is `Channel` API so must be thread-safe. + public final var remoteAddress: SocketAddress? { + return self.remoteAddressCached.load().value + } /// `true` if the whole `Channel` is closed and so no more IO operation can be done. public var closed: Bool { @@ -60,6 +111,48 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { return _closed } + internal var selectable: T { + return self.socket + } + + // This is `Channel` API so must be thread-safe. + public var isActive: Bool { + return self.active.load() + } + + // This is `Channel` API so must be thread-safe. + public final var closeFuture: EventLoopFuture { + return self.closePromise.futureResult + } + + public final var eventLoop: EventLoop { + return selectableEventLoop + } + + // This is `Channel` API so must be thread-safe. + public var isWritable: Bool { + return true + } + + // This is `Channel` API so must be thread-safe. + public final var allocator: ByteBufferAllocator { + if eventLoop.inEventLoop { + return bufferAllocator + } else { + return self.bufferAllocatorCached.load().value + } + } + + // This is `Channel` API so must be thread-safe. + public final var pipeline: ChannelPipeline { + return _pipeline + } + + // MARK: Methods to override in subclasses. + func writeToSocket() throws -> OverallWriteResult { + fatalError("must be overridden") + } + /// Provides the registration for this selector. Must be implemented by subclasses. func registrationFor(interested: IOEvent) -> NIORegistration { fatalError("must override") @@ -72,14 +165,6 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { fatalError("this must be overridden by sub class") } - fileprivate enum ReadResult { - /// Nothing was read by the read operation. - case none - - /// Some data was read by the read operation. - case some - } - /// Begin connection of the underlying socket. /// /// - parameters: @@ -111,70 +196,78 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { fatalError("this must be overridden by sub class") } + // MARK: Common base socket logic. + fileprivate init(socket: T, parent: Channel? = nil, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws { + self.bufferAllocatorCached = AtomicBox(value: Box(self.bufferAllocator)) + self.socket = socket + self.selectableEventLoop = eventLoop + self.closePromise = eventLoop.newPromise() + self.parent = parent + self.active.store(false) + self.recvAllocator = recvAllocator + self._pipeline = ChannelPipeline(channel: self) + } + + deinit { + assert(self._closed, "leak of open Channel") + } + + public final func localAddress0() throws -> SocketAddress { + assert(self.eventLoop.inEventLoop) + guard self.open else { + throw ChannelError.ioOnClosedChannel + } + return try self.socket.localAddress() + } + + public final func remoteAddress0() throws -> SocketAddress { + assert(self.eventLoop.inEventLoop) + guard self.open else { + throw ChannelError.ioOnClosedChannel + } + return try self.socket.remoteAddress() + } + /// Flush data to the underlying socket and return if this socket needs to be registered for write notifications. /// /// - returns: If this socket should be registered for write notifications. Ie. `IONotificationState.register` if _not_ all data could be written, so notifications are necessary; and `IONotificationState.unregister` if everything was written and we don't need to be notified about writability at the moment. fileprivate func flushNow() -> IONotificationState { - fatalError("this must be overridden by sub class") - } + // Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by + // `writeToSocket`. + guard !inFlushNow && !closed else { + return .unregister + } - // MARK: Common base socket logic. + defer { + inFlushNow = false + } + inFlushNow = true - var selectable: T { - return self.socket - } + do { + switch try self.writeToSocket() { + case .couldNotWriteEverything: + return .register + case .writtenCompletely: + return .unregister + } + } catch let err { + // If there is a write error we should try drain the inbound before closing the socket as there may be some data pending. + // We ignore any error that is thrown as we will use the original err to close the channel and notify the user. + if readIfNeeded0() { - public final var _unsafe: ChannelCore { return self } + // We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it. + while let read = try? readFromSocket(), read == .some { + pipeline.fireChannelReadComplete() + } + } - // Visible to access from EventLoop directly - let socket: T - public var interestedEvent: IOEvent = .none + close0(error: err, mode: .all, promise: nil) - fileprivate var readPending = false - private var neverRegistered = true - fileprivate var pendingConnect: EventLoopPromise? - private let closePromise: EventLoopPromise - private var active: Atomic = Atomic(value: false) - private var _closed: Bool = false - public var isActive: Bool { - return active.load() - } - - public var parent: Channel? = nil - - public final var closeFuture: EventLoopFuture { - return closePromise.futureResult - } - - private let selectableEventLoop: SelectableEventLoop - - public final var eventLoop: EventLoop { - return selectableEventLoop - } - - public var isWritable: Bool { - return true - } - - public final var allocator: ByteBufferAllocator { - if eventLoop.inEventLoop { - return bufferAllocator - } else { - return try! eventLoop.submit{ self.bufferAllocator }.wait() + // we handled all writes + return .unregister } } - private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator() - fileprivate var recvAllocator: RecvByteBufferAllocator - fileprivate var autoRead: Bool = true - fileprivate var maxMessagesPerRead: UInt = 4 - - // We don't use lazy var here as this is more expensive then doing this :/ - public final var pipeline: ChannelPipeline { - return _pipeline - } - - private var _pipeline: ChannelPipeline! public final func setOption(option: T, value: T.OptionType) -> EventLoopFuture { if eventLoop.inEventLoop { @@ -240,18 +333,6 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - public final var localAddress: SocketAddress? { - get { - return socket.localAddress - } - } - - public final var remoteAddress: SocketAddress? { - get { - return socket.remoteAddress - } - } - /// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.` /// /// - returns: `true` if `readPending` is `true`, `false` otherwise. @@ -275,6 +356,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { executeAndComplete(promise) { try socket.bind(to: address) + self.updateCachedAddressesFromSocket(updateRemote: false) } } @@ -401,6 +483,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { // Fail all pending writes and so ensure all pending promises are notified self._closed = true + self.unsetCachedAddressesFromSocket() self.cancelWritesOnClose(error: error) becomeInactive0() @@ -515,6 +598,22 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { readIfNeeded0() } + internal final func updateCachedAddressesFromSocket(updateLocal: Bool = true, updateRemote: Bool = true) { + assert(self.eventLoop.inEventLoop) + if updateLocal { + self.localAddressCached.store(Box(try? self.localAddress0())) + } + if updateRemote { + self.remoteAddressCached.store(Box(try? self.remoteAddress0())) + } + } + + internal final func unsetCachedAddressesFromSocket() { + assert(self.eventLoop.inEventLoop) + self.localAddressCached.store(Box(nil)) + self.remoteAddressCached.store(Box(nil)) + } + public final func connect0(to address: SocketAddress, promise: EventLoopPromise?) { assert(eventLoop.inEventLoop) @@ -529,6 +628,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } do { if try !connectSocket(to: address) { + self.updateCachedAddressesFromSocket() if promise != nil { pendingConnect = promise } else { @@ -536,6 +636,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } registerForWritable() } else { + self.updateCachedAddressesFromSocket() promise?.succeed(result: ()) } } catch let error { @@ -602,19 +703,6 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { active.store(false) pipeline.fireChannelInactive0() } - - fileprivate init(socket: T, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws { - self.socket = socket - self.selectableEventLoop = eventLoop - self.closePromise = eventLoop.newPromise() - active.store(false) - self.recvAllocator = recvAllocator - self._pipeline = ChannelPipeline(channel: self) - } - - deinit { - assert(self._closed, "leak of open Channel") - } } /// A `Channel` for a client socket. @@ -628,10 +716,9 @@ final class SocketChannel: BaseSocketChannel { private var inputShutdown: Bool = false private var outputShutdown: Bool = false - // Guard against re-entrance of flushNow() method. - private var inFlushNow: Bool = false private let pendingWrites: PendingStreamWritesManager + // This is `Channel` API so must be thread-safe. override public var isWritable: Bool { return pendingWrites.isWritable } @@ -694,15 +781,10 @@ final class SocketChannel: BaseSocketChannel { return .socketChannel(self, interested) } - fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws { + fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws { try socket.setNonBlocking() self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs) - try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) - } - - fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws { - try self.init(socket: socket, eventLoop: eventLoop) - self.parent = parent + try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) } override fileprivate func readFromSocket() throws -> ReadResult { @@ -749,8 +831,8 @@ final class SocketChannel: BaseSocketChannel { return result } - private func writeToSocket(pendingWrites: PendingStreamWritesManager) throws -> OverallWriteResult { - let result = try pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in + override func writeToSocket() throws -> OverallWriteResult { + let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in guard ptr.count > 0 else { // No need to call write if the buffer is empty. return .processed(0) @@ -882,43 +964,6 @@ final class SocketChannel: BaseSocketChannel { pipeline.fireChannelWritabilityChanged0() } } - - override fileprivate func flushNow() -> IONotificationState { - // Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by - // `writeToSocket`. - guard !inFlushNow && !closed else { - return .unregister - } - - defer { - inFlushNow = false - } - inFlushNow = true - - do { - switch try self.writeToSocket(pendingWrites: pendingWrites) { - case .couldNotWriteEverything: - return .register - case .writtenCompletely: - return .unregister - } - } catch let err { - // If there is a write error we should try drain the inbound before closing the socket as there may be some data pending. - // We ignore any error that is thrown as we will use the original err to close the channel and notify the user. - if readIfNeeded0() { - - // We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it. - while let read = try? readFromSocket(), read == .some { - pipeline.fireChannelReadComplete() - } - } - - close0(error: err, mode: .all, promise: nil) - - // we handled all writes - return .unregister - } - } } /// A `Channel` for a server socket. @@ -930,6 +975,7 @@ final class ServerSocketChannel : BaseSocketChannel { private let group: EventLoopGroup /// The server socket channel is never writable. + // This is `Channel` API so must be thread-safe. override public var isWritable: Bool { return false } init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: Int32) throws { @@ -987,6 +1033,7 @@ final class ServerSocketChannel : BaseSocketChannel { } executeAndComplete(p) { try socket.bind(to: address) + self.updateCachedAddressesFromSocket(updateRemote: false) try self.socket.listen(backlog: backlog) } } @@ -1009,7 +1056,7 @@ final class ServerSocketChannel : BaseSocketChannel { readPending = false result = .some do { - let chan = try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop, parent: self) + let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop) pipeline.fireChannelRead0(NIOAny(chan)) } catch let err { _ = try? accepted.close() @@ -1054,9 +1101,9 @@ final class ServerSocketChannel : BaseSocketChannel { final class DatagramChannel: BaseSocketChannel { // Guard against re-entrance of flushNow() method. - private var inFlushNow: Bool = false private let pendingWrites: PendingDatagramWritesManager + // This is `Channel` API so must be thread-safe. override public var isWritable: Bool { return pendingWrites.isWritable } @@ -1082,18 +1129,13 @@ final class DatagramChannel: BaseSocketChannel { try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) } - fileprivate init(socket: Socket, eventLoop: SelectableEventLoop) throws { + fileprivate init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws { try socket.setNonBlocking() self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs, iovecs: eventLoop.iovecs, addresses: eventLoop.addresses, storageRefs: eventLoop.storageRefs) - try super.init(socket: socket, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) - } - - fileprivate convenience init(socket: Socket, eventLoop: SelectableEventLoop, parent: Channel) throws { - try self.init(socket: socket, eventLoop: eventLoop) - self.parent = parent + try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) } // MARK: Datagram Channel overrides required by BaseSocketChannel @@ -1156,22 +1198,7 @@ final class DatagramChannel: BaseSocketChannel { let mayGrow = recvAllocator.record(actualReadBytes: bytesRead) readPending = false - let sourceAddress: SocketAddress - switch Int(rawAddressLength) { - case MemoryLayout.size: - let addr: sockaddr_in = rawAddress.convert() - sourceAddress = .init(addr, host: "") - case MemoryLayout.size: - let addr: sockaddr_in6 = rawAddress.convert() - sourceAddress = .init(addr, host: "") - case MemoryLayout.size: - let addr: sockaddr_un = rawAddress.convert() - sourceAddress = .init(addr) - default: - fatalError("Unexpected sockaddr size") - } - - let msg = AddressedEnvelope(remoteAddress: sourceAddress, data: buffer) + let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer) pipeline.fireChannelRead0(NIOAny(msg)) if mayGrow && i < maxMessagesPerRead { buffer = recvAllocator.buffer(allocator: allocator) @@ -1211,45 +1238,8 @@ final class DatagramChannel: BaseSocketChannel { self.pendingWrites.failAll(error: error, close: true) } - override fileprivate func flushNow() -> IONotificationState { - // Guard against re-entry as data that will be put into `pendingWrites` will just be picked up by - // `writeToSocket`. - guard !inFlushNow && !closed else { - return .unregister - } - - defer { - inFlushNow = false - } - inFlushNow = true - - do { - switch try self.writeToSocket(pendingWrites: pendingWrites) { - case .couldNotWriteEverything: - return .register - case .writtenCompletely: - return .unregister - } - } catch let err { - // If there is a write error we should try drain the inbound before closing the socket as there may be some data pending. - // We ignore any error that is thrown as we will use the original err to close the channel and notify the user. - if readIfNeeded0() { - - // We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it. - while let read = try? readFromSocket(), read == .some { - pipeline.fireChannelReadComplete() - } - } - - close0(error: err, mode: .all, promise: nil) - - // we handled all writes - return .unregister - } - } - - private func writeToSocket(pendingWrites: PendingDatagramWritesManager) throws -> OverallWriteResult { - let result = try pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in + override func writeToSocket() throws -> OverallWriteResult { + let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarWriteOperation: { (ptr, destinationPtr, destinationSize) in guard ptr.count > 0 else { // No need to call write if the buffer is empty. return .processed(0) @@ -1273,6 +1263,7 @@ final class DatagramChannel: BaseSocketChannel { assert(self.eventLoop.inEventLoop) do { try socket.bind(to: address) + self.updateCachedAddressesFromSocket(updateRemote: false) promise?.succeed(result: ()) becomeActive0() readIfNeeded0() diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index 3fe493b8..c9658716 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -50,6 +50,12 @@ private let sysLseek = lseek private let sysRecvFrom = recvfrom private let sysSendTo = sendto private let sysDup = dup +private let sysGetpeername = getpeername +private let sysGetsockname = getsockname +private let sysAF_INET = AF_INET +private let sysAF_INET6 = AF_INET6 +private let sysAF_UNIX = AF_UNIX +private let sysInet_ntop = inet_ntop #if os(Linux) private let sysSendMmsg = CNIOLinux_sendmmsg @@ -110,6 +116,23 @@ internal func wrapSyscall(where function: StaticString = # } } +/* Sorry, we really try hard to not use underscored attributes. In this case however we seem to break the inlining threshold which makes a system call take twice the time, ie. we need this exception. */ +@inline(__always) +internal func wrapErrorIsNullReturnCall(where function: StaticString = #function, _ fn: () throws -> UnsafePointer?) throws -> UnsafePointer? { + while true { + let res = try fn() + if res == nil { + let err = errno + if err == EINTR { + continue + } + assert(!isBlacklistedErrno(err), "blacklisted errno \(err) \(strerror(err)!)") + throw IOError(errnoCode: err, function: function) + } + return res + } +} + enum Shutdown { case RD case WR @@ -148,6 +171,9 @@ internal enum Posix { } #endif + static let AF_INET = sa_family_t(sysAF_INET) + static let AF_INET6 = sa_family_t(sysAF_INET6) + static let AF_UNIX = sa_family_t(sysAF_UNIX) @inline(never) public static func shutdown(descriptor: CInt, how: Shutdown) throws { @@ -321,6 +347,14 @@ internal enum Posix { } } + @discardableResult + @inline(never) + public static func inet_ntop(addressFamily: CInt, addressBytes: UnsafeRawPointer, addressDescription: UnsafeMutablePointer, addressDescriptionLength: socklen_t) throws -> UnsafePointer? { + return try wrapErrorIsNullReturnCall { + sysInet_ntop(addressFamily, addressBytes, addressDescription, addressDescriptionLength) + } + } + // Its not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple @inline(never) public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult { @@ -365,6 +399,21 @@ internal enum Posix { Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout)) } } + + @inline(never) + public static func getpeername(socket: CInt, address: UnsafeMutablePointer, addressLength: UnsafeMutablePointer) throws { + _ = try wrapSyscall { + return sysGetpeername(socket, address, addressLength) + } + } + + @inline(never) + public static func getsockname(socket: CInt, address: UnsafeMutablePointer, addressLength: UnsafeMutablePointer) throws { + _ = try wrapSyscall { + return sysGetsockname(socket, address, addressLength) + } + } + } #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) diff --git a/Sources/NIOChatServer/main.swift b/Sources/NIOChatServer/main.swift index ba30ee82..0804f8ad 100644 --- a/Sources/NIOChatServer/main.swift +++ b/Sources/NIOChatServer/main.swift @@ -74,7 +74,7 @@ final class ChatHandler: ChannelInboundHandler { } public func channelActive(ctx: ChannelHandlerContext) { - let remoteAddress = ctx.channel.remoteAddress! + let remoteAddress = ctx.remoteAddress! let channel = ctx.channel self.channelsSyncQueue.async { // broadcast the message to all the connected clients except the one that just became active. @@ -84,7 +84,7 @@ final class ChatHandler: ChannelInboundHandler { } var buffer = channel.allocator.buffer(capacity: 64) - buffer.write(string: "(ChatServer) - Welcome to: \(channel.localAddress!)\n") + buffer.write(string: "(ChatServer) - Welcome to: \(ctx.localAddress!)\n") ctx.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) } diff --git a/Sources/NIOConcurrencyHelpers/atomics.swift b/Sources/NIOConcurrencyHelpers/atomics.swift index ef8e9893..669962d7 100644 --- a/Sources/NIOConcurrencyHelpers/atomics.swift +++ b/Sources/NIOConcurrencyHelpers/atomics.swift @@ -307,3 +307,91 @@ extension UInt: AtomicPrimitive { public static let atomic_load = catmc_atomic_unsigned_long_load public static let atomic_store = catmc_atomic_unsigned_long_store } + +/// `AtomicBox` is a heap-allocated box which allows atomic access to an instance of a Swift class. +/// +/// It behaves very much like `Atomic` but for objects, maintaining the correct retain counts. +public class AtomicBox { + private let storage: Atomic + + public init(value: T) { + let ptr = Unmanaged.passRetained(value) + self.storage = Atomic(value: Int(bitPattern: ptr.toOpaque())) + } + + deinit { + let oldPtrBits = self.storage.exchange(with: 0xdeadbeef) + let oldPtr = Unmanaged.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!) + oldPtr.release() + } + + /// Atomically compares the value against `expected` and, if they are equal, + /// replaces the value with `desired`. + /// + /// This implementation conforms to C11's `atomic_compare_exchange_strong`. This + /// means that the compare-and-swap will always succeed if `expected` is equal to + /// value. Additionally, it uses a *sequentially consistent ordering*. For more + /// details on atomic memory models, check the documentation for C11's + /// `stdatomic.h`. + /// + /// - Parameter expected: The value that this object must currently hold for the + /// compare-and-swap to succeed. + /// - Parameter desired: The new value that this object will hold if the compare + /// succeeds. + /// - Returns: `True` if the exchange occurred, or `False` if `expected` did not + /// match the current value and so no exchange occurred. + public func compareAndExchange(expected: T, desired: T) -> Bool { + return withExtendedLifetime(desired) { + let expectedPtr = Unmanaged.passUnretained(expected) + let desiredPtr = Unmanaged.passUnretained(desired) + + if self.storage.compareAndExchange(expected: Int(bitPattern: expectedPtr.toOpaque()), + desired: Int(bitPattern: desiredPtr.toOpaque())) { + _ = desiredPtr.retain() + expectedPtr.release() + return true + } else { + return false + } + } + } + + /// Atomically exchanges `value` for the current value of this object. + /// + /// This implementation uses a *relaxed* memory ordering. This guarantees nothing + /// more than that this operation is atomic: there is no guarantee that any other + /// event will be ordered before or after this one. + /// + /// - Parameter value: The new value to set this object to. + /// - Returns: The value previously held by this object. + public func exchange(with value: T) -> T { + let newPtr = Unmanaged.passRetained(value) + let oldPtrBits = self.storage.exchange(with: Int(bitPattern: newPtr.toOpaque())) + let oldPtr = Unmanaged.fromOpaque(UnsafeRawPointer(bitPattern: oldPtrBits)!) + return oldPtr.takeRetainedValue() + } + + /// Atomically loads and returns the value of this object. + /// + /// This implementation uses a *relaxed* memory ordering. This guarantees nothing + /// more than that this operation is atomic: there is no guarantee that any other + /// event will be ordered before or after this one. + /// + /// - Returns: The value of this object + public func load() -> T { + let ptrBits = self.storage.load() + let ptr = Unmanaged.fromOpaque(UnsafeRawPointer(bitPattern: ptrBits)!) + return ptr.takeUnretainedValue() + } + + /// Atomically replaces the value of this object with `value`. + /// + /// This implementation uses a *relaxed* memory ordering. This guarantees nothing + /// more than that this operation is atomic: there is no guarantee that any other + /// event will be ordered before or after this one. + /// + /// - Parameter value: The new value to set the object to. + public func store(_ value: T) -> Void { + _ = self.exchange(with: value) + } +} diff --git a/Sources/NIOEchoClient/main.swift b/Sources/NIOEchoClient/main.swift index 576a30d5..49ec0c18 100644 --- a/Sources/NIOEchoClient/main.swift +++ b/Sources/NIOEchoClient/main.swift @@ -41,7 +41,7 @@ private final class EchoHandler: ChannelInboundHandler { } public func channelActive(ctx: ChannelHandlerContext) { - print("Client connected to \(ctx.channel.remoteAddress!)") + print("Client connected to \(ctx.remoteAddress!)") // We are connected its time to send the message to the server to initialize the ping-pong sequence. var buffer = ctx.channel.allocator.buffer(capacity: line.utf8.count) diff --git a/Sources/NIOHTTP1Server/main.swift b/Sources/NIOHTTP1Server/main.swift index 040f7c1c..bed2e1a7 100644 --- a/Sources/NIOHTTP1Server/main.swift +++ b/Sources/NIOHTTP1Server/main.swift @@ -72,7 +72,7 @@ private final class HTTPHandler: ChannelInboundHandler { URL: \(self.infoSavedRequestHead!.uri)\r body length: \(self.infoSavedBodyBytes)\r headers: \(self.infoSavedRequestHead!.headers)\r - client: \(ctx.channel.remoteAddress?.description ?? "zombie")\r + client: \(ctx.remoteAddress?.description ?? "zombie")\r IO: SwiftNIO Electric Boogaloo™️\r\n """ self.buffer.clear() @@ -200,7 +200,7 @@ private final class HTTPHandler: ChannelInboundHandler { case "/dynamic/count-to-ten": return { self.handleMultipleWrites(ctx: $0, request: $1, strings: (1...10).map { "\($0)\r\n" }, delay: .milliseconds(100)) } case "/dynamic/client-ip": - return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.channel.remoteAddress.debugDescription)") } + return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, string: "\(ctx.remoteAddress.debugDescription)") } default: return { ctx, req in self.handleJustWrite(ctx: ctx, request: req, statusCode: .notFound, string: "not found") } } diff --git a/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests+XCTest.swift b/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests+XCTest.swift index 2b9dd4ca..6928eee7 100644 --- a/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests+XCTest.swift +++ b/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests+XCTest.swift @@ -39,6 +39,11 @@ extension NIOConcurrencyHelpersTests { ("testConditionLockMutualExclusion", testConditionLockMutualExclusion), ("testConditionLock", testConditionLock), ("testConditionLockWithDifferentConditions", testConditionLockWithDifferentConditions), + ("testAtomicBoxDoesNotTriviallyLeak", testAtomicBoxDoesNotTriviallyLeak), + ("testAtomicBoxCompareAndExchangeWorksIfEqual", testAtomicBoxCompareAndExchangeWorksIfEqual), + ("testAtomicBoxCompareAndExchangeWorksIfNotEqual", testAtomicBoxCompareAndExchangeWorksIfNotEqual), + ("testAtomicBoxStoreWorks", testAtomicBoxStoreWorks), + ("testAtomicBoxCompareAndExchangeOntoItselfWorks", testAtomicBoxCompareAndExchangeOntoItselfWorks), ] } } diff --git a/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift b/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift index 606cacb3..dff9f18c 100644 --- a/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift +++ b/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift @@ -401,4 +401,149 @@ class NIOConcurrencyHelpersTests: XCTestCase { doneSem.wait() /* job on 'q2' is done */ } } + + func testAtomicBoxDoesNotTriviallyLeak() throws { + class SomeClass {} + weak var weakSomeInstance1: SomeClass? = nil + weak var weakSomeInstance2: SomeClass? = nil + ({ + let someInstance = SomeClass() + weakSomeInstance1 = someInstance + let someAtomic = AtomicBox(value: someInstance) + let loadedFromAtomic = someAtomic.load() + weakSomeInstance2 = loadedFromAtomic + XCTAssertNotNil(weakSomeInstance1) + XCTAssertNotNil(weakSomeInstance2) + XCTAssert(someInstance === loadedFromAtomic) + })() + XCTAssertNil(weakSomeInstance1) + XCTAssertNil(weakSomeInstance2) + } + + func testAtomicBoxCompareAndExchangeWorksIfEqual() throws { + class SomeClass {} + weak var weakSomeInstance1: SomeClass? = nil + weak var weakSomeInstance2: SomeClass? = nil + weak var weakSomeInstance3: SomeClass? = nil + ({ + let someInstance1 = SomeClass() + let someInstance2 = SomeClass() + weakSomeInstance1 = someInstance1 + + let atomic = AtomicBox(value: someInstance1) + var loadedFromAtomic = atomic.load() + XCTAssert(someInstance1 === loadedFromAtomic) + weakSomeInstance2 = loadedFromAtomic + + XCTAssertTrue(atomic.compareAndExchange(expected: loadedFromAtomic, desired: someInstance2)) + + loadedFromAtomic = atomic.load() + weakSomeInstance3 = loadedFromAtomic + XCTAssert(someInstance1 !== loadedFromAtomic) + XCTAssert(someInstance2 === loadedFromAtomic) + + XCTAssertNotNil(weakSomeInstance1) + XCTAssertNotNil(weakSomeInstance2) + XCTAssertNotNil(weakSomeInstance3) + XCTAssert(weakSomeInstance1 === weakSomeInstance2 && weakSomeInstance2 !== weakSomeInstance3) + })() + XCTAssertNil(weakSomeInstance1) + XCTAssertNil(weakSomeInstance2) + XCTAssertNil(weakSomeInstance3) + } + + func testAtomicBoxCompareAndExchangeWorksIfNotEqual() throws { + class SomeClass {} + weak var weakSomeInstance1: SomeClass? = nil + weak var weakSomeInstance2: SomeClass? = nil + weak var weakSomeInstance3: SomeClass? = nil + ({ + let someInstance1 = SomeClass() + let someInstance2 = SomeClass() + weakSomeInstance1 = someInstance1 + + let atomic = AtomicBox(value: someInstance1) + var loadedFromAtomic = atomic.load() + XCTAssert(someInstance1 === loadedFromAtomic) + weakSomeInstance2 = loadedFromAtomic + + XCTAssertFalse(atomic.compareAndExchange(expected: someInstance2, desired: someInstance2)) + XCTAssertFalse(atomic.compareAndExchange(expected: SomeClass(), desired: someInstance2)) + XCTAssertTrue(atomic.load() === someInstance1) + + loadedFromAtomic = atomic.load() + weakSomeInstance3 = someInstance2 + XCTAssert(someInstance1 === loadedFromAtomic) + XCTAssert(someInstance2 !== loadedFromAtomic) + + XCTAssertNotNil(weakSomeInstance1) + XCTAssertNotNil(weakSomeInstance2) + XCTAssertNotNil(weakSomeInstance3) + })() + XCTAssertNil(weakSomeInstance1) + XCTAssertNil(weakSomeInstance2) + XCTAssertNil(weakSomeInstance3) + } + + func testAtomicBoxStoreWorks() throws { + class SomeClass {} + weak var weakSomeInstance1: SomeClass? = nil + weak var weakSomeInstance2: SomeClass? = nil + weak var weakSomeInstance3: SomeClass? = nil + ({ + let someInstance1 = SomeClass() + let someInstance2 = SomeClass() + weakSomeInstance1 = someInstance1 + + let atomic = AtomicBox(value: someInstance1) + var loadedFromAtomic = atomic.load() + XCTAssert(someInstance1 === loadedFromAtomic) + weakSomeInstance2 = loadedFromAtomic + + atomic.store(someInstance2) + + loadedFromAtomic = atomic.load() + weakSomeInstance3 = loadedFromAtomic + XCTAssert(someInstance1 !== loadedFromAtomic) + XCTAssert(someInstance2 === loadedFromAtomic) + + XCTAssertNotNil(weakSomeInstance1) + XCTAssertNotNil(weakSomeInstance2) + XCTAssertNotNil(weakSomeInstance3) + })() + XCTAssertNil(weakSomeInstance1) + XCTAssertNil(weakSomeInstance2) + XCTAssertNil(weakSomeInstance3) + } + + func testAtomicBoxCompareAndExchangeOntoItselfWorks() { + let q = DispatchQueue(label: "q") + let g = DispatchGroup() + let sem1 = DispatchSemaphore(value: 0) + let sem2 = DispatchSemaphore(value: 0) + class SomeClass {} + weak var weakInstance: SomeClass? + ({ + let instance = SomeClass() + weakInstance = instance + + let atomic = AtomicBox(value: instance) + q.async(group: g) { + sem1.signal() + sem2.wait() + for _ in 0..<1000 { + XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance)) + } + } + sem2.signal() + sem1.wait() + for _ in 0..<1000 { + XCTAssertTrue(atomic.compareAndExchange(expected: instance, desired: instance)) + } + g.wait() + let v = atomic.load() + XCTAssert(v === instance) + })() + XCTAssertNil(weakInstance) + } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift index 3c3f52c7..9c2ff4a8 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift @@ -19,6 +19,18 @@ import NIOFoundationCompat import Dispatch @testable import NIOHTTP1 +internal extension Channel { + func syncCloseAcceptingAlreadyClosed() throws { + do { + try self.close().wait() + } catch ChannelError.alreadyClosed { + /* we're happy with this one */ + } catch let e { + throw e + } + } +} + extension Array where Array.Element == ByteBuffer { public func allAsBytes() -> [UInt8] { var out: [UInt8] = [] @@ -61,17 +73,6 @@ internal class ArrayAccumulationHandler: ChannelInboundHandler { } class HTTPServerClientTest : XCTestCase { - - private func syncCloseAcceptingAlreadyClosed(channel: Channel) throws { - do { - try channel.close().wait() - } catch ChannelError.alreadyClosed { - /* we're happy with this one */ - } catch let e { - throw e - } - } - /* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */ private static let massiveResponseLength = 5 * 1024 * 1024 + 7 private static let massiveResponseBytes: [UInt8] = { @@ -374,7 +375,7 @@ class HTTPServerClientTest : XCTestCase { }.bind(host: "127.0.0.1", port: 0).wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } let clientChannel = try ClientBootstrap(group: group) @@ -387,7 +388,7 @@ class HTTPServerClientTest : XCTestCase { .wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/helloworld") @@ -432,7 +433,7 @@ class HTTPServerClientTest : XCTestCase { }.bind(host: "127.0.0.1", port: 0).wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } let clientChannel = try ClientBootstrap(group: group) @@ -445,7 +446,7 @@ class HTTPServerClientTest : XCTestCase { .wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/count-to-ten") @@ -490,7 +491,7 @@ class HTTPServerClientTest : XCTestCase { }.bind(host: "127.0.0.1", port: 0).wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } let clientChannel = try ClientBootstrap(group: group) @@ -503,7 +504,7 @@ class HTTPServerClientTest : XCTestCase { .wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/trailers") @@ -550,7 +551,7 @@ class HTTPServerClientTest : XCTestCase { }.bind(host: "127.0.0.1", port: 0).wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } let clientChannel = try ClientBootstrap(group: group) @@ -559,7 +560,7 @@ class HTTPServerClientTest : XCTestCase { .wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } var buffer = clientChannel.allocator.buffer(capacity: numBytes) @@ -591,7 +592,7 @@ class HTTPServerClientTest : XCTestCase { } }.bind(host: "127.0.0.1", port: 0).wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } let clientChannel = try ClientBootstrap(group: group) @@ -604,7 +605,7 @@ class HTTPServerClientTest : XCTestCase { .wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .HEAD, uri: "/head") @@ -636,7 +637,7 @@ class HTTPServerClientTest : XCTestCase { } }.bind(host: "127.0.0.1", port: 0).wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: serverChannel)) + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } let clientChannel = try ClientBootstrap(group: group) @@ -649,7 +650,7 @@ class HTTPServerClientTest : XCTestCase { .wait() defer { - XCTAssertNoThrow(try syncCloseAcceptingAlreadyClosed(channel: clientChannel)) + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/204") diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index cb7ac677..d70a8e51 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -55,6 +55,7 @@ extension ChannelTests { ("testHalfClosure", testHalfClosure), ("testRejectsInvalidData", testRejectsInvalidData), ("testWeDontCrashIfChannelReleasesBeforePipeline", testWeDontCrashIfChannelReleasesBeforePipeline), + ("testAskForLocalAndRemoteAddressesAfterChannelIsClosed", testAskForLocalAndRemoteAddressesAfterChannelIsClosed), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 85624d5b..e0aaacf3 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -1157,7 +1157,7 @@ public class ChannelTests: XCTestCase { return channel.pipeline.add(handler: byteCountingHandler) } } - .connect(to: server.localAddress!) + .connect(to: try! server.localAddress()) let accepted = try server.accept()! defer { XCTAssertNoThrow(try accepted.close()) @@ -1219,7 +1219,7 @@ public class ChannelTests: XCTestCase { return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false)) } } - .connect(to: server.localAddress!) + .connect(to: try! server.localAddress()) let accepted = try server.accept()! defer { XCTAssertNoThrow(try accepted.close()) @@ -1269,7 +1269,7 @@ public class ChannelTests: XCTestCase { return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false)) } .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) - .connect(to: server.localAddress!) + .connect(to: try! server.localAddress()) let accepted = try server.accept()! defer { XCTAssertNoThrow(try accepted.close()) @@ -1412,4 +1412,27 @@ public class ChannelTests: XCTestCase { XCTAssertNil(weakServerChannel, "weakServerChannel not nil, looks like we leaked it!") XCTAssertNil(weakServerChildChannel, "weakServerChildChannel not nil, looks like we leaked it!") } + + func testAskForLocalAndRemoteAddressesAfterChannelIsClosed() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .bind(host: "127.0.0.1", port: 0).wait() + + let clientChannel = try ClientBootstrap(group: group) + .connect(to: serverChannel.localAddress!).wait() + + + // Start shutting stuff down. + try serverChannel.syncCloseAcceptingAlreadyClosed() + try clientChannel.syncCloseAcceptingAlreadyClosed() + + for f in [ serverChannel.remoteAddress, serverChannel.localAddress, clientChannel.remoteAddress, clientChannel.localAddress ] { + XCTAssertNil(f) + } + } } diff --git a/Tests/NIOTests/TestUtils.swift b/Tests/NIOTests/TestUtils.swift index 5b57a46a..b8d6c70f 100644 --- a/Tests/NIOTests/TestUtils.swift +++ b/Tests/NIOTests/TestUtils.swift @@ -69,6 +69,17 @@ func openTemporaryFile() -> (CInt, String) { return (fd, String(decoding: templateBytes, as: UTF8.self)) } +internal extension Channel { + func syncCloseAcceptingAlreadyClosed() throws { + do { + try self.close().wait() + } catch ChannelError.alreadyClosed { + /* we're happy with this one */ + } catch let e { + throw e + } + } +} final class ByteCountingHandler : ChannelInboundHandler { typealias InboundIn = ByteBuffer