Cleanup unix socket pathname on server socket close or bind (#1637)

* Cleanup unix socket pathname on server socket close or bind

* cleaner use of swift syntax

* using stat and unlink via syscal wrappers

* separate error type when UDS path is not a socket file

* track if socket needs cleanup to avoid extra syscall

* struct instead of enum for a single new error type

Co-authored-by: Cory Benfield <lukasa@apple.com>
This commit is contained in:
Andrius 2020-09-22 15:30:49 +01:00 committed by GitHub
parent bc72ee7537
commit b2f3de8ed5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 166 additions and 12 deletions

View File

@ -18,6 +18,9 @@ import NIOConcurrencyHelpers
import let WinSDK.INVALID_SOCKET
#endif
/// The requested UDS path exists and has wrong type (not a socket).
public struct UnixDomainSocketPathWrongType: Error {}
/// A Registration on a `Selector`, which is interested in an `SelectorEventSet`.
protocol Registration {
/// The `SelectorEventSet` in which the `Registration` is interested.
@ -301,6 +304,35 @@ class BaseSocket: BaseSocketProtocol {
}
return sock
}
/// Cleanup the unix domain socket.
///
/// Deletes the associated file if it exists and has socket type. Does nothing if pathname does not exist.
///
/// - parameters:
/// - unixDomainSocketPath: The pathname of the UDS.
/// - throws: An `UnixDomainSocketPathWrongType` if the pathname exists and is not a socket.
static func cleanupSocket(unixDomainSocketPath: String) throws {
do {
var sb: stat = stat()
try withUnsafeMutablePointer(to: &sb) { sbPtr in
try Posix.stat(pathname: unixDomainSocketPath, outStat: sbPtr)
}
// Only unlink the existing file if it is a socket
if sb.st_mode & S_IFSOCK == S_IFSOCK {
try Posix.unlink(pathname: unixDomainSocketPath)
} else {
throw UnixDomainSocketPathWrongType()
}
} catch let err as IOError {
// If the filepath did not exist, we consider it cleaned up
if err.errnoCode == ENOENT {
return
}
throw err
}
}
/// Create a new instance.
///

View File

@ -224,6 +224,23 @@ public final class ServerBootstrap {
try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
}
}
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
///
/// - parameters:
/// - unixDomainSocketPath: The _Unix domain socket_ path to bind to. `unixDomainSocketPath` must not exist, it will be created by the system.
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`.
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
if cleanupExistingSocketFile {
do {
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
} catch {
return group.next().makeFailedFuture(error)
}
}
return self.bind(unixDomainSocketPath: unixDomainSocketPath)
}
#if !os(Windows)
/// Use the existing bound socket file descriptor.
@ -860,6 +877,23 @@ public final class DatagramBootstrap {
return try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
}
}
/// Bind the `DatagramChannel` to a UNIX Domain Socket.
///
/// - parameters:
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. `path` must not exist, it will be created by the system.
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`.
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
if cleanupExistingSocketFile {
do {
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
} catch {
return group.next().makeFailedFuture(error)
}
}
return self.bind(unixDomainSocketPath: unixDomainSocketPath)
}
private func bind0(_ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
let address: SocketAddress

View File

@ -15,6 +15,7 @@
/// A server socket that can accept new connections.
/* final but tests */ class ServerSocket: BaseSocket, ServerSocketProtocol {
typealias SocketType = ServerSocket
private let cleanupOnClose: Bool
public final class func bootstrap(protocolFamily: NIOBSDSocket.ProtocolFamily, host: String, port: Int) throws -> ServerSocket {
let socket = try ServerSocket(protocolFamily: protocolFamily)
@ -31,6 +32,12 @@
/// - throws: An `IOError` if creation of the socket failed.
init(protocolFamily: NIOBSDSocket.ProtocolFamily, setNonBlocking: Bool = false) throws {
let sock = try BaseSocket.makeSocket(protocolFamily: protocolFamily, type: .stream, setNonBlocking: setNonBlocking)
switch protocolFamily {
case .unix:
cleanupOnClose = true
default:
cleanupOnClose = false
}
try super.init(socket: sock)
}
@ -54,6 +61,7 @@
/// - setNonBlocking: Set non-blocking mode on the socket.
/// - throws: An `IOError` if socket is invalid.
init(socket: NIOBSDSocket.Handle, setNonBlocking: Bool = false) throws {
cleanupOnClose = false // socket already bound, owner must clean up
try super.init(socket: socket)
if setNonBlocking {
try self.setNonBlocking()
@ -109,4 +117,17 @@
return sock
}
}
/// Close the socket.
///
/// After the socket was closed all other methods will throw an `IOError` when called.
///
/// - throws: An `IOError` if the operation failed.
override func close() throws {
let maybePathname = self.cleanupOnClose ? (try? self.localAddress().pathname) : nil
try super.close()
if let socketPath = maybePathname {
try BaseSocket.cleanupSocket(unixDomainSocketPath: socketPath)
}
}
}

View File

@ -103,20 +103,10 @@ public enum SocketAddress: CustomStringConvertible {
addressString = try! descriptionForAddress(family: .inet6, bytes: &mutAddr, length: Int(INET6_ADDRSTRLEN))
port = "\(self.port!)"
case .unixDomainSocket(let addr):
let address = addr.address
case .unixDomainSocket(_):
host = nil
type = "UDS"
addressString = ""
// This is a static assert that exists just to verify the safety of the assumption below.
assert(Swift.type(of: address.sun_path.0) == CChar.self)
port = withUnsafePointer(to: address.sun_path) { ptr in
// Homogeneous tuples are always implicitly also bound to their element type, so this assumption below is safe.
let charPtr = UnsafeRawPointer(ptr).assumingMemoryBound(to: CChar.self)
return String(cString: charPtr)
}
return "[\(type)]\(port)"
return "[\(type)]\(self.pathname ?? "")"
}
return "[\(type)]\(host.map { "\($0)/\(addressString):" } ?? "\(addressString):")\(port)"
@ -187,6 +177,25 @@ public enum SocketAddress: CustomStringConvertible {
}
}
}
/// Get the pathname of a UNIX domain socket as a string
public var pathname: String? {
switch self {
case .v4:
return nil
case .v6:
return nil
case .unixDomainSocket(let addr):
// This is a static assert that exists just to verify the safety of the assumption below.
assert(Swift.type(of: addr.address.sun_path.0) == CChar.self)
let pathname: String = withUnsafePointer(to: addr.address.sun_path) { ptr in
// Homogeneous tuples are always implicitly also bound to their element type, so this assumption below is safe.
let charPtr = UnsafeRawPointer(ptr).assumingMemoryBound(to: CChar.self)
return String(cString: charPtr)
}
return pathname
}
}
/// Calls the given function with a pointer to a `sockaddr` structure and the associated size
/// of that structure.

