ClientBootstrap: allow binding sockets (#1490)
This commit is contained in:
parent
5fc487bf7e
commit
120acb15c3
|
@ -337,7 +337,7 @@ class BaseSocket: BaseSocketProtocol {
|
|||
/// - name: The name of the option to set.
|
||||
/// - value: The value for the option.
|
||||
/// - throws: An `IOError` if the operation failed.
|
||||
final func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
|
||||
func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
|
||||
if level == .tcp && name == .tcp_nodelay && (try? self.localAddress().protocol) == Optional<NIOBSDSocket.ProtocolFamily>.some(.unix) {
|
||||
// setting TCP_NODELAY on UNIX domain sockets will fail. Previously we had a bug where we would ignore
|
||||
// most socket options settings so for the time being we'll just ignore this. Let's revisit for NIO 2.0.
|
||||
|
@ -387,7 +387,7 @@ class BaseSocket: BaseSocketProtocol {
|
|||
/// - parameters:
|
||||
/// - address: The `SocketAddress` to which the socket should be bound.
|
||||
/// - throws: An `IOError` if the operation failed.
|
||||
final func bind(to address: SocketAddress) throws {
|
||||
func bind(to address: SocketAddress) throws {
|
||||
try self.withUnsafeHandle { fd in
|
||||
func doBind(ptr: UnsafePointer<sockaddr>, bytes: Int) throws {
|
||||
try Posix.bind(descriptor: fd, ptr: ptr, bytes: bytes)
|
||||
|
|
|
@ -645,10 +645,6 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
|
|||
promise?.fail(ChannelError.ioOnClosedChannel)
|
||||
return
|
||||
}
|
||||
guard self.lifecycleManager.isPreRegistered else {
|
||||
promise?.fail(ChannelError.inappropriateOperationForState)
|
||||
return
|
||||
}
|
||||
|
||||
executeAndComplete(promise) {
|
||||
try socket.bind(to: address)
|
||||
|
|
|
@ -426,6 +426,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
internal var _channelOptions: ChannelOptions.Storage
|
||||
private var connectTimeout: TimeAmount = TimeAmount.seconds(10)
|
||||
private var resolver: Optional<Resolver>
|
||||
private var bindTarget: Optional<SocketAddress>
|
||||
|
||||
/// Create a `ClientBootstrap` on the `EventLoopGroup` `group`.
|
||||
///
|
||||
|
@ -458,6 +459,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
self._channelInitializer = { channel in channel.eventLoop.makeSucceededFuture(()) }
|
||||
self.protocolHandlers = nil
|
||||
self.resolver = nil
|
||||
self.bindTarget = nil
|
||||
}
|
||||
|
||||
/// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add
|
||||
|
@ -523,6 +525,24 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
return self
|
||||
}
|
||||
|
||||
/// Bind the `SocketChannel` to `address`.
|
||||
///
|
||||
/// Using `bind` is not necessary unless you need the local address to be bound to a specific address.
|
||||
///
|
||||
/// - note: Using `bind` will disable Happy Eyeballs on this `Channel`.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - address: The `SocketAddress` to bind on.
|
||||
public func bind(to address: SocketAddress) -> ClientBootstrap {
|
||||
self.bindTarget = address
|
||||
return self
|
||||
}
|
||||
|
||||
func makeSocketChannel(eventLoop: EventLoop,
|
||||
protocolFamily: NIOBSDSocket.ProtocolFamily) throws -> SocketChannel {
|
||||
return try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily)
|
||||
}
|
||||
|
||||
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - parameters:
|
||||
|
@ -531,34 +551,51 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
/// - returns: An `EventLoopFuture<Channel>` to deliver the `Channel` when connected.
|
||||
public func connect(host: String, port: Int) -> EventLoopFuture<Channel> {
|
||||
let loop = self.group.next()
|
||||
let connector = HappyEyeballsConnector(resolver: resolver ?? GetaddrinfoResolver(loop: loop, aiSocktype: .stream, aiProtocol: CInt(IPPROTO_TCP)),
|
||||
let resolver = self.resolver ?? GetaddrinfoResolver(loop: loop,
|
||||
aiSocktype: .stream,
|
||||
aiProtocol: CInt(IPPROTO_TCP))
|
||||
let connector = HappyEyeballsConnector(resolver: resolver,
|
||||
loop: loop,
|
||||
host: host,
|
||||
port: port,
|
||||
connectTimeout: self.connectTimeout) { eventLoop, protocolFamily in
|
||||
return self.execute(eventLoop: eventLoop, protocolFamily: protocolFamily) { $0.eventLoop.makeSucceededFuture(()) }
|
||||
return self.initializeAndRegisterNewChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) {
|
||||
$0.eventLoop.makeSucceededFuture(())
|
||||
}
|
||||
}
|
||||
return connector.resolveAndConnect()
|
||||
}
|
||||
|
||||
private func connect(freshChannel channel: Channel, address: SocketAddress) -> EventLoopFuture<Void> {
|
||||
let connectPromise = channel.eventLoop.makePromise(of: Void.self)
|
||||
channel.connect(to: address, promise: connectPromise)
|
||||
let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) {
|
||||
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
|
||||
channel.close(promise: nil)
|
||||
}
|
||||
|
||||
connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
|
||||
cancelTask.cancel()
|
||||
}
|
||||
return connectPromise.futureResult
|
||||
}
|
||||
|
||||
internal func testOnly_connect(injectedChannel: SocketChannel,
|
||||
to address: SocketAddress) -> EventLoopFuture<Channel> {
|
||||
return self.initializeAndRegisterChannel(injectedChannel) { channel in
|
||||
return self.connect(freshChannel: channel, address: address)
|
||||
}
|
||||
}
|
||||
|
||||
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - address: The address to connect to.
|
||||
/// - returns: An `EventLoopFuture<Channel>` to deliver the `Channel` when connected.
|
||||
public func connect(to address: SocketAddress) -> EventLoopFuture<Channel> {
|
||||
return execute(eventLoop: group.next(), protocolFamily: address.protocol) { channel in
|
||||
let connectPromise = channel.eventLoop.makePromise(of: Void.self)
|
||||
channel.connect(to: address, promise: connectPromise)
|
||||
let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) {
|
||||
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
|
||||
channel.close(promise: nil)
|
||||
}
|
||||
|
||||
connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
|
||||
cancelTask.cancel()
|
||||
}
|
||||
return connectPromise.futureResult
|
||||
return self.initializeAndRegisterNewChannel(eventLoop: self.group.next(),
|
||||
protocolFamily: address.protocol) { channel in
|
||||
return self.connect(freshChannel: channel, address: address)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -570,9 +607,9 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
public func connect(unixDomainSocketPath: String) -> EventLoopFuture<Channel> {
|
||||
do {
|
||||
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
|
||||
return connect(to: address)
|
||||
return self.connect(to: address)
|
||||
} catch {
|
||||
return group.next().makeFailedFuture(error)
|
||||
return self.group.next().makeFailedFuture(error)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -615,24 +652,35 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
}
|
||||
}
|
||||
|
||||
private func execute(eventLoop: EventLoop,
|
||||
protocolFamily: NIOBSDSocket.ProtocolFamily,
|
||||
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
|
||||
let channelInitializer = self.channelInitializer
|
||||
let channelOptions = self._channelOptions
|
||||
|
||||
private func initializeAndRegisterNewChannel(eventLoop: EventLoop,
|
||||
protocolFamily: NIOBSDSocket.ProtocolFamily,
|
||||
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
|
||||
let channel: SocketChannel
|
||||
do {
|
||||
channel = try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily)
|
||||
channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily)
|
||||
} catch {
|
||||
return eventLoop.makeFailedFuture(error)
|
||||
}
|
||||
return self.initializeAndRegisterChannel(channel, body)
|
||||
}
|
||||
|
||||
private func initializeAndRegisterChannel(_ channel: SocketChannel,
|
||||
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
|
||||
let channelInitializer = self.channelInitializer
|
||||
let channelOptions = self._channelOptions
|
||||
let eventLoop = channel.eventLoop
|
||||
|
||||
@inline(__always)
|
||||
func setupChannel() -> EventLoopFuture<Channel> {
|
||||
eventLoop.assertInEventLoop()
|
||||
return channelOptions.applyAllChannelOptions(to: channel).flatMap {
|
||||
channelInitializer(channel)
|
||||
if let bindTarget = self.bindTarget {
|
||||
return channel.bind(to: bindTarget).flatMap {
|
||||
channelInitializer(channel)
|
||||
}
|
||||
} else {
|
||||
return channelInitializer(channel)
|
||||
}
|
||||
}.flatMap {
|
||||
eventLoop.assertInEventLoop()
|
||||
return channel.registerAndDoSynchronously(body)
|
||||
|
@ -647,7 +695,9 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
if eventLoop.inEventLoop {
|
||||
return setupChannel()
|
||||
} else {
|
||||
return eventLoop.submit(setupChannel).flatMap { $0 }
|
||||
return eventLoop.flatSubmit {
|
||||
setupChannel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ extension BootstrapTest {
|
|||
("testDatagramBootstrapRejectsNotWorkingELGsCorrectly", testDatagramBootstrapRejectsNotWorkingELGsCorrectly),
|
||||
("testNIOPipeBootstrapValidatesWorkingELGsCorrectly", testNIOPipeBootstrapValidatesWorkingELGsCorrectly),
|
||||
("testNIOPipeBootstrapRejectsNotWorkingELGsCorrectly", testNIOPipeBootstrapRejectsNotWorkingELGsCorrectly),
|
||||
("testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only", testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -529,6 +529,88 @@ class BootstrapTest: XCTestCase {
|
|||
XCTAssertNil(NIOPipeBootstrap(validatingGroup: elg))
|
||||
XCTAssertNil(NIOPipeBootstrap(validatingGroup: el))
|
||||
}
|
||||
|
||||
func testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only() {
|
||||
for isIPv4 in [true, false] {
|
||||
guard System.supportsIPv6 || isIPv4 else {
|
||||
continue // need to skip IPv6 tests if we don't support it.
|
||||
}
|
||||
let localIP = isIPv4 ? "127.0.0.1" : "::1"
|
||||
guard let serverLocalAddressChoice = try? SocketAddress(ipAddress: localIP, port: 0),
|
||||
let clientLocalAddressWholeInterface = try? SocketAddress(ipAddress: localIP, port: 0),
|
||||
let server1 = (try? ServerBootstrap(group: self.group)
|
||||
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
|
||||
.bind(to: serverLocalAddressChoice)
|
||||
.wait()),
|
||||
let server2 = (try? ServerBootstrap(group: self.group)
|
||||
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
|
||||
.bind(to: serverLocalAddressChoice)
|
||||
.wait()),
|
||||
let server1LocalAddress = server1.localAddress,
|
||||
let server2LocalAddress = server2.localAddress else {
|
||||
XCTFail("can't boot servers even")
|
||||
return
|
||||
}
|
||||
defer {
|
||||
XCTAssertNoThrow(try server1.close().wait())
|
||||
XCTAssertNoThrow(try server2.close().wait())
|
||||
}
|
||||
|
||||
// Try 1: Directly connect to 127.0.0.1, this won't do Happy Eyeballs.
|
||||
XCTAssertNoThrow(try ClientBootstrap(group: self.group)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.bind(to: clientLocalAddressWholeInterface)
|
||||
.connect(to: server1LocalAddress)
|
||||
.wait()
|
||||
.close()
|
||||
.wait())
|
||||
|
||||
var maybeChannel1: Channel? = nil
|
||||
// Try 2: Connect to "localhost", this will do Happy Eyeballs.
|
||||
XCTAssertNoThrow(maybeChannel1 = try ClientBootstrap(group: self.group)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.bind(to: clientLocalAddressWholeInterface)
|
||||
.connect(host: "localhost", port: server1LocalAddress.port!)
|
||||
.wait())
|
||||
guard let myChannel1 = maybeChannel1, let myChannel1Address = myChannel1.localAddress else {
|
||||
XCTFail("can't connect channel 1")
|
||||
return
|
||||
}
|
||||
XCTAssertEqual(localIP, maybeChannel1?.localAddress?.ipAddress)
|
||||
// Try 3: Bind the client to the same address/port as in try 2 but to server 2.
|
||||
XCTAssertNoThrow(try ClientBootstrap(group: self.group)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.connectTimeout(.hours(2))
|
||||
.bind(to: myChannel1Address)
|
||||
.connect(to: server2LocalAddress)
|
||||
.map { channel -> Channel in
|
||||
XCTAssertEqual(myChannel1Address, channel.localAddress)
|
||||
return channel
|
||||
}
|
||||
.wait()
|
||||
.close()
|
||||
.wait())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private final class WriteStringOnChannelActive: ChannelInboundHandler {
|
||||
typealias InboundIn = Never
|
||||
typealias OutboundOut = ByteBuffer
|
||||
|
||||
let string: String
|
||||
|
||||
init(_ string: String) {
|
||||
self.string = string
|
||||
}
|
||||
|
||||
func channelActive(context: ChannelHandlerContext) {
|
||||
var buffer = context.channel.allocator.buffer(capacity: self.string.utf8.count)
|
||||
buffer.writeString(string)
|
||||
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
private final class MakeSureAutoReadIsOffInChannelInitializer: ChannelInboundHandler {
|
||||
|
|
|
@ -2012,7 +2012,7 @@ public final class ChannelTests: XCTestCase {
|
|||
XCTAssertFalse(channel.isWritable)
|
||||
}
|
||||
|
||||
withChannel { channel in
|
||||
withChannel(skipStream: true) { channel in
|
||||
checkThatItThrowsInappropriateOperationForState {
|
||||
XCTAssertEqual(0, channel.localAddress?.port ?? 0xffff)
|
||||
XCTAssertNil(channel.remoteAddress)
|
||||
|
|
|
@ -49,15 +49,6 @@ final class MulticastTest: XCTestCase {
|
|||
|
||||
struct ReceivedDatagramError: Error { }
|
||||
|
||||
private var supportsIPv6: Bool {
|
||||
do {
|
||||
let ipv6Loopback = try SocketAddress(ipAddress: "::1", port: 0)
|
||||
return try System.enumerateInterfaces().contains(where: { $0.address == ipv6Loopback })
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func interfaceForAddress(address: String) throws -> NIONetworkInterface {
|
||||
let targetAddress = try SocketAddress(ipAddress: address, port: 0)
|
||||
guard let interface = try System.enumerateInterfaces().lazy.filter({ $0.address == targetAddress }).first else {
|
||||
|
@ -220,7 +211,7 @@ final class MulticastTest: XCTestCase {
|
|||
}
|
||||
|
||||
func testCanJoinBasicMulticastGroupIPv6() throws {
|
||||
guard self.supportsIPv6 else {
|
||||
guard System.supportsIPv6 else {
|
||||
// Skip on non-IPv6 systems
|
||||
return
|
||||
}
|
||||
|
@ -317,7 +308,7 @@ final class MulticastTest: XCTestCase {
|
|||
}
|
||||
|
||||
func testCanLeaveAnIPv6MulticastGroup() throws {
|
||||
guard self.supportsIPv6 else {
|
||||
guard System.supportsIPv6 else {
|
||||
// Skip on non-IPv6 systems
|
||||
return
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ extension SALChannelTest {
|
|||
("testWritesFromWritabilityNotificationsDoNotGetLostIfWePreviouslyWroteEverything", testWritesFromWritabilityNotificationsDoNotGetLostIfWePreviouslyWroteEverything),
|
||||
("testWeSurviveIfIgnoringSIGPIPEFails", testWeSurviveIfIgnoringSIGPIPEFails),
|
||||
("testBasicRead", testBasicRead),
|
||||
("testBasicConnectWithClientBootstrap", testBasicConnectWithClientBootstrap),
|
||||
("testClientBootstrapBindIsDoneAfterSocketOptions", testClientBootstrapBindIsDoneAfterSocketOptions),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -235,4 +235,86 @@ final class SALChannelTest: XCTestCase, SALTest {
|
|||
|
||||
g.wait()
|
||||
}
|
||||
|
||||
func testBasicConnectWithClientBootstrap() {
|
||||
guard let channel = try? self.makeSocketChannel() else {
|
||||
XCTFail("couldn't make a channel")
|
||||
return
|
||||
}
|
||||
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
|
||||
let serverAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 5)
|
||||
XCTAssertNoThrow(try channel.eventLoop.runSAL(syscallAssertions: {
|
||||
try self.assertSetOption(expectedLevel: .tcp, expectedOption: .tcp_nodelay) { value in
|
||||
return (value as? SocketOptionValue) == 1
|
||||
}
|
||||
try self.assertConnect(expectedAddress: serverAddress, result: true)
|
||||
try self.assertLocalAddress(address: localAddress)
|
||||
try self.assertRemoteAddress(address: localAddress)
|
||||
try self.assertRegister { selectable, event, Registration in
|
||||
XCTAssertEqual([.reset], event)
|
||||
return true
|
||||
}
|
||||
try self.assertReregister { selectable, event in
|
||||
XCTAssertEqual([.reset, .readEOF], event)
|
||||
return true
|
||||
}
|
||||
try self.assertDeregister { selectable in
|
||||
return true
|
||||
}
|
||||
try self.assertClose(expectedFD: .max)
|
||||
}) {
|
||||
ClientBootstrap(group: channel.eventLoop)
|
||||
.channelOption(ChannelOptions.autoRead, value: false)
|
||||
.testOnly_connect(injectedChannel: channel, to: serverAddress)
|
||||
.flatMap { channel in
|
||||
channel.close()
|
||||
}
|
||||
}.salWait())
|
||||
}
|
||||
|
||||
func testClientBootstrapBindIsDoneAfterSocketOptions() {
|
||||
guard let channel = try? self.makeSocketChannel() else {
|
||||
XCTFail("couldn't make a channel")
|
||||
return
|
||||
}
|
||||
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
|
||||
let serverAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 5)
|
||||
XCTAssertNoThrow(try channel.eventLoop.runSAL(syscallAssertions: {
|
||||
try self.assertSetOption(expectedLevel: .tcp, expectedOption: .tcp_nodelay) { value in
|
||||
return (value as? SocketOptionValue) == 1
|
||||
}
|
||||
// This is the important bit: We need to apply the socket options _before_ ...
|
||||
try self.assertSetOption(expectedLevel: .socket, expectedOption: .so_reuseaddr) { value in
|
||||
return (value as? SocketOptionValue) == 1
|
||||
}
|
||||
// ... we call bind.
|
||||
try self.assertBind(expectedAddress: localAddress)
|
||||
try self.assertLocalAddress(address: nil) // this is an inefficiency in `bind0`.
|
||||
try self.assertConnect(expectedAddress: serverAddress, result: true)
|
||||
try self.assertLocalAddress(address: localAddress)
|
||||
try self.assertRemoteAddress(address: localAddress)
|
||||
try self.assertRegister { selectable, event, Registration in
|
||||
XCTAssertEqual([.reset], event)
|
||||
return true
|
||||
}
|
||||
try self.assertReregister { selectable, event in
|
||||
XCTAssertEqual([.reset, .readEOF], event)
|
||||
return true
|
||||
}
|
||||
try self.assertDeregister { selectable in
|
||||
return true
|
||||
}
|
||||
try self.assertClose(expectedFD: .max)
|
||||
}) {
|
||||
ClientBootstrap(group: channel.eventLoop)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelOption(ChannelOptions.autoRead, value: false)
|
||||
.bind(to: localAddress)
|
||||
.testOnly_connect(injectedChannel: channel, to: serverAddress)
|
||||
.flatMap { channel in
|
||||
channel.close()
|
||||
}
|
||||
}.salWait())
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -127,6 +127,8 @@ enum UserToKernel {
|
|||
case disableSIGPIPE(CInt)
|
||||
case write(CInt, ByteBuffer)
|
||||
case writev(CInt, [ByteBuffer])
|
||||
case bind(SocketAddress)
|
||||
case setOption(NIOBSDSocket.OptionLevel, NIOBSDSocket.Option, Any)
|
||||
}
|
||||
|
||||
enum KernelToUser {
|
||||
|
@ -342,6 +344,26 @@ class HookedSocket: Socket, UserKernelInterface {
|
|||
throw UnexpectedKernelReturn(ret)
|
||||
}
|
||||
}
|
||||
|
||||
override func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
|
||||
try self.userToKernel.waitForEmptyAndSet(.setOption(level, name, value))
|
||||
let ret = try self.waitForKernelReturn()
|
||||
if case .returnVoid = ret {
|
||||
return
|
||||
} else {
|
||||
throw UnexpectedKernelReturn(ret)
|
||||
}
|
||||
}
|
||||
|
||||
override func bind(to address: SocketAddress) throws {
|
||||
try self.userToKernel.waitForEmptyAndSet(.bind(address))
|
||||
let ret = try self.waitForKernelReturn()
|
||||
if case .returnVoid = ret {
|
||||
return
|
||||
} else {
|
||||
throw UnexpectedKernelReturn(ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension HookedSelector {
|
||||
|
@ -495,13 +517,17 @@ extension SALTest {
|
|||
return channel
|
||||
}
|
||||
|
||||
func makeSocketChannel(file: StaticString = #file, line: UInt = #line) throws -> SocketChannel {
|
||||
return try self.makeSocketChannel(eventLoop: self.loop, file: file, line: line)
|
||||
}
|
||||
|
||||
func makeConnectedSocketChannel(localAddress: SocketAddress?,
|
||||
remoteAddress: SocketAddress,
|
||||
file: StaticString = (#file),
|
||||
line: UInt = #line) throws -> SocketChannel {
|
||||
let channel = try self.makeSocketChannel(eventLoop: self.loop)
|
||||
let connectFuture = try channel.eventLoop.runSAL(syscallAssertions: {
|
||||
try self.assertConnect(result: true, { $0 == remoteAddress })
|
||||
try self.assertConnect(expectedAddress: remoteAddress, result: true)
|
||||
try self.assertLocalAddress(address: localAddress)
|
||||
try self.assertRemoteAddress(address: remoteAddress)
|
||||
try self.assertRegister { selectable, eventSet, registration in
|
||||
|
@ -608,8 +634,9 @@ extension SALTest {
|
|||
|
||||
func assertLocalAddress(address: SocketAddress?, file: StaticString = (#file), line: UInt = #line) throws {
|
||||
SAL.printIfDebug("\(#function)")
|
||||
try self.selector.assertSyscallAndReturn(address.map { .returnSocketAddress($0) } ??
|
||||
/* */ .error(.init(errnoCode: EOPNOTSUPP, reason: "nil passed")),
|
||||
try self.selector.assertSyscallAndReturn(address.map {
|
||||
.returnSocketAddress($0)
|
||||
/* */ } ?? .error(.init(errnoCode: EOPNOTSUPP, reason: "nil passed")),
|
||||
file: file, line: line) { syscall in
|
||||
if case .localAddress = syscall {
|
||||
return true
|
||||
|
@ -632,11 +659,22 @@ extension SALTest {
|
|||
}
|
||||
}
|
||||
|
||||
func assertConnect(result: Bool, file: StaticString = (#file), line: UInt = #line, _ matcher: (SocketAddress) -> Bool = { _ in true }) throws {
|
||||
func assertConnect(expectedAddress: SocketAddress, result: Bool, file: StaticString = (#file), line: UInt = #line, _ matcher: (SocketAddress) -> Bool = { _ in true }) throws {
|
||||
SAL.printIfDebug("\(#function)")
|
||||
try self.selector.assertSyscallAndReturn(.returnBool(result), file: file, line: line) { syscall in
|
||||
if case .connect(let address) = syscall {
|
||||
return matcher(address)
|
||||
return address == expectedAddress
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertBind(expectedAddress: SocketAddress, file: StaticString = #file, line: UInt = #line) throws {
|
||||
SAL.printIfDebug("\(#function)")
|
||||
try self.selector.assertSyscallAndReturn(.returnVoid, file: file, line: line) { syscall in
|
||||
if case .bind(let address) = syscall {
|
||||
return address == expectedAddress
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
@ -655,6 +693,19 @@ extension SALTest {
|
|||
}
|
||||
}
|
||||
|
||||
func assertSetOption(expectedLevel: NIOBSDSocket.OptionLevel,
|
||||
expectedOption: NIOBSDSocket.Option,
|
||||
file: StaticString = #file, line: UInt = #line,
|
||||
_ valueMatcher: (Any) -> Bool = { _ in true }) throws {
|
||||
SAL.printIfDebug("\(#function)")
|
||||
try self.selector.assertSyscallAndReturn(.returnVoid, file: file, line: line) { syscall in
|
||||
if case .setOption(expectedLevel, expectedOption, let value) = syscall {
|
||||
return valueMatcher(value)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertRegister(file: StaticString = (#file), line: UInt = #line, _ matcher: (Selectable, SelectorEventSet, NIORegistration) throws -> Bool) throws {
|
||||
SAL.printIfDebug("\(#function)")
|
||||
|
|
|
@ -16,6 +16,17 @@ import XCTest
|
|||
@testable import NIO
|
||||
import NIOConcurrencyHelpers
|
||||
|
||||
extension System {
|
||||
static var supportsIPv6: Bool {
|
||||
do {
|
||||
let ipv6Loopback = try SocketAddress.makeAddressResolvingHost("::1", port: 0)
|
||||
return try System.enumerateInterfaces().filter { $0.address == ipv6Loopback }.first != nil
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withPipe(_ body: (NIO.NIOFileHandle, NIO.NIOFileHandle) throws -> [NIO.NIOFileHandle]) throws {
|
||||
var fds: [Int32] = [-1, -1]
|
||||
fds.withUnsafeMutableBufferPointer { ptr in
|
||||
|
|
Loading…
Reference in New Issue