Cleanup unix socket pathname on server socket close or bind (#1637)
* Cleanup unix socket pathname on server socket close or bind * cleaner use of swift syntax * using stat and unlink via syscal wrappers * separate error type when UDS path is not a socket file * track if socket needs cleanup to avoid extra syscall * struct instead of enum for a single new error type Co-authored-by: Cory Benfield <lukasa@apple.com>
This commit is contained in:
parent
bc72ee7537
commit
b2f3de8ed5
|
@ -18,6 +18,9 @@ import NIOConcurrencyHelpers
|
|||
import let WinSDK.INVALID_SOCKET
|
||||
#endif
|
||||
|
||||
/// The requested UDS path exists and has wrong type (not a socket).
|
||||
public struct UnixDomainSocketPathWrongType: Error {}
|
||||
|
||||
/// A Registration on a `Selector`, which is interested in an `SelectorEventSet`.
|
||||
protocol Registration {
|
||||
/// The `SelectorEventSet` in which the `Registration` is interested.
|
||||
|
@ -301,6 +304,35 @@ class BaseSocket: BaseSocketProtocol {
|
|||
}
|
||||
return sock
|
||||
}
|
||||
|
||||
/// Cleanup the unix domain socket.
|
||||
///
|
||||
/// Deletes the associated file if it exists and has socket type. Does nothing if pathname does not exist.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - unixDomainSocketPath: The pathname of the UDS.
|
||||
/// - throws: An `UnixDomainSocketPathWrongType` if the pathname exists and is not a socket.
|
||||
static func cleanupSocket(unixDomainSocketPath: String) throws {
|
||||
do {
|
||||
var sb: stat = stat()
|
||||
try withUnsafeMutablePointer(to: &sb) { sbPtr in
|
||||
try Posix.stat(pathname: unixDomainSocketPath, outStat: sbPtr)
|
||||
}
|
||||
|
||||
// Only unlink the existing file if it is a socket
|
||||
if sb.st_mode & S_IFSOCK == S_IFSOCK {
|
||||
try Posix.unlink(pathname: unixDomainSocketPath)
|
||||
} else {
|
||||
throw UnixDomainSocketPathWrongType()
|
||||
}
|
||||
} catch let err as IOError {
|
||||
// If the filepath did not exist, we consider it cleaned up
|
||||
if err.errnoCode == ENOENT {
|
||||
return
|
||||
}
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new instance.
|
||||
///
|
||||
|
|
|
@ -224,6 +224,23 @@ public final class ServerBootstrap {
|
|||
try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind the `ServerSocketChannel` to a UNIX Domain Socket.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - unixDomainSocketPath: The _Unix domain socket_ path to bind to. `unixDomainSocketPath` must not exist, it will be created by the system.
|
||||
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`.
|
||||
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
|
||||
if cleanupExistingSocketFile {
|
||||
do {
|
||||
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
|
||||
} catch {
|
||||
return group.next().makeFailedFuture(error)
|
||||
}
|
||||
}
|
||||
|
||||
return self.bind(unixDomainSocketPath: unixDomainSocketPath)
|
||||
}
|
||||
|
||||
#if !os(Windows)
|
||||
/// Use the existing bound socket file descriptor.
|
||||
|
@ -860,6 +877,23 @@ public final class DatagramBootstrap {
|
|||
return try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind the `DatagramChannel` to a UNIX Domain Socket.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. `path` must not exist, it will be created by the system.
|
||||
/// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `path`.
|
||||
public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool) -> EventLoopFuture<Channel> {
|
||||
if cleanupExistingSocketFile {
|
||||
do {
|
||||
try BaseSocket.cleanupSocket(unixDomainSocketPath: unixDomainSocketPath)
|
||||
} catch {
|
||||
return group.next().makeFailedFuture(error)
|
||||
}
|
||||
}
|
||||
|
||||
return self.bind(unixDomainSocketPath: unixDomainSocketPath)
|
||||
}
|
||||
|
||||
private func bind0(_ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
|
||||
let address: SocketAddress
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
/// A server socket that can accept new connections.
|
||||
/* final but tests */ class ServerSocket: BaseSocket, ServerSocketProtocol {
|
||||
typealias SocketType = ServerSocket
|
||||
private let cleanupOnClose: Bool
|
||||
|
||||
public final class func bootstrap(protocolFamily: NIOBSDSocket.ProtocolFamily, host: String, port: Int) throws -> ServerSocket {
|
||||
let socket = try ServerSocket(protocolFamily: protocolFamily)
|
||||
|
@ -31,6 +32,12 @@
|
|||
/// - throws: An `IOError` if creation of the socket failed.
|
||||
init(protocolFamily: NIOBSDSocket.ProtocolFamily, setNonBlocking: Bool = false) throws {
|
||||
let sock = try BaseSocket.makeSocket(protocolFamily: protocolFamily, type: .stream, setNonBlocking: setNonBlocking)
|
||||
switch protocolFamily {
|
||||
case .unix:
|
||||
cleanupOnClose = true
|
||||
default:
|
||||
cleanupOnClose = false
|
||||
}
|
||||
try super.init(socket: sock)
|
||||
}
|
||||
|
||||
|
@ -54,6 +61,7 @@
|
|||
/// - setNonBlocking: Set non-blocking mode on the socket.
|
||||
/// - throws: An `IOError` if socket is invalid.
|
||||
init(socket: NIOBSDSocket.Handle, setNonBlocking: Bool = false) throws {
|
||||
cleanupOnClose = false // socket already bound, owner must clean up
|
||||
try super.init(socket: socket)
|
||||
if setNonBlocking {
|
||||
try self.setNonBlocking()
|
||||
|
@ -109,4 +117,17 @@
|
|||
return sock
|
||||
}
|
||||
}
|
||||
|
||||
/// Close the socket.
|
||||
///
|
||||
/// After the socket was closed all other methods will throw an `IOError` when called.
|
||||
///
|
||||
/// - throws: An `IOError` if the operation failed.
|
||||
override func close() throws {
|
||||
let maybePathname = self.cleanupOnClose ? (try? self.localAddress().pathname) : nil
|
||||
try super.close()
|
||||
if let socketPath = maybePathname {
|
||||
try BaseSocket.cleanupSocket(unixDomainSocketPath: socketPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,20 +103,10 @@ public enum SocketAddress: CustomStringConvertible {
|
|||
addressString = try! descriptionForAddress(family: .inet6, bytes: &mutAddr, length: Int(INET6_ADDRSTRLEN))
|
||||
|
||||
port = "\(self.port!)"
|
||||
case .unixDomainSocket(let addr):
|
||||
let address = addr.address
|
||||
case .unixDomainSocket(_):
|
||||
host = nil
|
||||
type = "UDS"
|
||||
addressString = ""
|
||||
|
||||
// This is a static assert that exists just to verify the safety of the assumption below.
|
||||
assert(Swift.type(of: address.sun_path.0) == CChar.self)
|
||||
port = withUnsafePointer(to: address.sun_path) { ptr in
|
||||
// Homogeneous tuples are always implicitly also bound to their element type, so this assumption below is safe.
|
||||
let charPtr = UnsafeRawPointer(ptr).assumingMemoryBound(to: CChar.self)
|
||||
return String(cString: charPtr)
|
||||
}
|
||||
return "[\(type)]\(port)"
|
||||
return "[\(type)]\(self.pathname ?? "")"
|
||||
}
|
||||
|
||||
return "[\(type)]\(host.map { "\($0)/\(addressString):" } ?? "\(addressString):")\(port)"
|
||||
|
@ -187,6 +177,25 @@ public enum SocketAddress: CustomStringConvertible {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the pathname of a UNIX domain socket as a string
|
||||
public var pathname: String? {
|
||||
switch self {
|
||||
case .v4:
|
||||
return nil
|
||||
case .v6:
|
||||
return nil
|
||||
case .unixDomainSocket(let addr):
|
||||
// This is a static assert that exists just to verify the safety of the assumption below.
|
||||
assert(Swift.type(of: addr.address.sun_path.0) == CChar.self)
|
||||
let pathname: String = withUnsafePointer(to: addr.address.sun_path) { ptr in
|
||||
// Homogeneous tuples are always implicitly also bound to their element type, so this assumption below is safe.
|
||||
let charPtr = UnsafeRawPointer(ptr).assumingMemoryBound(to: CChar.self)
|
||||
return String(cString: charPtr)
|
||||
}
|
||||
return pathname
|
||||
}
|
||||
}
|
||||
|
||||
/// Calls the given function with a pointer to a `sockaddr` structure and the associated size
|
||||
/// of that structure.
|
||||
|
|
|
@ -95,6 +95,8 @@ private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointe
|
|||
|
||||
#if os(Linux)
|
||||
private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer<stat>) -> CInt = fstat
|
||||
private let sysStat: @convention(c) (UnsafePointer<CChar>, UnsafeMutablePointer<stat>) -> CInt = stat
|
||||
private let sysUnlink: @convention(c) (UnsafePointer<CChar>) -> CInt = unlink
|
||||
private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIOLinux_mmsghdr>?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg
|
||||
private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIOLinux_mmsghdr>?, CUnsignedInt, CInt, UnsafeMutablePointer<timespec>?) -> CInt = CNIOLinux_recvmmsg
|
||||
private let sysCmsgFirstHdr: @convention(c) (UnsafePointer<msghdr>?) -> UnsafeMutablePointer<cmsghdr>? =
|
||||
|
@ -108,6 +110,8 @@ private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPA
|
|||
private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN
|
||||
#elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
|
||||
private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer<stat>?) -> CInt = fstat
|
||||
private let sysStat: @convention(c) (UnsafePointer<CChar>?, UnsafeMutablePointer<stat>?) -> CInt = stat
|
||||
private let sysUnlink: @convention(c) (UnsafePointer<CChar>?) -> CInt = unlink
|
||||
private let sysKevent = kevent
|
||||
private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIODarwin_mmsghdr>?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg
|
||||
private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIODarwin_mmsghdr>?, CUnsignedInt, CInt, UnsafeMutablePointer<timespec>?) -> CInt = CNIODarwin_recvmmsg
|
||||
|
@ -518,6 +522,20 @@ internal enum Posix {
|
|||
}
|
||||
}
|
||||
|
||||
@inline(never)
|
||||
public static func stat(pathname: String, outStat: UnsafeMutablePointer<stat>) throws {
|
||||
_ = try syscall(blocking: false) {
|
||||
sysStat(pathname, outStat)
|
||||
}
|
||||
}
|
||||
|
||||
@inline(never)
|
||||
public static func unlink(pathname: String) throws {
|
||||
_ = try syscall(blocking: false) {
|
||||
sysUnlink(pathname)
|
||||
}
|
||||
}
|
||||
|
||||
@inline(never)
|
||||
public static func socketpair(domain: NIOBSDSocket.ProtocolFamily,
|
||||
type: NIOBSDSocket.SocketType,
|
||||
|
|
|
@ -31,6 +31,8 @@ extension EchoServerClientTest {
|
|||
("testLotsOfUnflushedWrites", testLotsOfUnflushedWrites),
|
||||
("testEchoUnixDomainSocket", testEchoUnixDomainSocket),
|
||||
("testConnectUnixDomainSocket", testConnectUnixDomainSocket),
|
||||
("testCleanupUnixDomainSocket", testCleanupUnixDomainSocket),
|
||||
("testBootstrapUnixDomainSocketNameClash", testBootstrapUnixDomainSocketNameClash),
|
||||
("testChannelActiveOnConnect", testChannelActiveOnConnect),
|
||||
("testWriteThenRead", testWriteThenRead),
|
||||
("testCloseInInactive", testCloseInInactive),
|
||||
|
|
|
@ -172,6 +172,44 @@ class EchoServerClientTest : XCTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
func testCleanupUnixDomainSocket() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer {
|
||||
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||
}
|
||||
|
||||
try withTemporaryUnixDomainSocketPathName { udsPath in
|
||||
let bootstrap = ServerBootstrap(group: group)
|
||||
|
||||
let serverChannel = try assertNoThrowWithValue(
|
||||
bootstrap.bind(unixDomainSocketPath: udsPath).wait())
|
||||
|
||||
XCTAssertNoThrow(try serverChannel.close().wait())
|
||||
|
||||
let reusedPathServerChannel = try assertNoThrowWithValue(
|
||||
bootstrap.bind(unixDomainSocketPath: udsPath,
|
||||
cleanupExistingSocketFile: true).wait())
|
||||
|
||||
XCTAssertNoThrow(try reusedPathServerChannel.close().wait())
|
||||
}
|
||||
}
|
||||
|
||||
func testBootstrapUnixDomainSocketNameClash() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer {
|
||||
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||
}
|
||||
|
||||
try withTemporaryUnixDomainSocketPathName { udsPath in
|
||||
// Bootstrap should not overwrite an existing file unless it is a socket
|
||||
FileManager.default.createFile(atPath: udsPath, contents: nil, attributes: nil)
|
||||
let bootstrap = ServerBootstrap(group: group)
|
||||
|
||||
XCTAssertThrowsError(
|
||||
try bootstrap.bind(unixDomainSocketPath: udsPath).wait())
|
||||
}
|
||||
}
|
||||
|
||||
func testChannelActiveOnConnect() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer {
|
||||
|
|
Loading…
Reference in New Issue