Don't have channels stop reading on errors they tolerate. (#2408)

Motivation:

When an error is hit during a read loop, a channel is able to tolerate
that error without closing. This is done for a number of reasons, but
the most important one is accepting sockets for already-closed
connections, which can trigger all kinds of errors on the read path.

Unfortunately, there was an edge-case in the code for handling this
case. If one or more reads in the loop had succeeded before the error
was caught, the inner code would be expecting a call to readIfNeeded,
but the outer code wouldn't make it. This would lead to autoRead
channels being wedged open.

Modifications:

This patch extends the Syscall Abstraction Layer to add support for
server sockets. It adds two tests: one for the basic accept flow, and
then one for the case discussed above.

This patch also refactors the code in BaseSocketChannel.readable0 to
more clearly show the path through the error case. There were a number
of early returns and partial conditionals that led to us checking the
same condition in a number of places. This refactor makes it clearer
that it is possible to exit this code in the happy path, with a
tolerated error, which should be considered the same as reading
_something_.

Result:

Harder to wedge a channel open.
This commit is contained in:
Cory Benfield 2023-04-20 11:09:02 +01:00 committed by GitHub
parent ad859ae82e
commit 003fbadf51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 370 additions and 23 deletions

View File

@ -1091,36 +1091,40 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
// peer closed / shutdown the connection.
if let channelErr = err as? ChannelError, channelErr == ChannelError.eof {
readStreamState = .eof
// Directly call getOption0 as we are already on the EventLoop and so not need to create an extra future.
// getOption0 can only fail if the channel is not active anymore but we assert further up that it is. If
// that's not the case this is a precondition failure and we would like to know.
if self.lifecycleManager.isActive, try! self.getOption0(ChannelOptions.allowRemoteHalfClosure) {
// If we want to allow half closure we will just mark the input side of the Channel
// as closed.
assert(self.lifecycleManager.isActive)
if self.lifecycleManager.isActive {
// Directly call getOption0 as we are already on the EventLoop and so not need to create an extra future.
//
// getOption0 can only fail if the channel is not active anymore but we assert further up that it is. If
// that's not the case this is a precondition failure and we would like to know.
let allowRemoteHalfClosure = try! self.getOption0(ChannelOptions.allowRemoteHalfClosure)
// For EOF, we always fire read complete.
self.pipeline.syncOperations.fireChannelReadComplete()
if self.shouldCloseOnReadError(err) {
self.close0(error: err, mode: .input, promise: nil)
if allowRemoteHalfClosure {
// If we want to allow half closure we will just mark the input side of the Channel
// as closed.
if self.shouldCloseOnReadError(err) {
self.close0(error: err, mode: .input, promise: nil)
}
self.readPending = false
return .eof
}
self.readPending = false
return .eof
}
} else {
readStreamState = .error
self.pipeline.syncOperations.fireErrorCaught(err)
}
// Call before triggering the close of the Channel.
if readStreamState != .error, self.lifecycleManager.isActive {
self.pipeline.syncOperations.fireChannelReadComplete()
}
if self.shouldCloseOnReadError(err) {
self.close0(error: err, mode: .all, promise: nil)
return readStreamState
} else {
// This is non-fatal, so continue as normal.
// This constitutes "some" as we did get at least an error from the socket.
readResult = .some
}
return readStreamState
}
// This assert needs to be disabled for io_uring, as the io_uring backend does not have the implicit synchronisation between
// modifications to the poll mask and the actual returned events on the completion queue that kqueue and epoll has.

View File

@ -323,4 +323,137 @@ final class SALChannelTest: XCTestCase, SALTest {
}.salWait())
}
func testAcceptingInboundConnections() throws {
final class ConnectionRecorder: ChannelInboundHandler {
typealias InboundIn = Any
typealias InboundOut = Any
let readCount = ManagedAtomic(0)
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
readCount.wrappingIncrement(ordering: .sequentiallyConsistent)
context.fireChannelRead(data)
}
}
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
let remoteAddress = try! SocketAddress(ipAddress: "5.6.7.8", port: 10)
let channel = try self.makeBoundServerSocketChannel(localAddress: localAddress)
let socket = try self.makeSocket()
let readRecorder = ConnectionRecorder()
XCTAssertNoThrow(try channel.eventLoop.runSAL(syscallAssertions: {
let readEvent = SelectorEvent(io: [.read],
registration: NIORegistration(channel: .serverSocketChannel(channel),
interested: [.read],
registrationID: .initialRegistrationID))
try self.assertWaitingForNotification(result: readEvent)
try self.assertAccept(expectedFD: .max, expectedNonBlocking: true, return: socket)
try self.assertLocalAddress(address: localAddress)
try self.assertRemoteAddress(address: remoteAddress)
// This accept is expected: we delay inbound channel registration by one EL tick.
try self.assertAccept(expectedFD: .max, expectedNonBlocking: true, return: nil)
// Then we register the inbound channel.
try self.assertRegister { selectable, eventSet, registration in
if case (.socketChannel(let channel), let registrationEventSet) =
(registration.channel, registration.interested) {
XCTAssertEqual(localAddress, channel.localAddress)
XCTAssertEqual(remoteAddress, channel.remoteAddress)
XCTAssertEqual(eventSet, registrationEventSet)
XCTAssertEqual(.reset, eventSet)
return true
} else {
return false
}
}
try self.assertReregister { selectable, eventSet in
XCTAssertEqual([.reset, .readEOF], eventSet)
return true
}
// because autoRead is on by default
try self.assertReregister { selectable, eventSet in
XCTAssertEqual([.reset, .readEOF, .read], eventSet)
return true
}
try self.assertParkedRightNow()
}) {
channel.pipeline.addHandler(readRecorder)
})
XCTAssertEqual(readRecorder.readCount.load(ordering: .sequentiallyConsistent), 1)
}
func testAcceptingInboundConnectionsDoesntUnregisterForReadIfTheSecondAcceptErrors() throws {
final class ConnectionRecorder: ChannelInboundHandler {
typealias InboundIn = Any
typealias InboundOut = Any
let readCount = ManagedAtomic(0)
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
readCount.wrappingIncrement(ordering: .sequentiallyConsistent)
context.fireChannelRead(data)
}
}
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
let remoteAddress = try! SocketAddress(ipAddress: "5.6.7.8", port: 10)
let channel = try self.makeBoundServerSocketChannel(localAddress: localAddress)
let socket = try self.makeSocket()
let readRecorder = ConnectionRecorder()
XCTAssertNoThrow(try channel.eventLoop.runSAL(syscallAssertions: {
let readEvent = SelectorEvent(io: [.read],
registration: NIORegistration(channel: .serverSocketChannel(channel),
interested: [.read],
registrationID: .initialRegistrationID))
try self.assertWaitingForNotification(result: readEvent)
try self.assertAccept(expectedFD: .max, expectedNonBlocking: true, return: socket)
try self.assertLocalAddress(address: localAddress)
try self.assertRemoteAddress(address: remoteAddress)
// This accept is expected: we delay inbound channel registration by one EL tick. This one throws.
// We throw a deliberate error here: this one hits the buggy codepath.
try self.assertAccept(expectedFD: .max, expectedNonBlocking: true, throwing: NIOFcntlFailedError())
// Then we register the inbound channel from the first accept.
try self.assertRegister { selectable, eventSet, registration in
if case (.socketChannel(let channel), let registrationEventSet) =
(registration.channel, registration.interested) {
XCTAssertEqual(localAddress, channel.localAddress)
XCTAssertEqual(remoteAddress, channel.remoteAddress)
XCTAssertEqual(eventSet, registrationEventSet)
XCTAssertEqual(.reset, eventSet)
return true
} else {
return false
}
}
try self.assertReregister { selectable, eventSet in
XCTAssertEqual([.reset, .readEOF], eventSet)
return true
}
// because autoRead is on by default
try self.assertReregister { selectable, eventSet in
XCTAssertEqual([.reset, .readEOF, .read], eventSet)
return true
}
// Importantly, we should now be _parked_. This test is mostly testing in the absence:
// we expect not to see a reregister that removes readable.
try self.assertParkedRightNow()
}) {
channel.pipeline.addHandler(readRecorder)
})
XCTAssertEqual(readRecorder.readCount.load(ordering: .sequentiallyConsistent), 1)
}
}

