BaseSocketChannel: accept immediately closed socket (#1121)

Motivation:

In the grpc-swift test suite, we saw a case where the server would
always immediately close the accepted socket. This lead NIO to misbehave
badly because kqueue would send us the `readEOF` before the `writable`
event that finishes an asynchronous `connect`.

What happened is that we just dropped the `readEOF` on the floor so we
would never actually tell the user if the channel ever went away.

Modifications:

Only register for `readEOF` after becoming active.

Result:

- we're happy with servers that immediately close the socket
This commit is contained in:
Johannes Weiss 2019-08-24 12:23:16 +01:00 committed by GitHub
parent ae8e2919b6
commit 32760eae40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 255 additions and 3 deletions

View File

@ -216,7 +216,9 @@ class BaseSocketChannel<T: BaseSocket>: SelectableChannel, ChannelCore {
private let isActiveAtomic: Atomic<Bool> = Atomic(value: false)
private var _pipeline: ChannelPipeline! = nil // this is really a constant (set in .init) but needs `self` to be constructed and therefore a `var`. Do not change as this needs to accessed from arbitrary threads
internal var interestedEvent: SelectorEventSet = [.readEOF, .reset] {
// We start with the invalid empty set of selector events we're interested in. This is to make sure we later on
// (in `becomeFullyRegistered0`) seed the initial event correctly.
internal var interestedEvent: SelectorEventSet = [] {
didSet {
assert(self.interestedEvent.contains(.reset), "impossible to unregister for reset")
}
@ -673,6 +675,22 @@ class BaseSocketChannel<T: BaseSocket>: SelectableChannel, ChannelCore {
self.safeReregister(interested: self.interestedEvent.union(.read))
}
private final func registerForReadEOF() {
self.eventLoop.assertInEventLoop()
assert(self.lifecycleManager.isRegisteredFully)
guard !self.lifecycleManager.hasSeenEOFNotification else {
// we have seen an EOF notification before so there's no point in registering for reads
return
}
guard !self.interestedEvent.contains(.readEOF) else {
return
}
self.safeReregister(interested: self.interestedEvent.union(.readEOF))
}
internal final func unregisterForReadable() {
self.eventLoop.assertInEventLoop()
assert(self.lifecycleManager.isRegisteredFully)
@ -1129,8 +1147,10 @@ class BaseSocketChannel<T: BaseSocket>: SelectableChannel, ChannelCore {
assert(self.lifecycleManager.isPreRegistered)
assert(!self.lifecycleManager.isRegisteredFully)
// We always register with interested .none and will just trigger readIfNeeded0() later to re-register if needed.
try self.safeRegister(interested: [.readEOF, .reset])
// The initial set of interested events must not contain `.readEOF` because when connect doesn't return
// synchronously, kevent might send us a `readEOF` because the `writable` event that marks the connect as completed.
// See SocketChannelTest.testServerClosesTheConnectionImmediately for a regression test.
try self.safeRegister(interested: [.reset])
self.lifecycleManager.finishRegistration()(nil, self.pipeline)
}
@ -1140,12 +1160,18 @@ class BaseSocketChannel<T: BaseSocket>: SelectableChannel, ChannelCore {
if !self.lifecycleManager.isRegisteredFully {
do {
try self.becomeFullyRegistered0()
assert(self.lifecycleManager.isRegisteredFully)
} catch {
self.close0(error: error, mode: .all, promise: promise)
return
}
}
self.lifecycleManager.activate()(promise, self.pipeline)
guard self.lifecycleManager.isOpen else {
// in the user callout for `channelActive` the channel got closed.
return
}
self.registerForReadEOF()
self.readIfNeeded0()
}
}

View File

@ -50,6 +50,9 @@ extension SocketChannelTest {
("testUnprocessedOutboundUserEventFailsOnSocketChannel", testUnprocessedOutboundUserEventFailsOnSocketChannel),
("testSetSockOptDoesNotOverrideExistingFlags", testSetSockOptDoesNotOverrideExistingFlags),
("testServerChannelDoesNotBreakIfAcceptingFailsWithEINVAL", testServerChannelDoesNotBreakIfAcceptingFailsWithEINVAL),
("testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheServerSide", testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheServerSide),
("testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheClientSide", testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheClientSide),
("testServerClosesTheConnectionImmediately", testServerClosesTheConnectionImmediately),
]
}
}

View File

@ -745,4 +745,227 @@ public final class SocketChannelTest : XCTestCase {
XCTAssertEqual(1, eventCounter.channelReadCalls)
XCTAssertEqual(1, eventCounter.channelReadCompleteCalls)
}
func testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheServerSide() throws {
// This test makes sure that we notice EOFs early, even if we never register for read (by dropping all the reads
// on the floor. This is the same test as below but this one is for TCP servers.
for mode in [DropAllReadsOnTheFloorHandler.Mode.halfClosureEnabled, .halfClosureDisabled] {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let channelInactivePromise = group.next().makePromise(of: Void.self)
let channelHalfClosedPromise = group.next().makePromise(of: Void.self)
let waitUntilWriteFailedPromise = group.next().makePromise(of: Void.self)
let channelActivePromise = group.next().makePromise(of: Void.self)
if mode == .halfClosureDisabled {
// if we don't support half-closure these two promises would otherwise never be fulfilled
channelInactivePromise.futureResult.cascade(to: waitUntilWriteFailedPromise)
channelInactivePromise.futureResult.cascade(to: channelHalfClosedPromise)
}
let eventCounter = EventCounterHandler()
var numberOfAcceptedChannels = 0
let server = try assertNoThrowWithValue(ServerBootstrap(group: group)
.childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: mode == .halfClosureEnabled)
.childChannelInitializer { channel in
numberOfAcceptedChannels += 1
XCTAssertEqual(1, numberOfAcceptedChannels)
let drop = DropAllReadsOnTheFloorHandler(mode: mode,
channelInactivePromise: channelInactivePromise,
channelHalfClosedPromise: channelHalfClosedPromise,
waitUntilWriteFailedPromise: waitUntilWriteFailedPromise,
channelActivePromise: channelActivePromise)
return channel.pipeline.addHandlers([eventCounter, drop])
}
.bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait())
let client = try assertNoThrowWithValue(ClientBootstrap(group: group)
.connect(to: server.localAddress!).wait())
XCTAssertNoThrow(try channelActivePromise.futureResult.flatMap { () -> EventLoopFuture<Void> in
XCTAssertTrue(client.isActive)
XCTAssertEqual(["register", "channelActive", "channelRegistered"], eventCounter.allTriggeredEvents())
XCTAssertEqual(1, eventCounter.channelActiveCalls)
XCTAssertEqual(1, eventCounter.channelRegisteredCalls)
return client.close()
}.wait())
XCTAssertNoThrow(try channelHalfClosedPromise.futureResult.wait())
XCTAssertNoThrow(try channelInactivePromise.futureResult.wait())
XCTAssertNoThrow(try waitUntilWriteFailedPromise.futureResult.wait())
}
}
func testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheClientSide() throws {
// This test makes sure that we notice EOFs early, even if we never register for read (by dropping all the reads
// on the floor. This is the same test as above but this one is for TCP clients.
enum Mode {
case halfClosureEnabled
case halfClosureDisabled
}
for mode in [DropAllReadsOnTheFloorHandler.Mode.halfClosureEnabled, .halfClosureDisabled] {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let acceptedServerChannel: EventLoopPromise<Channel> = group.next().makePromise()
let channelInactivePromise = group.next().makePromise(of: Void.self)
let channelHalfClosedPromise = group.next().makePromise(of: Void.self)
let waitUntilWriteFailedPromise = group.next().makePromise(of: Void.self)
if mode == .halfClosureDisabled {
// if we don't support half-closure these two promises would otherwise never be fulfilled
channelInactivePromise.futureResult.cascade(to: waitUntilWriteFailedPromise)
channelInactivePromise.futureResult.cascade(to: channelHalfClosedPromise)
}
let eventCounter = EventCounterHandler()
let server = try assertNoThrowWithValue(ServerBootstrap(group: group)
.childChannelInitializer { channel in
acceptedServerChannel.succeed(channel)
return channel.eventLoop.makeSucceededFuture(())
}
.bind(to: .init(ipAddress: "127.0.0.1", port: 0))
.wait())
let client = try assertNoThrowWithValue(ClientBootstrap(group: group)
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: mode == .halfClosureEnabled)
.channelInitializer { channel in
channel.pipeline.addHandlers([eventCounter,
DropAllReadsOnTheFloorHandler(mode: mode,
channelInactivePromise: channelInactivePromise,
channelHalfClosedPromise: channelHalfClosedPromise,
waitUntilWriteFailedPromise: waitUntilWriteFailedPromise)])
}
.connect(to: server.localAddress!).wait())
XCTAssertNoThrow(try acceptedServerChannel.futureResult.flatMap { channel -> EventLoopFuture<Void> in
XCTAssertEqual(["register", "channelActive", "channelRegistered", "connect"],
eventCounter.allTriggeredEvents())
XCTAssertEqual(1, eventCounter.channelActiveCalls)
XCTAssertEqual(1, eventCounter.channelRegisteredCalls)
return channel.close()
}.wait())
XCTAssertNoThrow(try channelHalfClosedPromise.futureResult.wait())
XCTAssertNoThrow(try channelInactivePromise.futureResult.wait())
XCTAssertNoThrow(try client.closeFuture.wait())
XCTAssertNoThrow(try waitUntilWriteFailedPromise.futureResult.wait())
}
}
func testServerClosesTheConnectionImmediately() throws {
// This is a regression test for a problem that the grpc-swift compatibility tests hit where everything would
// get stuck on a server that just insta-closes every accepted connection.
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
class WaitForChannelInactiveHandler: ChannelInboundHandler {
typealias InboundIn = Never
typealias OutboundOut = ByteBuffer
let channelInactivePromise: EventLoopPromise<Void>
init(channelInactivePromise: EventLoopPromise<Void>) {
self.channelInactivePromise = channelInactivePromise
}
func channelActive(context: ChannelHandlerContext) {
var buffer = context.channel.allocator.buffer(capacity: 128)
buffer.writeString(String(repeating: "x", count: 517))
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
func channelInactive(context: ChannelHandlerContext) {
self.channelInactivePromise.succeed(())
context.fireChannelInactive()
}
}
let serverSocket = try assertNoThrowWithValue(ServerSocket(protocolFamily: PF_INET))
XCTAssertNoThrow(try serverSocket.bind(to: .init(ipAddress: "127.0.0.1", port: 0)))
XCTAssertNoThrow(try serverSocket.listen())
let g = DispatchGroup()
DispatchQueue(label: "accept one client").async(group: g) {
if let socket = try! serverSocket.accept() {
try! socket.close()
}
}
let channelInactivePromise = group.next().makePromise(of: Void.self)
let eventCounter = EventCounterHandler()
let client = try assertNoThrowWithValue(ClientBootstrap(group: group)
.channelInitializer { channel in
channel.pipeline.addHandlers([eventCounter,
WaitForChannelInactiveHandler(channelInactivePromise: channelInactivePromise)])
}
.connect(to: try serverSocket.localAddress())
.wait())
XCTAssertNoThrow(try channelInactivePromise.futureResult.map { _ in
XCTAssertEqual(1, eventCounter.channelInactiveCalls)
}.wait())
XCTAssertNoThrow(try client.closeFuture.wait())
g.wait()
XCTAssertNoThrow(try serverSocket.close())
}
}
class DropAllReadsOnTheFloorHandler: ChannelDuplexHandler {
typealias InboundIn = Never
typealias OutboundIn = Never
typealias OutboundOut = ByteBuffer
enum Mode {
case halfClosureEnabled
case halfClosureDisabled
}
let channelInactivePromise: EventLoopPromise<Void>
let channelHalfClosedPromise: EventLoopPromise<Void>
let waitUntilWriteFailedPromise: EventLoopPromise<Void>
let channelActivePromise: EventLoopPromise<Void>?
let mode: Mode
init(mode: Mode,
channelInactivePromise: EventLoopPromise<Void>,
channelHalfClosedPromise: EventLoopPromise<Void>,
waitUntilWriteFailedPromise: EventLoopPromise<Void>,
channelActivePromise: EventLoopPromise<Void>? = nil) {
self.mode = mode
self.channelInactivePromise = channelInactivePromise
self.channelHalfClosedPromise = channelHalfClosedPromise
self.waitUntilWriteFailedPromise = waitUntilWriteFailedPromise
self.channelActivePromise = channelActivePromise
}
func channelActive(context: ChannelHandlerContext) {
self.channelActivePromise?.succeed(())
}
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if let event = event as? ChannelEvent, event == .inputClosed {
XCTAssertEqual(.halfClosureEnabled, self.mode)
self.channelHalfClosedPromise.succeed(())
var buffer = context.channel.allocator.buffer(capacity: 1_000_000)
buffer.writeBytes(Array(repeating: UInt8(ascii: "x"),
count: 1_000_000))
// What we're trying to do here is forcing a close without calling `close`. We know that the other side of
// the connection is fully closed but because we support half-closure, we need to write to 'learn' that the
// other side has actually fully closed the socket.
func writeUntilError() {
context.writeAndFlush(self.wrapOutboundOut(buffer)).map {
writeUntilError()
}.whenFailure { (_: Error) in
self.waitUntilWriteFailedPromise.succeed(())
}
}
writeUntilError()
}
context.fireUserInboundEventTriggered(event)
}
func channelInactive(context: ChannelHandlerContext) {
self.channelInactivePromise.succeed(())
context.fireChannelInactive()
}
func read(context: ChannelHandlerContext) {}
}