ClientBootstrap: allow binding sockets (#1490)

This commit is contained in:
Johannes Weiss 2020-06-08 10:46:20 +01:00 committed by GitHub
parent 5fc487bf7e
commit 120acb15c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 314 additions and 48 deletions

View File

@ -337,7 +337,7 @@ class BaseSocket: BaseSocketProtocol {
/// - name: The name of the option to set.
/// - value: The value for the option.
/// - throws: An `IOError` if the operation failed.
final func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
if level == .tcp && name == .tcp_nodelay && (try? self.localAddress().protocol) == Optional<NIOBSDSocket.ProtocolFamily>.some(.unix) {
// setting TCP_NODELAY on UNIX domain sockets will fail. Previously we had a bug where we would ignore
// most socket options settings so for the time being we'll just ignore this. Let's revisit for NIO 2.0.
@ -387,7 +387,7 @@ class BaseSocket: BaseSocketProtocol {
/// - parameters:
/// - address: The `SocketAddress` to which the socket should be bound.
/// - throws: An `IOError` if the operation failed.
final func bind(to address: SocketAddress) throws {
func bind(to address: SocketAddress) throws {
try self.withUnsafeHandle { fd in
func doBind(ptr: UnsafePointer<sockaddr>, bytes: Int) throws {
try Posix.bind(descriptor: fd, ptr: ptr, bytes: bytes)

View File

@ -645,10 +645,6 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
promise?.fail(ChannelError.ioOnClosedChannel)
return
}
guard self.lifecycleManager.isPreRegistered else {
promise?.fail(ChannelError.inappropriateOperationForState)
return
}
executeAndComplete(promise) {
try socket.bind(to: address)

View File

@ -426,6 +426,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
internal var _channelOptions: ChannelOptions.Storage
private var connectTimeout: TimeAmount = TimeAmount.seconds(10)
private var resolver: Optional<Resolver>
private var bindTarget: Optional<SocketAddress>
/// Create a `ClientBootstrap` on the `EventLoopGroup` `group`.
///
@ -458,6 +459,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
self._channelInitializer = { channel in channel.eventLoop.makeSucceededFuture(()) }
self.protocolHandlers = nil
self.resolver = nil
self.bindTarget = nil
}
/// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add
@ -523,6 +525,24 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
return self
}
/// Bind the `SocketChannel` to `address`.
///
/// Using `bind` is not necessary unless you need the local address to be bound to a specific address.
///
/// - note: Using `bind` will disable Happy Eyeballs on this `Channel`.
///
/// - parameters:
/// - address: The `SocketAddress` to bind on.
public func bind(to address: SocketAddress) -> ClientBootstrap {
self.bindTarget = address
return self
}
func makeSocketChannel(eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily) throws -> SocketChannel {
return try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily)
}
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
///
/// - parameters:
@ -531,34 +551,51 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
/// - returns: An `EventLoopFuture<Channel>` to deliver the `Channel` when connected.
public func connect(host: String, port: Int) -> EventLoopFuture<Channel> {
let loop = self.group.next()
let connector = HappyEyeballsConnector(resolver: resolver ?? GetaddrinfoResolver(loop: loop, aiSocktype: .stream, aiProtocol: CInt(IPPROTO_TCP)),
let resolver = self.resolver ?? GetaddrinfoResolver(loop: loop,
aiSocktype: .stream,
aiProtocol: CInt(IPPROTO_TCP))
let connector = HappyEyeballsConnector(resolver: resolver,
loop: loop,
host: host,
port: port,
connectTimeout: self.connectTimeout) { eventLoop, protocolFamily in
return self.execute(eventLoop: eventLoop, protocolFamily: protocolFamily) { $0.eventLoop.makeSucceededFuture(()) }
return self.initializeAndRegisterNewChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) {
$0.eventLoop.makeSucceededFuture(())
}
}
return connector.resolveAndConnect()
}
private func connect(freshChannel channel: Channel, address: SocketAddress) -> EventLoopFuture<Void> {
let connectPromise = channel.eventLoop.makePromise(of: Void.self)
channel.connect(to: address, promise: connectPromise)
let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
channel.close(promise: nil)
}
connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
cancelTask.cancel()
}
return connectPromise.futureResult
}
internal func testOnly_connect(injectedChannel: SocketChannel,
to address: SocketAddress) -> EventLoopFuture<Channel> {
return self.initializeAndRegisterChannel(injectedChannel) { channel in
return self.connect(freshChannel: channel, address: address)
}
}
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
///
/// - parameters:
/// - address: The address to connect to.
/// - returns: An `EventLoopFuture<Channel>` to deliver the `Channel` when connected.
public func connect(to address: SocketAddress) -> EventLoopFuture<Channel> {
return execute(eventLoop: group.next(), protocolFamily: address.protocol) { channel in
let connectPromise = channel.eventLoop.makePromise(of: Void.self)
channel.connect(to: address, promise: connectPromise)
let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
channel.close(promise: nil)
}
connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
cancelTask.cancel()
}
return connectPromise.futureResult
return self.initializeAndRegisterNewChannel(eventLoop: self.group.next(),
protocolFamily: address.protocol) { channel in
return self.connect(freshChannel: channel, address: address)
}
}
@ -570,9 +607,9 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
public func connect(unixDomainSocketPath: String) -> EventLoopFuture<Channel> {
do {
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
return connect(to: address)
return self.connect(to: address)
} catch {
return group.next().makeFailedFuture(error)
return self.group.next().makeFailedFuture(error)
}
}
@ -615,24 +652,35 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
}
}
private func execute(eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily,
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let channelInitializer = self.channelInitializer
let channelOptions = self._channelOptions
private func initializeAndRegisterNewChannel(eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily,
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let channel: SocketChannel
do {
channel = try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily)
channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily)
} catch {
return eventLoop.makeFailedFuture(error)
}
return self.initializeAndRegisterChannel(channel, body)
}
private func initializeAndRegisterChannel(_ channel: SocketChannel,
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let channelInitializer = self.channelInitializer
let channelOptions = self._channelOptions
let eventLoop = channel.eventLoop
@inline(__always)
func setupChannel() -> EventLoopFuture<Channel> {
eventLoop.assertInEventLoop()
return channelOptions.applyAllChannelOptions(to: channel).flatMap {
channelInitializer(channel)
if let bindTarget = self.bindTarget {
return channel.bind(to: bindTarget).flatMap {
channelInitializer(channel)
}
} else {
return channelInitializer(channel)
}
}.flatMap {
eventLoop.assertInEventLoop()
return channel.registerAndDoSynchronously(body)
@ -647,7 +695,9 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
if eventLoop.inEventLoop {
return setupChannel()
} else {
return eventLoop.submit(setupChannel).flatMap { $0 }
return eventLoop.flatSubmit {
setupChannel()
}
}
}
}

View File

@ -48,6 +48,7 @@ extension BootstrapTest {
("testDatagramBootstrapRejectsNotWorkingELGsCorrectly", testDatagramBootstrapRejectsNotWorkingELGsCorrectly),
("testNIOPipeBootstrapValidatesWorkingELGsCorrectly", testNIOPipeBootstrapValidatesWorkingELGsCorrectly),
("testNIOPipeBootstrapRejectsNotWorkingELGsCorrectly", testNIOPipeBootstrapRejectsNotWorkingELGsCorrectly),
("testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only", testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only),
]
}
}

View File

@ -529,6 +529,88 @@ class BootstrapTest: XCTestCase {
XCTAssertNil(NIOPipeBootstrap(validatingGroup: elg))
XCTAssertNil(NIOPipeBootstrap(validatingGroup: el))
}
func testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only() {
for isIPv4 in [true, false] {
guard System.supportsIPv6 || isIPv4 else {
continue // need to skip IPv6 tests if we don't support it.
}
let localIP = isIPv4 ? "127.0.0.1" : "::1"
guard let serverLocalAddressChoice = try? SocketAddress(ipAddress: localIP, port: 0),
let clientLocalAddressWholeInterface = try? SocketAddress(ipAddress: localIP, port: 0),
let server1 = (try? ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
.bind(to: serverLocalAddressChoice)
.wait()),
let server2 = (try? ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
.bind(to: serverLocalAddressChoice)
.wait()),
let server1LocalAddress = server1.localAddress,
let server2LocalAddress = server2.localAddress else {
XCTFail("can't boot servers even")
return
}
defer {
XCTAssertNoThrow(try server1.close().wait())
XCTAssertNoThrow(try server2.close().wait())
}
// Try 1: Directly connect to 127.0.0.1, this won't do Happy Eyeballs.
XCTAssertNoThrow(try ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.bind(to: clientLocalAddressWholeInterface)
.connect(to: server1LocalAddress)
.wait()
.close()
.wait())
var maybeChannel1: Channel? = nil
// Try 2: Connect to "localhost", this will do Happy Eyeballs.
XCTAssertNoThrow(maybeChannel1 = try ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.bind(to: clientLocalAddressWholeInterface)
.connect(host: "localhost", port: server1LocalAddress.port!)
.wait())
guard let myChannel1 = maybeChannel1, let myChannel1Address = myChannel1.localAddress else {
XCTFail("can't connect channel 1")
return
}
XCTAssertEqual(localIP, maybeChannel1?.localAddress?.ipAddress)
// Try 3: Bind the client to the same address/port as in try 2 but to server 2.
XCTAssertNoThrow(try ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.connectTimeout(.hours(2))
.bind(to: myChannel1Address)
.connect(to: server2LocalAddress)
.map { channel -> Channel in
XCTAssertEqual(myChannel1Address, channel.localAddress)
return channel
}
.wait()
.close()
.wait())
}
}
}
private final class WriteStringOnChannelActive: ChannelInboundHandler {
typealias InboundIn = Never
typealias OutboundOut = ByteBuffer
let string: String
init(_ string: String) {
self.string = string
}
func channelActive(context: ChannelHandlerContext) {
var buffer = context.channel.allocator.buffer(capacity: self.string.utf8.count)
buffer.writeString(string)
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
}
private final class MakeSureAutoReadIsOffInChannelInitializer: ChannelInboundHandler {

View File

@ -2012,7 +2012,7 @@ public final class ChannelTests: XCTestCase {
XCTAssertFalse(channel.isWritable)
}
withChannel { channel in
withChannel(skipStream: true) { channel in
checkThatItThrowsInappropriateOperationForState {
XCTAssertEqual(0, channel.localAddress?.port ?? 0xffff)
XCTAssertNil(channel.remoteAddress)

View File

@ -49,15 +49,6 @@ final class MulticastTest: XCTestCase {
struct ReceivedDatagramError: Error { }
private var supportsIPv6: Bool {
do {
let ipv6Loopback = try SocketAddress(ipAddress: "::1", port: 0)
return try System.enumerateInterfaces().contains(where: { $0.address == ipv6Loopback })
} catch {
return false
}
}
private func interfaceForAddress(address: String) throws -> NIONetworkInterface {
let targetAddress = try SocketAddress(ipAddress: address, port: 0)
guard let interface = try System.enumerateInterfaces().lazy.filter({ $0.address == targetAddress }).first else {
@ -220,7 +211,7 @@ final class MulticastTest: XCTestCase {
}
func testCanJoinBasicMulticastGroupIPv6() throws {
guard self.supportsIPv6 else {
guard System.supportsIPv6 else {
// Skip on non-IPv6 systems
return
}
@ -317,7 +308,7 @@ final class MulticastTest: XCTestCase {
}
func testCanLeaveAnIPv6MulticastGroup() throws {
guard self.supportsIPv6 else {
guard System.supportsIPv6 else {
// Skip on non-IPv6 systems
return
}

View File

@ -31,6 +31,8 @@ extension SALChannelTest {
("testWritesFromWritabilityNotificationsDoNotGetLostIfWePreviouslyWroteEverything", testWritesFromWritabilityNotificationsDoNotGetLostIfWePreviouslyWroteEverything),
("testWeSurviveIfIgnoringSIGPIPEFails", testWeSurviveIfIgnoringSIGPIPEFails),
("testBasicRead", testBasicRead),
("testBasicConnectWithClientBootstrap", testBasicConnectWithClientBootstrap),
("testClientBootstrapBindIsDoneAfterSocketOptions", testClientBootstrapBindIsDoneAfterSocketOptions),
]
}
}

View File

@ -235,4 +235,86 @@ final class SALChannelTest: XCTestCase, SALTest {
g.wait()
}
func testBasicConnectWithClientBootstrap() {
guard let channel = try? self.makeSocketChannel() else {
XCTFail("couldn't make a channel")
return
}
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
let serverAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 5)
XCTAssertNoThrow(try channel.eventLoop.runSAL(syscallAssertions: {
try self.assertSetOption(expectedLevel: .tcp, expectedOption: .tcp_nodelay) { value in
return (value as? SocketOptionValue) == 1
}
try self.assertConnect(expectedAddress: serverAddress, result: true)
try self.assertLocalAddress(address: localAddress)
try self.assertRemoteAddress(address: localAddress)
try self.assertRegister { selectable, event, Registration in
XCTAssertEqual([.reset], event)
return true
}
try self.assertReregister { selectable, event in
XCTAssertEqual([.reset, .readEOF], event)
return true
}
try self.assertDeregister { selectable in
return true
}
try self.assertClose(expectedFD: .max)
}) {
ClientBootstrap(group: channel.eventLoop)
.channelOption(ChannelOptions.autoRead, value: false)
.testOnly_connect(injectedChannel: channel, to: serverAddress)
.flatMap { channel in
channel.close()
}
}.salWait())
}
func testClientBootstrapBindIsDoneAfterSocketOptions() {
guard let channel = try? self.makeSocketChannel() else {
XCTFail("couldn't make a channel")
return
}
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
let serverAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 5)
XCTAssertNoThrow(try channel.eventLoop.runSAL(syscallAssertions: {
try self.assertSetOption(expectedLevel: .tcp, expectedOption: .tcp_nodelay) { value in
return (value as? SocketOptionValue) == 1
}
// This is the important bit: We need to apply the socket options _before_ ...
try self.assertSetOption(expectedLevel: .socket, expectedOption: .so_reuseaddr) { value in
return (value as? SocketOptionValue) == 1
}
// ... we call bind.
try self.assertBind(expectedAddress: localAddress)
try self.assertLocalAddress(address: nil) // this is an inefficiency in `bind0`.
try self.assertConnect(expectedAddress: serverAddress, result: true)
try self.assertLocalAddress(address: localAddress)
try self.assertRemoteAddress(address: localAddress)
try self.assertRegister { selectable, event, Registration in
XCTAssertEqual([.reset], event)
return true
}
try self.assertReregister { selectable, event in
XCTAssertEqual([.reset, .readEOF], event)
return true
}
try self.assertDeregister { selectable in
return true
}
try self.assertClose(expectedFD: .max)
}) {
ClientBootstrap(group: channel.eventLoop)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelOption(ChannelOptions.autoRead, value: false)
.bind(to: localAddress)
.testOnly_connect(injectedChannel: channel, to: serverAddress)
.flatMap { channel in
channel.close()
}
}.salWait())
}
}

View File

@ -127,6 +127,8 @@ enum UserToKernel {
case disableSIGPIPE(CInt)
case write(CInt, ByteBuffer)
case writev(CInt, [ByteBuffer])
case bind(SocketAddress)
case setOption(NIOBSDSocket.OptionLevel, NIOBSDSocket.Option, Any)
}
enum KernelToUser {
@ -342,6 +344,26 @@ class HookedSocket: Socket, UserKernelInterface {
throw UnexpectedKernelReturn(ret)
}
}
override func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
try self.userToKernel.waitForEmptyAndSet(.setOption(level, name, value))
let ret = try self.waitForKernelReturn()
if case .returnVoid = ret {
return
} else {
throw UnexpectedKernelReturn(ret)
}
}
override func bind(to address: SocketAddress) throws {
try self.userToKernel.waitForEmptyAndSet(.bind(address))
let ret = try self.waitForKernelReturn()
if case .returnVoid = ret {
return
} else {
throw UnexpectedKernelReturn(ret)
}
}
}
extension HookedSelector {
@ -495,13 +517,17 @@ extension SALTest {
return channel
}
func makeSocketChannel(file: StaticString = #file, line: UInt = #line) throws -> SocketChannel {
return try self.makeSocketChannel(eventLoop: self.loop, file: file, line: line)
}
func makeConnectedSocketChannel(localAddress: SocketAddress?,
remoteAddress: SocketAddress,
file: StaticString = (#file),
line: UInt = #line) throws -> SocketChannel {
let channel = try self.makeSocketChannel(eventLoop: self.loop)
let connectFuture = try channel.eventLoop.runSAL(syscallAssertions: {
try self.assertConnect(result: true, { $0 == remoteAddress })
try self.assertConnect(expectedAddress: remoteAddress, result: true)
try self.assertLocalAddress(address: localAddress)
try self.assertRemoteAddress(address: remoteAddress)
try self.assertRegister { selectable, eventSet, registration in
@ -608,8 +634,9 @@ extension SALTest {
func assertLocalAddress(address: SocketAddress?, file: StaticString = (#file), line: UInt = #line) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(address.map { .returnSocketAddress($0) } ??
/* */ .error(.init(errnoCode: EOPNOTSUPP, reason: "nil passed")),
try self.selector.assertSyscallAndReturn(address.map {
.returnSocketAddress($0)
/* */ } ?? .error(.init(errnoCode: EOPNOTSUPP, reason: "nil passed")),
file: file, line: line) { syscall in
if case .localAddress = syscall {
return true
@ -632,11 +659,22 @@ extension SALTest {
}
}
func assertConnect(result: Bool, file: StaticString = (#file), line: UInt = #line, _ matcher: (SocketAddress) -> Bool = { _ in true }) throws {
func assertConnect(expectedAddress: SocketAddress, result: Bool, file: StaticString = (#file), line: UInt = #line, _ matcher: (SocketAddress) -> Bool = { _ in true }) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(.returnBool(result), file: file, line: line) { syscall in
if case .connect(let address) = syscall {
return matcher(address)
return address == expectedAddress
} else {
return false
}
}
}
func assertBind(expectedAddress: SocketAddress, file: StaticString = #file, line: UInt = #line) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(.returnVoid, file: file, line: line) { syscall in
if case .bind(let address) = syscall {
return address == expectedAddress
} else {
return false
}
@ -655,6 +693,19 @@ extension SALTest {
}
}
func assertSetOption(expectedLevel: NIOBSDSocket.OptionLevel,
expectedOption: NIOBSDSocket.Option,
file: StaticString = #file, line: UInt = #line,
_ valueMatcher: (Any) -> Bool = { _ in true }) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(.returnVoid, file: file, line: line) { syscall in
if case .setOption(expectedLevel, expectedOption, let value) = syscall {
return valueMatcher(value)
} else {
return false
}
}
}
func assertRegister(file: StaticString = (#file), line: UInt = #line, _ matcher: (Selectable, SelectorEventSet, NIORegistration) throws -> Bool) throws {
SAL.printIfDebug("\(#function)")

View File

@ -16,6 +16,17 @@ import XCTest
@testable import NIO
import NIOConcurrencyHelpers
extension System {
static var supportsIPv6: Bool {
do {
let ipv6Loopback = try SocketAddress.makeAddressResolvingHost("::1", port: 0)
return try System.enumerateInterfaces().filter { $0.address == ipv6Loopback }.first != nil
} catch {
return false
}
}
}
func withPipe(_ body: (NIO.NIOFileHandle, NIO.NIOFileHandle) throws -> [NIO.NIOFileHandle]) throws {
var fds: [Int32] = [-1, -1]
fds.withUnsafeMutableBufferPointer { ptr in