View File

@ -95,6 +95,8 @@ private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointe
#if os(Linux)
private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer<stat>) -> CInt = fstat
private let sysStat: @convention(c) (UnsafePointer<CChar>, UnsafeMutablePointer<stat>) -> CInt = stat
private let sysUnlink: @convention(c) (UnsafePointer<CChar>) -> CInt = unlink
private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIOLinux_mmsghdr>?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg
private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIOLinux_mmsghdr>?, CUnsignedInt, CInt, UnsafeMutablePointer<timespec>?) -> CInt = CNIOLinux_recvmmsg
private let sysCmsgFirstHdr: @convention(c) (UnsafePointer<msghdr>?) -> UnsafeMutablePointer<cmsghdr>? =
@ -108,6 +110,8 @@ private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPA
private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN
#elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer<stat>?) -> CInt = fstat
private let sysStat: @convention(c) (UnsafePointer<CChar>?, UnsafeMutablePointer<stat>?) -> CInt = stat
private let sysUnlink: @convention(c) (UnsafePointer<CChar>?) -> CInt = unlink
private let sysKevent = kevent
private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIODarwin_mmsghdr>?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg
private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIODarwin_mmsghdr>?, CUnsignedInt, CInt, UnsafeMutablePointer<timespec>?) -> CInt = CNIODarwin_recvmmsg
@ -518,6 +522,20 @@ internal enum Posix {
}
}
@inline(never)
public static func stat(pathname: String, outStat: UnsafeMutablePointer<stat>) throws {
_ = try syscall(blocking: false) {
sysStat(pathname, outStat)
}
}
@inline(never)
public static func unlink(pathname: String) throws {
_ = try syscall(blocking: false) {
sysUnlink(pathname)
}
}
@inline(never)
public static func socketpair(domain: NIOBSDSocket.ProtocolFamily,
type: NIOBSDSocket.SocketType,

View File

@ -31,6 +31,8 @@ extension EchoServerClientTest {
("testLotsOfUnflushedWrites", testLotsOfUnflushedWrites),
("testEchoUnixDomainSocket", testEchoUnixDomainSocket),
("testConnectUnixDomainSocket", testConnectUnixDomainSocket),
("testCleanupUnixDomainSocket", testCleanupUnixDomainSocket),
("testBootstrapUnixDomainSocketNameClash", testBootstrapUnixDomainSocketNameClash),
("testChannelActiveOnConnect", testChannelActiveOnConnect),
("testWriteThenRead", testWriteThenRead),
("testCloseInInactive", testCloseInInactive),

View File

@ -172,6 +172,44 @@ class EchoServerClientTest : XCTestCase {
}
}
func testCleanupUnixDomainSocket() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
try withTemporaryUnixDomainSocketPathName { udsPath in
let bootstrap = ServerBootstrap(group: group)
let serverChannel = try assertNoThrowWithValue(
bootstrap.bind(unixDomainSocketPath: udsPath).wait())
XCTAssertNoThrow(try serverChannel.close().wait())
let reusedPathServerChannel = try assertNoThrowWithValue(
bootstrap.bind(unixDomainSocketPath: udsPath,
cleanupExistingSocketFile: true).wait())
XCTAssertNoThrow(try reusedPathServerChannel.close().wait())
}
}
func testBootstrapUnixDomainSocketNameClash() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
try withTemporaryUnixDomainSocketPathName { udsPath in
// Bootstrap should not overwrite an existing file unless it is a socket
FileManager.default.createFile(atPath: udsPath, contents: nil, attributes: nil)
let bootstrap = ServerBootstrap(group: group)
XCTAssertThrowsError(
try bootstrap.bind(unixDomainSocketPath: udsPath).wait())
}
}
func testChannelActiveOnConnect() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {