diff --git a/Sources/NIOCore/Channel.swift b/Sources/NIOCore/Channel.swift index 6daf6b24..2b447041 100644 --- a/Sources/NIOCore/Channel.swift +++ b/Sources/NIOCore/Channel.swift @@ -378,6 +378,25 @@ extension ChannelError: Equatable { } /// The removal of a `ChannelHandler` using `ChannelPipeline.removeHandler` has been attempted more than once. public struct NIOAttemptedToRemoveHandlerMultipleTimesError: Error {} +public enum DatagramChannelError { + public struct WriteOnUnconnectedSocketWithoutAddress: Error { + public init() {} + } + + public struct WriteOnConnectedSocketWithInvalidAddress: Error { + let envelopeRemoteAddress: SocketAddress + let connectedRemoteAddress: SocketAddress + + public init( + envelopeRemoteAddress: SocketAddress, + connectedRemoteAddress: SocketAddress + ) { + self.envelopeRemoteAddress = envelopeRemoteAddress + self.connectedRemoteAddress = connectedRemoteAddress + } + } +} + /// An `Channel` related event that is passed through the `ChannelPipeline` to notify the user. public enum ChannelEvent: Equatable, NIOSendable { /// `ChannelOptions.allowRemoteHalfClosure` is `true` and input portion of the `Channel` was closed. diff --git a/Sources/NIOPosix/BaseSocketChannel.swift b/Sources/NIOPosix/BaseSocketChannel.swift index 801f5af5..47ae117d 100644 --- a/Sources/NIOPosix/BaseSocketChannel.swift +++ b/Sources/NIOPosix/BaseSocketChannel.swift @@ -42,6 +42,9 @@ private struct SocketChannelLifecycleManager { // note: this can be `false` on a deactivated channel, we might just have torn it down. var hasSeenEOFNotification: Bool = false + // Should we support transition from `active` to `active`, used by datagram sockets. + let supportsReconnect: Bool + private var currentState: State = .fresh { didSet { self.eventLoop.assertInEventLoop() @@ -58,9 +61,14 @@ private struct SocketChannelLifecycleManager { // MARK: API // isActiveAtomic needs to be injected as it's accessed from arbitrary threads and `SocketChannelLifecycleManager` is usually held mutable - internal init(eventLoop: EventLoop, isActiveAtomic: NIOAtomic) { + internal init( + eventLoop: EventLoop, + isActiveAtomic: NIOAtomic, + supportReconnect: Bool + ) { self.eventLoop = eventLoop self.isActiveAtomic = isActiveAtomic + self.supportsReconnect = supportReconnect } // this is called from Channel's deinit, so don't assert we're on the EventLoop! @@ -140,6 +148,12 @@ private struct SocketChannelLifecycleManager { pipeline.syncOperations.fireChannelUnregistered() } + // origin: .activated + case (.activated, .activate) where self.supportsReconnect: + return { promise, pipeline in + promise?.succeed(()) + } + // bad transitions case (.fresh, .activate), // should go through .registered first (.preRegistered, .activate), // need to first be fully registered @@ -439,7 +453,13 @@ class BaseSocketChannel: SelectableChannel, Chan } // MARK: Common base socket logic. - init(socket: SocketType, parent: Channel?, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws { + init( + socket: SocketType, + parent: Channel?, + eventLoop: SelectableEventLoop, + recvAllocator: RecvByteBufferAllocator, + supportReconnect: Bool + ) throws { self._bufferAllocatorCache = self.bufferAllocator self.socket = socket self.selectableEventLoop = eventLoop @@ -448,7 +468,11 @@ class BaseSocketChannel: SelectableChannel, Chan self.recvAllocator = recvAllocator // As the socket may already be connected we should ensure we start with the correct addresses cached. self._addressCache = .init(local: try? socket.localAddress(), remote: try? socket.remoteAddress()) - self.lifecycleManager = SocketChannelLifecycleManager(eventLoop: eventLoop, isActiveAtomic: self.isActiveAtomic) + self.lifecycleManager = SocketChannelLifecycleManager( + eventLoop: eventLoop, + isActiveAtomic: self.isActiveAtomic, + supportReconnect: supportReconnect + ) self.socketDescription = socket.description self.pendingConnect = nil self._pipeline = ChannelPipeline(channel: self) diff --git a/Sources/NIOPosix/BaseStreamSocketChannel.swift b/Sources/NIOPosix/BaseStreamSocketChannel.swift index 0f09c01a..45f433ec 100644 --- a/Sources/NIOPosix/BaseStreamSocketChannel.swift +++ b/Sources/NIOPosix/BaseStreamSocketChannel.swift @@ -20,13 +20,21 @@ class BaseStreamSocketChannel: BaseSocketChannel private var outputShutdown: Bool = false private let pendingWrites: PendingStreamWritesManager - override init(socket: Socket, - parent: Channel?, - eventLoop: SelectableEventLoop, - recvAllocator: RecvByteBufferAllocator) throws { + init( + socket: Socket, + parent: Channel?, + eventLoop: SelectableEventLoop, + recvAllocator: RecvByteBufferAllocator + ) throws { self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs) self.connectTimeoutScheduled = nil - try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: recvAllocator) + try super.init( + socket: socket, + parent: parent, + eventLoop: eventLoop, + recvAllocator: recvAllocator, + supportReconnect: false + ) } deinit { diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index c0f5a0e2..8788178c 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -843,7 +843,7 @@ public final class DatagramBootstrap { func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { return try DatagramChannel(eventLoop: eventLoop, socket: socket) } - return bind0(makeChannel: makeChannel) { (eventLoop, channel) in + return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in let promise = eventLoop.makePromise(of: Void.self) channel.registerAlreadyConfigured0(promise: promise) return promise.futureResult @@ -907,14 +907,61 @@ public final class DatagramBootstrap { return try DatagramChannel(eventLoop: eventLoop, protocolFamily: address.protocol) } - return bind0(makeChannel: makeChannel) { (eventLoop, channel) in + return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in channel.register().flatMap { channel.bind(to: address) } } } - private func bind0(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ registerAndBind: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture) -> EventLoopFuture { + /// Connect the `DatagramChannel` to `host` and `port`. + /// + /// - parameters: + /// - host: The host to connect to. + /// - port: The port to connect to. + public func connect(host: String, port: Int) -> EventLoopFuture { + return connect0 { + return try SocketAddress.makeAddressResolvingHost(host, port: port) + } + } + + /// Connect the `DatagramChannel` to `address`. + /// + /// - parameters: + /// - address: The `SocketAddress` to connect to. + public func connect(to address: SocketAddress) -> EventLoopFuture { + return connect0 { address } + } + + /// Connect the `DatagramChannel` to a UNIX Domain Socket. + /// + /// - parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to connect to. `path` must not exist, it will be created by the system. + public func connect(unixDomainSocketPath: String) -> EventLoopFuture { + return connect0 { + return try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) + } + } + + private func connect0(_ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + let address: SocketAddress + do { + address = try makeSocketAddress() + } catch { + return group.next().makeFailedFuture(error) + } + func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { + return try DatagramChannel(eventLoop: eventLoop, + protocolFamily: address.protocol) + } + return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in + channel.register().flatMap { + channel.connect(to: address) + } + } + } + + private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture) -> EventLoopFuture { let eventLoop = self.group.next() let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } let channelOptions = self._channelOptions @@ -932,7 +979,7 @@ public final class DatagramBootstrap { channelInitializer(channel) }.flatMap { eventLoop.assertInEventLoop() - return registerAndBind(eventLoop, channel) + return bringup(eventLoop, channel) }.map { channel }.flatMapError { error in diff --git a/Sources/NIOPosix/PendingDatagramWritesManager.swift b/Sources/NIOPosix/PendingDatagramWritesManager.swift index ed7529ca..7849f0e9 100644 --- a/Sources/NIOPosix/PendingDatagramWritesManager.swift +++ b/Sources/NIOPosix/PendingDatagramWritesManager.swift @@ -17,7 +17,7 @@ import NIOConcurrencyHelpers private struct PendingDatagramWrite { var data: ByteBuffer var promise: Optional> - let address: SocketAddress + let address: SocketAddress? var metadata: AddressedEnvelope.Metadata? /// A helper function that copies the underlying sockaddr structure into temporary storage, @@ -31,7 +31,9 @@ private struct PendingDatagramWrite { func copySocketAddress(_ target: UnsafeMutablePointer) -> socklen_t { let erased = UnsafeMutableRawPointer(target) - switch address { + switch self.address { + case .none: + preconditionFailure("copySocketAddress called on write that has no address") case .v4(let innerAddress): erased.storeBytes(of: innerAddress.address, as: sockaddr_in.self) return socklen_t(MemoryLayout.size(ofValue: innerAddress.address)) @@ -99,14 +101,38 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite p.data.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in storageRefs[c] = storageRef.retain() - let addressLen = p.copySocketAddress(addresses.baseAddress! + c) + + /// From man page of `sendmsg(2)`: + /// + /// > The `msg_name` field is used on an unconnected socket to specify + /// > the target address for a datagram. It points to a buffer + /// > containing the address; the `msg_namelen` field should be set to + /// > the size of the address. For a connected socket, these fields + /// > should be specified as `NULL` and 0, respectively. + let address: UnsafeMutablePointer? + let addressLen: socklen_t + let protocolFamily: NIOBSDSocket.ProtocolFamily + if let envelopeAddress = p.address { + precondition(pending.remoteAddress == nil, "Pending write with address on connected socket.") + address = addresses.baseAddress! + c + addressLen = p.copySocketAddress(address!) + protocolFamily = envelopeAddress.protocol + } else { + guard let connectedRemoteAddress = pending.remoteAddress else { + preconditionFailure("Pending write without address on unconnected socket.") + } + address = nil + addressLen = 0 + protocolFamily = connectedRemoteAddress.protocol + } + iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer)) var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c]) - controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: p.address.protocol) + controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily) let controlMessageBytePointer = controlBytes.validControlBytes - let msg = msghdr(msg_name: addresses.baseAddress! + c, + let msg = msghdr(msg_name: address, msg_namelen: addressLen, msg_iov: iovecs.baseAddress! + c, msg_iovlen: 1, @@ -140,6 +166,7 @@ private struct PendingDatagramWritesState { private var pendingWrites = MarkedCircularBuffer(initialCapacity: 16) private var chunks: Int = 0 public private(set) var bytes: Int64 = 0 + private(set) var remoteAddress: SocketAddress? = nil public var nextWrite: PendingDatagramWrite? { return self.pendingWrites.first @@ -194,6 +221,10 @@ private struct PendingDatagramWritesState { self.pendingWrites.mark() } + mutating func markConnected(to remoteAddress: SocketAddress) { + self.remoteAddress = remoteAddress + } + /// Indicate that a write has happened, this may be a write of multiple outstanding writes (using for example `sendmmsg`). /// /// - warning: The closure will simply fulfill all the promises in order. If one of those promises does for example close the `Channel` we might see subsequent writes fail out of order. Example: Imagine the user issues three writes: `A`, `B` and `C`. Imagine that `A` and `B` both get successfully written in one write operation but the user closes the `Channel` in `A`'s callback. Then overall the promises will be fulfilled in this order: 1) `A`: success 2) `C`: error 3) `B`: success. Note how `B` and `C` get fulfilled out of order. @@ -402,6 +433,11 @@ final class PendingDatagramWritesManager: PendingWritesManager { self.state.markFlushCheckpoint() } + /// Mark that the socket is connected. + func markConnected(to remoteAddress: SocketAddress) { + self.state.markConnected(to: remoteAddress) + } + /// Is there a flush pending? var isFlushPending: Bool { return self.state.isFlushPending @@ -412,18 +448,9 @@ final class PendingDatagramWritesManager: PendingWritesManager { return self.state.isEmpty } - /// Add a pending write. - /// - /// - parameters: - /// - envelope: The `AddressedEnvelope` to write. - /// - promise: Optionally an `EventLoopPromise` that will get the write operation's result - /// - result: If the `Channel` is still writable after adding the write of `data`. - func add(envelope: AddressedEnvelope, promise: EventLoopPromise?) -> Bool { + private func add(_ pendingWrite: PendingDatagramWrite) -> Bool { assert(self.isOpen) - self.state.append(.init(data: envelope.data, - promise: promise, - address: envelope.remoteAddress, - metadata: envelope.metadata)) + self.state.append(pendingWrite) if self.state.bytes > waterMark.high && channelWritabilityFlag.compareAndExchange(expected: true, desired: false) { // Returns false to signal the Channel became non-writable and we need to notify the user. @@ -433,6 +460,48 @@ final class PendingDatagramWritesManager: PendingWritesManager { return true } + /// Add a pending write, with an `AddressedEnvelope`, usually on an unconnected socket. + /// + /// - parameters: + /// - envelope: The `AddressedEnvelope` to write. + /// - promise: Optionally an `EventLoopPromise` that will get the write operation's result + /// - returns: If the `Channel` is still writable after adding the write of `data`. + /// + /// - warning: If the socket is connected, then the `envelope.remoteAddress` _must_ match the + /// address of the connected peer, otherwise this function will throw a fatal error. + func add(envelope: AddressedEnvelope, promise: EventLoopPromise?) -> Bool { + if let remoteAddress = self.state.remoteAddress { + precondition(envelope.remoteAddress == remoteAddress, """ + Remote address of AddressedEnvelope does not match remote address of connected socket. + """) + return self.add(PendingDatagramWrite( + data: envelope.data, + promise: promise, + address: nil, + metadata: envelope.metadata)) + } else { + return self.add(PendingDatagramWrite( + data: envelope.data, + promise: promise, + address: envelope.remoteAddress, + metadata: envelope.metadata)) + } + } + + /// Add a pending write, without an `AddressedEnvelope`, on a connected socket. + /// + /// - parameters: + /// - data: The `ByteBuffer` to write. + /// - promise: Optionally an `EventLoopPromise` that will get the write operation's result + /// - returns: If the `Channel` is still writable after adding the write of `data`. + func add(data: ByteBuffer, promise: EventLoopPromise?) -> Bool { + return self.add(PendingDatagramWrite( + data: data, + promise: promise, + address: nil, + metadata: nil)) + } + /// Returns the best mechanism to write pending data at the current point in time. var currentBestWriteMechanism: WriteMechanism { return self.state.currentBestWriteMechanism @@ -442,10 +511,10 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// On platforms that do not support a gathering write operation, /// /// - parameters: - /// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendto`). + /// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendmsg`). /// - vectorWriteOperation: An operation that writes multiple contiguous arrays of bytes (usually `sendmmsg`). /// - returns: The `WriteResult` and whether the `Channel` is now writable. - func triggerAppropriateWriteOperations(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer, socklen_t, AddressedEnvelope.Metadata?) throws -> IOResult, + func triggerAppropriateWriteOperations(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer?, socklen_t, AddressedEnvelope.Metadata?) throws -> IOResult, vectorWriteOperation: (UnsafeMutableBufferPointer) throws -> IOResult) throws -> OverallWriteResult { return try self.triggerWriteOperations { writeMechanism in switch writeMechanism { @@ -515,14 +584,31 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// /// - parameters: /// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendto`). - private func triggerScalarBufferWrite(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer, socklen_t, AddressedEnvelope.Metadata?) throws -> IOResult) rethrows -> OneWriteOperationResult { + private func triggerScalarBufferWrite(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer?, socklen_t, AddressedEnvelope.Metadata?) throws -> IOResult) rethrows -> OneWriteOperationResult { assert(self.state.isFlushPending && self.isOpen && !self.state.isEmpty, "illegal state for scalar datagram write operation: flushPending: \(self.state.isFlushPending), isOpen: \(self.isOpen), empty: \(self.state.isEmpty)") let pending = self.state.nextWrite! do { - let writeResult = try pending.address.withSockAddr { (addrPtr, addrSize) in - try pending.data.withUnsafeReadableBytes { - try scalarWriteOperation($0, addrPtr, socklen_t(addrSize), pending.metadata) + let writeResult: IOResult + + if let address = pending.address { + assert(self.state.remoteAddress == nil, "Pending write with address on connected socket.") + writeResult = try address.withSockAddr { (addrPtr, addrSize) in + try pending.data.withUnsafeReadableBytes { + try scalarWriteOperation($0, addrPtr, socklen_t(addrSize), pending.metadata) + } + } + } else { + /// From man page of `sendmsg(2)`: + /// + /// > The `msg_name` field is used on an unconnected socket to specify + /// > the target address for a datagram. It points to a buffer + /// > containing the address; the `msg_namelen` field should be set to + /// > the size of the address. For a connected socket, these fields + /// > should be specified as `NULL` and 0, respectively. + assert(self.state.remoteAddress != nil, "Pending write without address on unconnected socket.") + writeResult = try pending.data.withUnsafeReadableBytes { + try scalarWriteOperation($0, nil, 0, pending.metadata) } } return self.didWrite(writeResult, messages: nil) diff --git a/Sources/NIOPosix/PipePair.swift b/Sources/NIOPosix/PipePair.swift index e24c518b..921fdab3 100644 --- a/Sources/NIOPosix/PipePair.swift +++ b/Sources/NIOPosix/PipePair.swift @@ -98,7 +98,7 @@ final class PipePair: SocketProtocol { } func sendmsg(pointer: UnsafeRawBufferPointer, - destinationPtr: UnsafePointer, + destinationPtr: UnsafePointer?, destinationSize: socklen_t, controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult { throw ChannelError.operationUnsupported diff --git a/Sources/NIOPosix/Socket.swift b/Sources/NIOPosix/Socket.swift index 03932867..2c01bc0a 100644 --- a/Sources/NIOPosix/Socket.swift +++ b/Sources/NIOPosix/Socket.swift @@ -151,7 +151,7 @@ typealias IOVector = iovec /// (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func sendmsg(pointer: UnsafeRawBufferPointer, - destinationPtr: UnsafePointer, + destinationPtr: UnsafePointer?, destinationSize: socklen_t, controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult { // Dubious const casts - it should be OK as there is no reason why this should get mutated diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 7e9eca34..e7750096 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -155,10 +155,13 @@ final class ServerSocketChannel: BaseSocketChannel { init(serverSocket: ServerSocket, eventLoop: SelectableEventLoop, group: EventLoopGroup) throws { self.group = group - try super.init(socket: serverSocket, - parent: nil, - eventLoop: eventLoop, - recvAllocator: AdaptiveRecvByteBufferAllocator()) + try super.init( + socket: serverSocket, + parent: nil, + eventLoop: eventLoop, + recvAllocator: AdaptiveRecvByteBufferAllocator(), + supportReconnect: false + ) } convenience init(socket: NIOBSDSocket.Handle, eventLoop: SelectableEventLoop, group: EventLoopGroup) throws { @@ -398,10 +401,13 @@ final class DatagramChannel: BaseSocketChannel { storageRefs: eventLoop.storageRefs, controlMessageStorage: eventLoop.controlMessageStorage) - try super.init(socket: socket, - parent: nil, - eventLoop: eventLoop, - recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) + try super.init( + socket: socket, + parent: nil, + eventLoop: eventLoop, + recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048), + supportReconnect: true + ) } init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws { @@ -412,7 +418,13 @@ final class DatagramChannel: BaseSocketChannel { addresses: eventLoop.addresses, storageRefs: eventLoop.storageRefs, controlMessageStorage: eventLoop.controlMessageStorage) - try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) + try super.init( + socket: socket, + parent: parent, + eventLoop: eventLoop, + recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048), + supportReconnect: true + ) } // MARK: Datagram Channel overrides required by BaseSocketChannel @@ -526,12 +538,24 @@ final class DatagramChannel: BaseSocketChannel { } override func connectSocket(to address: SocketAddress) throws -> Bool { - // For now we don't support operating in connected mode for datagram channels. - throw ChannelError.operationUnsupported + // TODO: this could be a channel option to do other things instead here, e.g. fail the connect + if !self.pendingWrites.isEmpty { + self.pendingWrites.failAll( + error: IOError( + errnoCode: EISCONN, + reason: "Socket was connected before flushing pending write."), + close: false) + } + if try self.socket.connect(to: address) { + self.pendingWrites.markConnected(to: address) + return true + } else { + preconditionFailure("Connect of datagram socket did not complete synchronously.") + } } override func finishConnectSocket() throws { - // For now we don't support operating in connected mode for datagram channels. + // This is not required for connected datagram channels connect is a synchronous operation. throw ChannelError.operationUnsupported } @@ -668,11 +692,52 @@ final class DatagramChannel: BaseSocketChannel { return true } } - /// Buffer a write in preparation for a flush. - override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { - let data = self.unwrapData(data, as: AddressedEnvelope.self) - if !self.pendingWrites.add(envelope: data, promise: promise) { + /// Buffer a write in preparation for a flush. + /// + /// When the channel is unconnected, `data` _must_ be of type `AddressedEnvelope`. + /// + /// When the channel is connected, `data` _should_ be of type `ByteBuffer`, but _may_ be of type + /// `AddressedEnvelope` to allow users to provide protocol control messages via + /// `AddressedEnvelope.metadata`. In this case, `AddressedEnvelope.remoteAddress` _must_ match + /// the address of the connected peer. + override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { + if let envelope = self.tryUnwrapData(data, as: AddressedEnvelope.self) { + return bufferPendingAddressedWrite(envelope: envelope, promise: promise) + } + // If it's not an `AddressedEnvelope` then it must be a `ByteBuffer` so we let the common + // `unwrapData(_:as:)` throw the fatal error if it's some other type. + let data = self.unwrapData(data, as: ByteBuffer.self) + return bufferPendingUnaddressedWrite(data: data, promise: promise) + } + + /// Buffer a write in preparation for a flush. + private func bufferPendingUnaddressedWrite(data: ByteBuffer, promise: EventLoopPromise?) { + // It is only appropriate to not use an AddressedEnvelope if the socket is connected. + guard self.remoteAddress != nil else { + promise?.fail(DatagramChannelError.WriteOnUnconnectedSocketWithoutAddress()) + return + } + + if !self.pendingWrites.add(data: data, promise: promise) { + assert(self.isActive) + self.pipeline.syncOperations.fireChannelWritabilityChanged() + } + } + + /// Buffer a write in preparation for a flush. + private func bufferPendingAddressedWrite(envelope: AddressedEnvelope, promise: EventLoopPromise?) { + // If the socket is connected, check the remote provided matches the connected address. + if let connectedRemoteAddress = self.remoteAddress { + guard envelope.remoteAddress == connectedRemoteAddress else { + promise?.fail(DatagramChannelError.WriteOnConnectedSocketWithInvalidAddress( + envelopeRemoteAddress: envelope.remoteAddress, + connectedRemoteAddress: connectedRemoteAddress)) + return + } + } + + if !self.pendingWrites.add(envelope: envelope, promise: promise) { assert(self.isActive) self.pipeline.syncOperations.fireChannelWritabilityChanged() } diff --git a/Sources/NIOPosix/SocketProtocols.swift b/Sources/NIOPosix/SocketProtocols.swift index bd3e86de..ab713021 100644 --- a/Sources/NIOPosix/SocketProtocols.swift +++ b/Sources/NIOPosix/SocketProtocols.swift @@ -54,7 +54,7 @@ protocol SocketProtocol: BaseSocketProtocol { controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult func sendmsg(pointer: UnsafeRawBufferPointer, - destinationPtr: UnsafePointer, + destinationPtr: UnsafePointer?, destinationSize: socklen_t, controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult diff --git a/Sources/NIOUDPEchoClient/main.swift b/Sources/NIOUDPEchoClient/main.swift index 6e3be25d..28f2eb91 100644 --- a/Sources/NIOUDPEchoClient/main.swift +++ b/Sources/NIOUDPEchoClient/main.swift @@ -73,7 +73,16 @@ private final class EchoHandler: ChannelInboundHandler { } // First argument is the program path -let arguments = CommandLine.arguments +var arguments = CommandLine.arguments +// Support for `--connect` if it appears as the first argument. +let connectedMode: Bool +if let connectedModeFlagIndex = arguments.firstIndex(where: { $0 == "--connect" }) { + connectedMode = true + arguments.remove(at: connectedModeFlagIndex) +} else { + connectedMode = false +} +// Now process the positional arguments. let arg1 = arguments.dropFirst().first let arg2 = arguments.dropFirst(2).first let arg3 = arguments.dropFirst(3).first @@ -133,7 +142,13 @@ let channel = try { () -> Channel in case .unixDomainSocket(_, let listeningPath): return try bootstrap.bind(unixDomainSocketPath: listeningPath).wait() } - }() +}() + +if connectedMode { + let remoteAddress = try remoteAddress() + print("Connecting to remote: \(remoteAddress)") + try channel.connect(to: remoteAddress).wait() +} // Will be closed after we echo-ed back to the server. try channel.closeFuture.wait() diff --git a/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift b/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift index 7c740ecf..6b45080d 100644 --- a/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift +++ b/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift @@ -2,7 +2,7 @@ // // This source file is part of the SwiftNIO open source project // -// Copyright (c) 2018-2021 Apple Inc. and the SwiftNIO project authors +// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information @@ -29,7 +29,6 @@ extension DatagramChannelTests { return [ ("testBasicChannelCommunication", testBasicChannelCommunication), ("testManyWrites", testManyWrites), - ("testConnectionFails", testConnectionFails), ("testDatagramChannelHasWatermark", testDatagramChannelHasWatermark), ("testWriteFuturesFailWhenChannelClosed", testWriteFuturesFailWhenChannelClosed), ("testManyManyDatagramWrites", testManyManyDatagramWrites), @@ -67,6 +66,14 @@ extension DatagramChannelTests { ("testReceiveEcnAndPacketInfoIPV6VectorRead", testReceiveEcnAndPacketInfoIPV6VectorRead), ("testReceiveEcnAndPacketInfoIPV4VectorReadVectorWrite", testReceiveEcnAndPacketInfoIPV4VectorReadVectorWrite), ("testReceiveEcnAndPacketInfoIPV6VectorReadVectorWrite", testReceiveEcnAndPacketInfoIPV6VectorReadVectorWrite), + ("testSendingAddressedEnvelopeOnUnconnectedSocketSucceeds", testSendingAddressedEnvelopeOnUnconnectedSocketSucceeds), + ("testSendingByteBufferOnUnconnectedSocketFails", testSendingByteBufferOnUnconnectedSocketFails), + ("testSendingByteBufferOnConnectedSocketSucceeds", testSendingByteBufferOnConnectedSocketSucceeds), + ("testSendingAddressedEnvelopeOnConnectedSocketSucceeds", testSendingAddressedEnvelopeOnConnectedSocketSucceeds), + ("testSendingAddressedEnvelopeOnConnectedSocketWithDifferentAddressFails", testSendingAddressedEnvelopeOnConnectedSocketWithDifferentAddressFails), + ("testConnectingSocketAfterFlushingExistingMessages", testConnectingSocketAfterFlushingExistingMessages), + ("testConnectingSocketFailsBufferedWrites", testConnectingSocketFailsBufferedWrites), + ("testReconnectingSocketFailsBufferedWrites", testReconnectingSocketFailsBufferedWrites), ] } } diff --git a/Tests/NIOPosixTests/DatagramChannelTests.swift b/Tests/NIOPosixTests/DatagramChannelTests.swift index 90816619..53ffbff4 100644 --- a/Tests/NIOPosixTests/DatagramChannelTests.swift +++ b/Tests/NIOPosixTests/DatagramChannelTests.swift @@ -102,10 +102,11 @@ private class DatagramReadRecorder: ChannelInboundHandler { } } -final class DatagramChannelTests: XCTestCase { +class DatagramChannelTests: XCTestCase { private var group: MultiThreadedEventLoopGroup! = nil private var firstChannel: Channel! = nil private var secondChannel: Channel! = nil + private var thirdChannel: Channel! = nil private func buildChannel(group: EventLoopGroup, host: String = "127.0.0.1") throws -> Channel { return try DatagramBootstrap(group: group) @@ -128,9 +129,11 @@ final class DatagramChannelTests: XCTestCase { override func setUp() { super.setUp() + self.continueAfterFailure = false self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) self.firstChannel = try! buildChannel(group: group) self.secondChannel = try! buildChannel(group: group) + self.thirdChannel = try! buildChannel(group: group) } override func tearDown() { @@ -173,12 +176,6 @@ final class DatagramChannelTests: XCTestCase { } } - func testConnectionFails() throws { - XCTAssertThrowsError(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait()) { error in - XCTAssertEqual(.operationUnsupported, error as? ChannelError) - } - } - func testDatagramChannelHasWatermark() throws { _ = try self.firstChannel.setOption(ChannelOptions.writeBufferWaterMark, value: ChannelOptions.Types.WriteBufferWaterMark(low: 1, high: 1024)).wait() @@ -916,4 +913,214 @@ final class DatagramChannelTests: XCTestCase { } testEcnAndPacketInfoReceive(address: "::1", vectorRead: true, vectorSend: true, receivePacketInfo: true) } + + func assertSending( + data: ByteBuffer, + from sourceChannel: Channel, + to destinationChannel: Channel, + wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool, + resultsIn expectedResult: Result, + file: StaticString = #file, + line: UInt = #line + ) throws { + // Wrap data in AddressedEnvelope if required. + let writePayload: NIOAny + if shouldWrapInAddressedEnvelope { + let envelope = AddressedEnvelope(remoteAddress: destinationChannel.localAddress!, data: data) + writePayload = NIOAny(envelope) + } else { + writePayload = NIOAny(data) + } + + // Write and flush. + let writeResult = sourceChannel.writeAndFlush(writePayload) + + // Check the expected result. + switch expectedResult { + case .success: + // Check the write succeeded. + XCTAssertNoThrow(try writeResult.wait()) + + // Check the destination received the sent payload. + let reads = try destinationChannel.waitForDatagrams(count: 1) + XCTAssertEqual(reads.count, 1) + let read = reads.first! + XCTAssertEqual(read.data, data) + XCTAssertEqual(read.remoteAddress, sourceChannel.localAddress!) + + case .failure(let expectedError): + // Check the error is of the expected type. + XCTAssertThrowsError(try writeResult.wait()) { error in + guard type(of: error) == type(of: expectedError) else { + XCTFail("expected error of type \(type(of: expectedError)), but caught other error of type (\(type(of: error)): \(error)") + return + } + } + } + } + + func assertSendingHelloWorld( + from sourceChannel: Channel, + to destinationChannel: Channel, + wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool, + resultsIn expectedResult: Result, + file: StaticString = #file, + line: UInt = #line + ) throws { + try self.assertSending( + data: sourceChannel.allocator.buffer(staticString: "hello, world!"), + from: sourceChannel, + to: destinationChannel, + wrappingInAddressedEnvelope: shouldWrapInAddressedEnvelope, + resultsIn: expectedResult, + file: file, + line: line + ) + } + + func bufferWrite( + of data: ByteBuffer, + from sourceChannel: Channel, + to destinationChannel: Channel, + wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool + ) -> EventLoopFuture { + if shouldWrapInAddressedEnvelope { + let envelope = AddressedEnvelope(remoteAddress: destinationChannel.localAddress!, data: data) + return sourceChannel.write(envelope) + } else { + return sourceChannel.write(data) + } + } + + func bufferWriteOfHelloWorld( + from sourceChannel: Channel, + to destinationChannel: Channel, + wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool + ) -> EventLoopFuture { + self.bufferWrite( + of: sourceChannel.allocator.buffer(staticString: "hello, world!"), + from: sourceChannel, + to: destinationChannel, + wrappingInAddressedEnvelope: shouldWrapInAddressedEnvelope + ) + } + + func testSendingAddressedEnvelopeOnUnconnectedSocketSucceeds() throws { + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.secondChannel, + wrappingInAddressedEnvelope: true, + resultsIn: .success(()) + ) + } + + func testSendingByteBufferOnUnconnectedSocketFails() throws { + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.secondChannel, + wrappingInAddressedEnvelope: false, + resultsIn: .failure(DatagramChannelError.WriteOnUnconnectedSocketWithoutAddress()) + ) + } + + func testSendingByteBufferOnConnectedSocketSucceeds() throws { + XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait()) + + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.secondChannel, + wrappingInAddressedEnvelope: false, + resultsIn: .success(()) + ) + } + + func testSendingAddressedEnvelopeOnConnectedSocketSucceeds() throws { + XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait()) + + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.secondChannel, + wrappingInAddressedEnvelope: true, + resultsIn: .success(()) + ) + } + + func testSendingAddressedEnvelopeOnConnectedSocketWithDifferentAddressFails() throws { + XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait()) + + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.thirdChannel, + wrappingInAddressedEnvelope: true, + resultsIn: .failure(DatagramChannelError.WriteOnConnectedSocketWithInvalidAddress( + envelopeRemoteAddress: self.thirdChannel.localAddress!, + connectedRemoteAddress: self.secondChannel.localAddress!)) + ) + } + + func testConnectingSocketAfterFlushingExistingMessages() throws { + // Send message from firstChannel to secondChannel. + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.secondChannel, + wrappingInAddressedEnvelope: true, + resultsIn: .success(()) + ) + + // Connect firstChannel to thirdChannel. + XCTAssertNoThrow(try self.firstChannel.connect(to: self.thirdChannel.localAddress!).wait()) + + // Send message from firstChannel to thirdChannel. + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.thirdChannel, + wrappingInAddressedEnvelope: false, + resultsIn: .success(()) + ) + } + + func testConnectingSocketFailsBufferedWrites() throws { + // Buffer message from firstChannel to secondChannel. + let bufferedWrite = bufferWriteOfHelloWorld(from: self.firstChannel, to: self.secondChannel, wrappingInAddressedEnvelope: true) + + // Connect firstChannel to thirdChannel. + XCTAssertNoThrow(try self.firstChannel.connect(to: self.thirdChannel.localAddress!).wait()) + + // Check that the buffered write was failed. + XCTAssertThrowsError(try bufferedWrite.wait()) { error in + XCTAssertEqual((error as? IOError)?.errnoCode, EISCONN, "expected EISCONN, but caught other error: \(error)") + } + + // Send message from firstChannel to thirdChannel. + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.thirdChannel, + wrappingInAddressedEnvelope: false, + resultsIn: .success(()) + ) + } + + func testReconnectingSocketFailsBufferedWrites() throws { + // Connect firstChannel to secondChannel. + XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait()) + + // Buffer message from firstChannel to secondChannel. + let bufferedWrite = bufferWriteOfHelloWorld(from: self.firstChannel, to: self.secondChannel, wrappingInAddressedEnvelope: false) + + // Connect firstChannel to thirdChannel. + XCTAssertNoThrow(try self.firstChannel.connect(to: self.thirdChannel.localAddress!).wait()) + + // Check that the buffered write was failed. + XCTAssertThrowsError(try bufferedWrite.wait()) { error in + XCTAssertEqual((error as? IOError)?.errnoCode, EISCONN, "expected EISCONN, but caught other error: \(error)") + } + + // Send message from firstChannel to thirdChannel. + try self.assertSendingHelloWorld( + from: self.firstChannel, + to: self.thirdChannel, + wrappingInAddressedEnvelope: false, + resultsIn: .success(()) + ) + } } diff --git a/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift b/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift index 7a2595d0..cac74b83 100644 --- a/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift +++ b/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift @@ -126,7 +126,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { if expected.count > singleState { XCTAssertGreaterThan(returns.count, everythingState) XCTAssertEqual(expected[singleState].0, buf.count, "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].0) bytes expected but \(buf.count) actual", file: (file), line: line) - XCTAssertEqual(expected[singleState].1, SocketAddress(addr), "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].1) address expected but \(SocketAddress(addr)) received", file: (file), line: line) + XCTAssertEqual(expected[singleState].1, addr.map(SocketAddress.init), "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].1) address expected but \(String(describing: addr.map(SocketAddress.init))) received", file: (file), line: line) XCTAssertEqual(expected[singleState].1.expectedSize, len, "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].1.expectedSize) socklen expected but \(len) received", file: (file), line: line) switch returns[everythingState] {