View File

@ -717,7 +717,8 @@ public final class SocketChannelTest : XCTestCase {
XCTAssertNoThrow(try serverChan.eventLoop.submit {
serverChan.readable()
}.wait())
XCTAssertEqual(["errorCaught"], eventCounter.allTriggeredEvents())
XCTAssertEqual(["channelReadComplete", "errorCaught"], eventCounter.allTriggeredEvents())
XCTAssertEqual(1, eventCounter.channelReadCompleteCalls)
XCTAssertEqual(1, eventCounter.errorCaughtCalls)
serverSock.shouldAcceptsFail.store(false, ordering: .relaxed)
@ -729,7 +730,7 @@ public final class SocketChannelTest : XCTestCase {
eventCounter.allTriggeredEvents())
XCTAssertEqual(1, eventCounter.errorCaughtCalls)
XCTAssertEqual(1, eventCounter.channelReadCalls)
XCTAssertEqual(1, eventCounter.channelReadCompleteCalls)
XCTAssertEqual(2, eventCounter.channelReadCompleteCalls)
}
func testWeAreInterestedInReadEOFWhenChannelIsConnectedOnTheServerSide() throws {

View File

@ -142,6 +142,8 @@ enum UserToKernel {
case writev(CInt, [ByteBuffer])
case bind(SocketAddress)
case setOption(NIOBSDSocket.OptionLevel, NIOBSDSocket.Option, Any)
case listen(CInt, CInt)
case accept(CInt, Bool)
}
enum KernelToUser {
@ -151,7 +153,8 @@ enum KernelToUser {
case returnVoid
case returnSelectorEvent(SelectorEvent<NIORegistration>?)
case returnIOResultInt(IOResult<Int>)
case error(IOError)
case returnSocket(Socket?)
case error(Error)
}
struct UnexpectedKernelReturn: Error {
@ -252,6 +255,98 @@ internal class HookedSelector: NIOPosix.Selector<NIORegistration>, UserKernelInt
}
}
class HookedServerSocket: ServerSocket, UserKernelInterface {
fileprivate let userToKernel: LockedBox<UserToKernel>
fileprivate let kernelToUser: LockedBox<KernelToUser>
init(userToKernel: LockedBox<UserToKernel>, kernelToUser: LockedBox<KernelToUser>, socket: NIOBSDSocket.Handle) throws {
self.userToKernel = userToKernel
self.kernelToUser = kernelToUser
try super.init(socket: socket)
}
override func ignoreSIGPIPE() throws {
try self.withUnsafeHandle { fd in
try self.userToKernel.waitForEmptyAndSet(.disableSIGPIPE(fd))
let ret = try self.waitForKernelReturn()
if case .returnVoid = ret {
return
} else {
throw UnexpectedKernelReturn(ret)
}
}
}
override func localAddress() throws -> SocketAddress {
try self.userToKernel.waitForEmptyAndSet(.localAddress)
let ret = try self.waitForKernelReturn()
if case .returnSocketAddress(let address) = ret {
return address
} else {
throw UnexpectedKernelReturn(ret)
}
}
override func remoteAddress() throws -> SocketAddress {
try self.userToKernel.waitForEmptyAndSet(.remoteAddress)
let ret = try self.waitForKernelReturn()
if case .returnSocketAddress(let address) = ret {
return address
} else {
throw UnexpectedKernelReturn(ret)
}
}
override func bind(to address: SocketAddress) throws {
try self.userToKernel.waitForEmptyAndSet(.bind(address))
let ret = try self.waitForKernelReturn()
if case .returnVoid = ret {
return
} else {
throw UnexpectedKernelReturn(ret)
}
}
override func listen(backlog: Int32 = 128) throws {
try self.withUnsafeHandle { fd in
try self.userToKernel.waitForEmptyAndSet(.listen(fd, backlog))
let ret = try self.waitForKernelReturn()
if case .returnVoid = ret {
return
} else {
throw UnexpectedKernelReturn(ret)
}
}
}
override func accept(setNonBlocking: Bool = false) throws -> Socket? {
try self.withUnsafeHandle { fd in
try self.userToKernel.waitForEmptyAndSet(.accept(fd, setNonBlocking))
let ret = try self.waitForKernelReturn()
switch ret {
case .returnSocket(let socket):
return socket
case .error(let error):
throw error
default:
throw UnexpectedKernelReturn(ret)
}
}
}
override func close() throws {
let fd = try self.takeDescriptorOwnership()
try self.userToKernel.waitForEmptyAndSet(.close(fd))
let ret = try self.waitForKernelReturn()
if case .returnVoid = ret {
return
} else {
throw UnexpectedKernelReturn(ret)
}
}
}
class HookedSocket: Socket, UserKernelInterface {
fileprivate let userToKernel: LockedBox<UserToKernel>
@ -526,6 +621,25 @@ extension SALTest {
return channel
}
private func makeServerSocketChannel(eventLoop: SelectableEventLoop,
group: MultiThreadedEventLoopGroup,
file: StaticString = #filePath, line: UInt = #line) throws -> ServerSocketChannel {
let channel = try eventLoop.runSAL(syscallAssertions: {
try self.assertdisableSIGPIPE(expectedFD: .max, result: .success(()))
try self.assertLocalAddress(address: nil)
try self.assertRemoteAddress(address: nil)
}) {
try ServerSocketChannel(serverSocket: HookedServerSocket(userToKernel: self.userToKernelBox,
kernelToUser: self.kernelToUserBox,
socket: .max),
eventLoop: eventLoop,
group: group
)
}
try self.assertParkedRightNow()
return channel
}
func makeSocketChannelInjectingFailures(disableSIGPIPEFailure: IOError?) throws -> SocketChannel {
let channel = try self.loop.runSAL(syscallAssertions: {
try self.assertdisableSIGPIPE(expectedFD: .max,
@ -552,6 +666,10 @@ extension SALTest {
return try self.makeSocketChannel(eventLoop: self.loop, file: (file), line: line)
}
func makeServerSocketChannel(file: StaticString = #filePath, line: UInt = #line) throws -> ServerSocketChannel {
return try self.makeServerSocketChannel(eventLoop: self.loop, group: self.group, file: (file), line: line)
}
func makeConnectedSocketChannel(localAddress: SocketAddress?,
remoteAddress: SocketAddress,
file: StaticString = #filePath,
@ -592,6 +710,53 @@ extension SALTest {
return channel
}
func makeBoundServerSocketChannel(localAddress: SocketAddress,
file: StaticString = #filePath,
line: UInt = #line) throws -> ServerSocketChannel {
let channel = try self.makeServerSocketChannel(eventLoop: self.loop, group: self.group)
let bindFuture = try channel.eventLoop.runSAL(syscallAssertions: {
try self.assertBind(expectedAddress: localAddress)
try self.assertLocalAddress(address: localAddress)
try self.assertListen(expectedFD: .max, expectedBacklog: 128)
try self.assertRegister { selectable, eventSet, registration in
if case (.serverSocketChannel(let channel), let registrationEventSet) =
(registration.channel, registration.interested) {
XCTAssertEqual(localAddress, channel.localAddress)
XCTAssertEqual(nil, channel.remoteAddress)
XCTAssertEqual(eventSet, registrationEventSet)
XCTAssertEqual(.reset, eventSet)
return true
} else {
return false
}
}
try self.assertReregister { selectable, eventSet in
XCTAssertEqual([.reset, .readEOF], eventSet)
return true
}
// because autoRead is on by default
try self.assertReregister { selectable, eventSet in
XCTAssertEqual([.reset, .readEOF, .read], eventSet)
return true
}
}) {
channel.register().flatMap {
channel.bind(to: localAddress)
}
}
XCTAssertNoThrow(try bindFuture.salWait())
return channel
}
func makeSocket() throws -> HookedSocket {
return try self.loop.runSAL(syscallAssertions: {
try self.assertdisableSIGPIPE(expectedFD: .max, result: .success(()))
}) {
try HookedSocket(userToKernel: self.userToKernelBox, kernelToUser: self.kernelToUserBox, socket: .max)
}
}
func tearDownSAL() {
SAL.printIfDebug("=== TEAR DOWN ===")
XCTAssertNotNil(self.kernelToUserBox)
@ -663,7 +828,7 @@ extension SALTest {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(address.map {
.returnSocketAddress($0)
/* */ } ?? .error(.init(errnoCode: EOPNOTSUPP, reason: "nil passed")),
/* */ } ?? .error(IOError(errnoCode: EOPNOTSUPP, reason: "nil passed")),
file: (file), line: line) { syscall in
if case .localAddress = syscall {
return true
@ -676,7 +841,7 @@ extension SALTest {
func assertRemoteAddress(address: SocketAddress?, file: StaticString = #filePath, line: UInt = #line) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(address.map { .returnSocketAddress($0) } ??
/* */ .error(.init(errnoCode: EOPNOTSUPP, reason: "nil passed")),
/* */ .error(IOError(errnoCode: EOPNOTSUPP, reason: "nil passed")),
file: (file), line: line) { syscall in
if case .remoteAddress = syscall {
return true
@ -803,6 +968,50 @@ extension SALTest {
}
}
func assertListen(expectedFD: CInt, expectedBacklog: CInt, file: StaticString = #filePath, line: UInt = #line) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(.returnVoid,
file: (file), line: line) { syscall in
if case .listen(let fd, let backlog) = syscall {
XCTAssertEqual(fd, expectedFD, file: (file), line: line)
XCTAssertEqual(backlog, expectedBacklog, file: (file), line: line)
return true
} else {
return false
}
}
}
func assertAccept(expectedFD: CInt, expectedNonBlocking: Bool, return: Socket?,
file: StaticString = #filePath, line: UInt = #line) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(.returnSocket(`return`),
file: (file), line: line) { syscall in
if case .accept(let fd, let nonBlocking) = syscall {
XCTAssertEqual(fd, expectedFD, file: (file), line: line)
XCTAssertEqual(nonBlocking, expectedNonBlocking, file: (file), line: line)
return true
} else {
return false
}
}
}
func assertAccept(expectedFD: CInt, expectedNonBlocking: Bool, throwing error: Error,
file: StaticString = #filePath, line: UInt = #line) throws {
SAL.printIfDebug("\(#function)")
try self.selector.assertSyscallAndReturn(.error(error),
file: (file), line: line) { syscall in
if case .accept(let fd, let nonBlocking) = syscall {
XCTAssertEqual(fd, expectedFD, file: (file), line: line)
XCTAssertEqual(nonBlocking, expectedNonBlocking, file: (file), line: line)
return true
} else {
return false
}
}
}
func waitForNextSyscall() throws -> UserToKernel {
return try self.userToKernelBox.waitForValue()
}