From 810544ec414265b9783ffa78ac1f34695f70acf7 Mon Sep 17 00:00:00 2001 From: David Nadoba Date: Thu, 1 Dec 2022 15:35:04 +0100 Subject: [PATCH] Add `RawSocketBootstrap` (#2320) --- Sources/NIOCore/BSDSocketAPI.swift | 7 + Sources/NIOCore/ChannelOption.swift | 5 + Sources/NIOPosix/BSDSocketAPICommon.swift | 8 + Sources/NIOPosix/RawSocketBootstrap.swift | 197 ++++++++++ Sources/NIOPosix/Socket.swift | 2 +- Tests/LinuxMain.swift | 1 + .../NIOPosixTests/DatagramChannelTests.swift | 4 +- Tests/NIOPosixTests/IPv4Header.swift | 365 ++++++++++++++++++ .../RawSocketBootstrapTests+XCTest.swift | 36 ++ .../RawSocketBootstrapTests.swift | 164 ++++++++ 10 files changed, 786 insertions(+), 3 deletions(-) create mode 100644 Sources/NIOPosix/RawSocketBootstrap.swift create mode 100644 Tests/NIOPosixTests/IPv4Header.swift create mode 100644 Tests/NIOPosixTests/RawSocketBootstrapTests+XCTest.swift create mode 100644 Tests/NIOPosixTests/RawSocketBootstrapTests.swift diff --git a/Sources/NIOCore/BSDSocketAPI.swift b/Sources/NIOCore/BSDSocketAPI.swift index 37764986..23cf4d02 100644 --- a/Sources/NIOCore/BSDSocketAPI.swift +++ b/Sources/NIOCore/BSDSocketAPI.swift @@ -267,6 +267,13 @@ extension NIOBSDSocket.Option { /// Control multicast time-to-live. public static let ip_multicast_ttl: NIOBSDSocket.Option = NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL) + + /// The IPv4 layer generates an IP header when sending a packet + /// unless the ``ip_hdrincl`` socket option is enabled on the socket. + /// When it is enabled, the packet must contain an IP header. For + /// receiving, the IP header is always included in the packet. + public static let ip_hdrincl: NIOBSDSocket.Option = + NIOBSDSocket.Option(rawValue: IP_HDRINCL) } // IPv6 Options diff --git a/Sources/NIOCore/ChannelOption.swift b/Sources/NIOCore/ChannelOption.swift index f79e126a..826ef021 100644 --- a/Sources/NIOCore/ChannelOption.swift +++ b/Sources/NIOCore/ChannelOption.swift @@ -276,6 +276,11 @@ public struct ChannelOptions { public static let socketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in .init(level: .socket, name: name) } + + /// - seealso: `SocketOption`. + public static let ipOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in + .init(level: .ip, name: name) + } /// - seealso: `SocketOption`. public static let tcpOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in diff --git a/Sources/NIOPosix/BSDSocketAPICommon.swift b/Sources/NIOPosix/BSDSocketAPICommon.swift index 01c92198..087ee30a 100644 --- a/Sources/NIOPosix/BSDSocketAPICommon.swift +++ b/Sources/NIOPosix/BSDSocketAPICommon.swift @@ -83,6 +83,14 @@ extension NIOBSDSocket.SocketType { internal static let stream: NIOBSDSocket.SocketType = NIOBSDSocket.SocketType(rawValue: SOCK_STREAM) #endif + + #if os(Linux) + internal static let raw: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue)) + #else + internal static let raw: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: SOCK_RAW) + #endif } // IPv4 Options diff --git a/Sources/NIOPosix/RawSocketBootstrap.swift b/Sources/NIOPosix/RawSocketBootstrap.swift new file mode 100644 index 00000000..2ab19640 --- /dev/null +++ b/Sources/NIOPosix/RawSocketBootstrap.swift @@ -0,0 +1,197 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIOCore + +/// A `RawSocketBootstrap` is an easy way to interact with IP based protocols other then TCP and UDP. +/// +/// Example: +/// +/// ```swift +/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) +/// defer { +/// try! group.syncShutdownGracefully() +/// } +/// let bootstrap = RawSocketBootstrap(group: group) +/// .channelInitializer { channel in +/// channel.pipeline.addHandler(MyChannelHandler()) +/// } +/// let channel = try! bootstrap.bind(host: "127.0.0.1", ipProtocol: .icmp).wait() +/// /* the Channel is now ready to send/receive IP packets */ +/// +/// try channel.closeFuture.wait() // Wait until the channel un-binds. +/// ``` +/// +/// The `Channel` will operate on `AddressedEnvelope` as inbound and outbound messages. +public final class NIORawSocketBootstrap { + + private let group: EventLoopGroup + private var channelInitializer: Optional + @usableFromInline + internal var _channelOptions: ChannelOptions.Storage + + /// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`. + /// + /// The `EventLoopGroup` `group` must be compatible, otherwise the program will crash. `RawSocketBootstrap` is + /// compatible only with `MultiThreadedEventLoopGroup` as well as the `EventLoop`s returned by + /// `MultiThreadedEventLoopGroup.next`. See `init(validatingGroup:)` for a fallible initializer for + /// situations where it's impossible to tell ahead of time if the `EventLoopGroup` is compatible or not. + /// + /// - parameters: + /// - group: The `EventLoopGroup` to use. + public convenience init(group: EventLoopGroup) { + guard NIOOnSocketsBootstraps.isCompatible(group: group) else { + preconditionFailure("RawSocketBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with \(group) which is incompatible.") + } + self.init(validatingGroup: group)! + } + + /// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`, validating that `group` is compatible. + /// + /// - parameters: + /// - group: The `EventLoopGroup` to use. + public init?(validatingGroup group: EventLoopGroup) { + guard NIOOnSocketsBootstraps.isCompatible(group: group) else { + return nil + } + self._channelOptions = ChannelOptions.Storage() + self.group = group + self.channelInitializer = nil + } + + /// Initialize the bound `Channel` with `initializer`. The most common task in initializer is to add + /// `ChannelHandler`s to the `ChannelPipeline`. + /// + /// - parameters: + /// - handler: A closure that initializes the provided `Channel`. + public func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self { + self.channelInitializer = handler + return self + } + + /// Specifies a `ChannelOption` to be applied to the `Channel`. + /// + /// - parameters: + /// - option: The option to be applied. + /// - value: The value for the option. + @inlinable + public func channelOption(_ option: Option, value: Option.Value) -> Self { + self._channelOptions.append(key: option, value: value) + return self + } + + /// Bind the `Channel` to `host`. + /// All packets or errors matching the `ipProtocol` specified are passed to the resulting `Channel`. + /// + /// - parameters: + /// - host: The host to bind on. + /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. + public func bind(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture { + return bind0(ipProtocol: ipProtocol) { + return try SocketAddress.makeAddressResolvingHost(host, port: 0) + } + } + + private func bind0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + let address: SocketAddress + do { + address = try makeSocketAddress() + } catch { + return group.next().makeFailedFuture(error) + } + precondition(address.port == nil || address.port == 0, "port must be 0 or not set") + func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { + return try DatagramChannel(eventLoop: eventLoop, + protocolFamily: address.protocol, + protocolSubtype: .init(ipProtocol), + socketType: .raw) + } + return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in + channel.register().flatMap { + channel.bind(to: address) + } + } + } + + /// Connect the `Channel` to `host`. + /// + /// - parameters: + /// - host: The host to connect to. + /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. + public func connect(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture { + return connect0(ipProtocol: ipProtocol) { + return try SocketAddress.makeAddressResolvingHost(host, port: 0) + } + } + + private func connect0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + let address: SocketAddress + do { + address = try makeSocketAddress() + } catch { + return group.next().makeFailedFuture(error) + } + func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { + return try DatagramChannel(eventLoop: eventLoop, + protocolFamily: address.protocol, + protocolSubtype: .init(ipProtocol), + socketType: .raw) + } + return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in + channel.register().flatMap { + channel.connect(to: address) + } + } + } + + private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture) -> EventLoopFuture { + let eventLoop = self.group.next() + let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } + let channelOptions = self._channelOptions + + let channel: DatagramChannel + do { + channel = try makeChannel(eventLoop as! SelectableEventLoop) + } catch { + return eventLoop.makeFailedFuture(error) + } + + func setupChannel() -> EventLoopFuture { + eventLoop.assertInEventLoop() + return channelOptions.applyAllChannelOptions(to: channel).flatMap { + channelInitializer(channel) + }.flatMap { + eventLoop.assertInEventLoop() + return bringup(eventLoop, channel) + }.map { + channel + }.flatMapError { error in + eventLoop.makeFailedFuture(error) + } + } + + if eventLoop.inEventLoop { + return setupChannel() + } else { + return eventLoop.flatSubmit { + setupChannel() + } + } + } +} + +#if swift(>=5.6) +@available(*, unavailable) +extension NIORawSocketBootstrap: Sendable {} +#endif diff --git a/Sources/NIOPosix/Socket.swift b/Sources/NIOPosix/Socket.swift index da9f4a0e..e55cab83 100644 --- a/Sources/NIOPosix/Socket.swift +++ b/Sources/NIOPosix/Socket.swift @@ -32,7 +32,7 @@ typealias IOVector = iovec /// - parameters: /// - protocolFamily: The protocol family to use (usually `AF_INET6` or `AF_INET`). /// - type: The type of the socket to create. - /// - protocolSubtype: The subtype of the protocol, corresponding to the `protocol` + /// - protocolSubtype: The subtype of the protocol, corresponding to the `protocolSubtype` /// argument to the socket syscall. Defaults to 0. /// - setNonBlocking: Set non-blocking mode on the socket. /// - throws: An `IOError` if creation of the socket failed. diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index bd01f799..d5daf092 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -120,6 +120,7 @@ class LinuxMainRunner { testCase(PendingDatagramWritesManagerTests.allTests), testCase(PipeChannelTest.allTests), testCase(PriorityQueueTest.allTests), + testCase(RawSocketBootstrapTests.allTests), testCase(SALChannelTest.allTests), testCase(SALEventLoopTests.allTests), testCase(SNIHandlerTest.allTests), diff --git a/Tests/NIOPosixTests/DatagramChannelTests.swift b/Tests/NIOPosixTests/DatagramChannelTests.swift index a0d3941f..8c1b8e4e 100644 --- a/Tests/NIOPosixTests/DatagramChannelTests.swift +++ b/Tests/NIOPosixTests/DatagramChannelTests.swift @@ -17,7 +17,7 @@ import NIOCore @testable import NIOPosix import XCTest -private extension Channel { +extension Channel { func waitForDatagrams(count: Int) throws -> [AddressedEnvelope] { return try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in if let future = (context.handler as? DatagramReadRecorder)?.notifyForDatagrams(count) { @@ -47,7 +47,7 @@ private extension Channel { /// A class that records datagrams received and forwards them on. /// /// Used extensively in tests to validate messaging expectations. -private class DatagramReadRecorder: ChannelInboundHandler { +final class DatagramReadRecorder: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope typealias InboundOut = AddressedEnvelope diff --git a/Tests/NIOPosixTests/IPv4Header.swift b/Tests/NIOPosixTests/IPv4Header.swift new file mode 100644 index 00000000..1bda7fa5 --- /dev/null +++ b/Tests/NIOPosixTests/IPv4Header.swift @@ -0,0 +1,365 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +struct IPv4Address: Hashable { + var rawValue: UInt32 +} + +extension IPv4Address { + init(_ v1: UInt8, _ v2: UInt8, _ v3: UInt8, _ v4: UInt8) { + rawValue = UInt32(v1) << 24 | UInt32(v2) << 16 | UInt32(v3) << 8 | UInt32(v4) + } +} + +extension IPv4Address: CustomStringConvertible { + var description: String { + let v1 = rawValue >> 24 + let v2 = rawValue >> 16 & 0b1111_1111 + let v3 = rawValue >> 8 & 0b1111_1111 + let v4 = rawValue & 0b1111_1111 + return "\(v1).\(v2).\(v3).\(v4)" + } +} + +struct IPv4Header: Hashable { + static let size: Int = 20 + + fileprivate var versionAndIhl: UInt8 + var version: UInt8 { + get { + versionAndIhl >> 4 + } + set { + precondition(newValue & 0b1111_0000 == 0) + versionAndIhl = newValue << 4 | (0b0000_1111 & versionAndIhl) + assert(newValue == version, "\(newValue) != \(version) \(versionAndIhl)") + } + } + var internetHeaderLength: UInt8 { + get { + versionAndIhl & 0b0000_1111 + } + set { + precondition(newValue & 0b1111_0000 == 0) + versionAndIhl = newValue | (0b1111_0000 & versionAndIhl) + assert(newValue == internetHeaderLength) + } + } + fileprivate var dscpAndEcn: UInt8 + var differentiatedServicesCodePoint: UInt8 { + get { + dscpAndEcn >> 2 + } + set { + precondition(newValue & 0b0000_0011 == 0) + dscpAndEcn = newValue << 2 | (0b0000_0011 & dscpAndEcn) + assert(newValue == differentiatedServicesCodePoint) + } + } + var explicitCongestionNotification: UInt8 { + get { + dscpAndEcn & 0b0000_0011 + } + set { + precondition(newValue & 0b0000_0011 == 0) + dscpAndEcn = newValue | (0b1111_1100 & dscpAndEcn) + assert(newValue == explicitCongestionNotification) + } + } + var totalLength: UInt16 + var identification: UInt16 + fileprivate var flagsAndFragmentOffset: UInt16 + var flags: UInt8 { + get { + UInt8(flagsAndFragmentOffset >> 13) + } + set { + precondition(newValue & 0b0000_0111 == 0) + flagsAndFragmentOffset = UInt16(newValue) << 13 | (0b0001_1111_1111_1111 & flagsAndFragmentOffset) + assert(newValue == flags) + } + } + var fragmentOffset: UInt16 { + get { + flagsAndFragmentOffset & 0b0001_1111_1111_1111 + } + set { + precondition(newValue & 0b1110_0000_0000_0000 == 0) + flagsAndFragmentOffset = newValue | (0b1110_0000_0000_0000 & flagsAndFragmentOffset) + assert(newValue == fragmentOffset) + } + } + var timeToLive: UInt8 + var `protocol`: NIOIPProtocol + var headerChecksum: UInt16 + var sourceIpAddress: IPv4Address + var destinationIpAddress: IPv4Address + + fileprivate init( + versionAndIhl: UInt8, + dscpAndEcn: UInt8, + totalLength: UInt16, + identification: UInt16, + flagsAndFragmentOffset: UInt16, + timeToLive: UInt8, + `protocol`: NIOIPProtocol, + headerChecksum: UInt16, + sourceIpAddress: IPv4Address, + destinationIpAddress: IPv4Address + ) { + self.versionAndIhl = versionAndIhl + self.dscpAndEcn = dscpAndEcn + self.totalLength = totalLength + self.identification = identification + self.flagsAndFragmentOffset = flagsAndFragmentOffset + self.timeToLive = timeToLive + self.`protocol` = `protocol` + self.headerChecksum = headerChecksum + self.sourceIpAddress = sourceIpAddress + self.destinationIpAddress = destinationIpAddress + } + + init() { + self.versionAndIhl = 0 + self.dscpAndEcn = 0 + self.totalLength = 0 + self.identification = 0 + self.flagsAndFragmentOffset = 0 + self.timeToLive = 0 + self.`protocol` = .init(rawValue: 0) + self.headerChecksum = 0 + self.sourceIpAddress = .init(rawValue: 0) + self.destinationIpAddress = .init(rawValue: 0) + } +} + +extension FixedWidthInteger { + func convertEndianness(to endianness: Endianness) -> Self { + switch endianness { + case .little: + return self.littleEndian + case .big: + return self.bigEndian + } + } +} + + + +extension ByteBuffer { + mutating func readIPv4Header() -> IPv4Header? { + guard let ( + versionAndIhl, + dscpAndEcn, + totalLength, + identification, + flagsAndFragmentOffset, + timeToLive, + `protocol`, + headerChecksum, + sourceIpAddress, + destinationIpAddress + ) = self.readMultipleIntegers(as: ( + UInt8, + UInt8, + UInt16, + UInt16, + UInt16, + UInt8, + UInt8, + UInt16, + UInt32, + UInt32 + ).self) else { return nil } + return .init( + versionAndIhl: versionAndIhl, + dscpAndEcn: dscpAndEcn, + totalLength: totalLength, + identification: identification, + flagsAndFragmentOffset: flagsAndFragmentOffset, + timeToLive: timeToLive, + protocol: .init(rawValue: `protocol`), + headerChecksum: headerChecksum, + sourceIpAddress: .init(rawValue: sourceIpAddress), + destinationIpAddress: .init(rawValue: destinationIpAddress) + ) + } + + mutating func readIPv4HeaderFromBSDRawSocket() -> IPv4Header? { + guard var header = self.readIPv4Header() else { return nil } + // On BSD, the total length is in host byte order + header.totalLength = header.totalLength.convertEndianness(to: .big) + // TODO: fragmentOffset is in host byte order as well but it is always zero in our tests + // and fragmentOffset is 13 bits in size so we can't just use readInteger(endianness: .host) + return header + } + + mutating func readIPv4HeaderFromOSRawSocket() -> IPv4Header? { + #if canImport(Darwin) + return self.readIPv4HeaderFromBSDRawSocket() + #else + return self.readIPv4Header() + #endif + } +} + +extension ByteBuffer { + @discardableResult + mutating func writeIPv4Header(_ header: IPv4Header) -> Int { + assert({ + var buffer = ByteBuffer() + buffer._writeIPv4Header(header) + let writtenHeader = buffer.readIPv4Header() + return header == writtenHeader + }()) + return self._writeIPv4Header(header) + } + + @discardableResult + private mutating func _writeIPv4Header(_ header: IPv4Header) -> Int { + return self.writeMultipleIntegers( + header.versionAndIhl, + header.dscpAndEcn, + header.totalLength, + header.identification, + header.flagsAndFragmentOffset, + header.timeToLive, + header.`protocol`.rawValue, + header.headerChecksum, + header.sourceIpAddress.rawValue, + header.destinationIpAddress.rawValue + ) + } + + @discardableResult + mutating func writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int { + assert({ + var buffer = ByteBuffer() + buffer._writeIPv4HeaderToBSDRawSocket(header) + let writtenHeader = buffer.readIPv4HeaderFromBSDRawSocket() + return header == writtenHeader + }()) + return self._writeIPv4HeaderToBSDRawSocket(header) + } + + @discardableResult + private mutating func _writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int { + var header = header + // On BSD, the total length needs to be in host byte order + header.totalLength = header.totalLength.convertEndianness(to: .big) + // TODO: fragmentOffset needs to be in host byte order as well but it is always zero in our tests + // and fragmentOffset is 13 bits in size so we can't just use writeInteger(endianness: .host) + return self._writeIPv4Header(header) + } + + @discardableResult + mutating func writeIPv4HeaderToOSRawSocket(_ header: IPv4Header) -> Int { + #if canImport(Darwin) + self.writeIPv4HeaderToBSDRawSocket(header) + #else + self.writeIPv4Header(header) + #endif + } +} + +extension IPv4Header { + func computeChecksum() -> UInt16 { + let checksum = ~[ + UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn), + totalLength, + identification, + flagsAndFragmentOffset, + UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue), + UInt16(sourceIpAddress.rawValue >> 16), + UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111), + UInt16(destinationIpAddress.rawValue >> 16), + UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111), + ].reduce(UInt16(0), onesComplementAdd) + assert(isValidChecksum(checksum)) + return checksum + } + mutating func setChecksum() { + self.headerChecksum = computeChecksum() + } + func isValidChecksum(_ headerChecksum: UInt16) -> Bool { + let sum = ~[ + UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn), + totalLength, + identification, + flagsAndFragmentOffset, + UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue), + headerChecksum, + UInt16(sourceIpAddress.rawValue >> 16), + UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111), + UInt16(destinationIpAddress.rawValue >> 16), + UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111), + ].reduce(UInt16(0), onesComplementAdd) + return sum == 0 + } + func isValidChecksum() -> Bool { + isValidChecksum(headerChecksum) + } +} + +private func onesComplementAdd(lhs: Integer, rhs: Integer) -> Integer { + var (sum, overflowed) = lhs.addingReportingOverflow(rhs) + if overflowed { + sum &+= 1 + } + return sum +} + +extension IPv4Header { + var platformIndependentTotalLengthForReceivedPacketFromRawSocket: UInt16 { + #if canImport(Darwin) + // On BSD the IP header will only contain the size of the ip packet body, not the header. + // This is known bug which can't be fixed without breaking old apps which already workaround the issue + // like e.g. we do now too. + return totalLength + 20 + #elseif os(Linux) + return totalLength + #endif + } + var platformIndependentChecksumForReceivedPacketFromRawSocket: UInt16 { + #if canImport(Darwin) + // On BSD the checksum is always zero and we need to compute it + precondition(headerChecksum == 0) + return computeChecksum() + #elseif os(Linux) + return headerChecksum + #endif + } +} + +extension IPv4Header: CustomStringConvertible { + var description: String { + """ + Version: \(version) + Header Length: \(internetHeaderLength * 4) bytes + Differentiated Services: \(String(differentiatedServicesCodePoint, radix: 2)) + Explicit Congestion Notification: \(String(explicitCongestionNotification, radix: 2)) + Total Length: \(totalLength) bytes + Identification: \(identification) + Flags: \(String(flags, radix: 2)) + Fragment Offset: \(fragmentOffset) bytes + Time to Live: \(timeToLive) + Protocol: \(`protocol`) + Header Checksum: \(headerChecksum) (\(isValidChecksum() ? "valid" : "*not* valid")) + Source IP Address: \(sourceIpAddress) + Destination IP Address: \(destinationIpAddress) + """ + } +} diff --git a/Tests/NIOPosixTests/RawSocketBootstrapTests+XCTest.swift b/Tests/NIOPosixTests/RawSocketBootstrapTests+XCTest.swift new file mode 100644 index 00000000..186cd6fc --- /dev/null +++ b/Tests/NIOPosixTests/RawSocketBootstrapTests+XCTest.swift @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// RawSocketBootstrapTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension RawSocketBootstrapTests { + + @available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings") + static var allTests : [(String, (RawSocketBootstrapTests) -> () throws -> Void)] { + return [ + ("testBindWithRecevMmsg", testBindWithRecevMmsg), + ("testConnect", testConnect), + ("testIpHdrincl", testIpHdrincl), + ] + } +} + diff --git a/Tests/NIOPosixTests/RawSocketBootstrapTests.swift b/Tests/NIOPosixTests/RawSocketBootstrapTests.swift new file mode 100644 index 00000000..ff721482 --- /dev/null +++ b/Tests/NIOPosixTests/RawSocketBootstrapTests.swift @@ -0,0 +1,164 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +import NIOCore +@testable import NIOPosix + +extension NIOIPProtocol { + static let reservedForTesting = Self(rawValue: 253) +} + +// lazily try's to create a raw socket and caches the error if it fails +private let cachedRawSocketAPICheck = Result { + let socket = try Socket(protocolFamily: .inet, type: .raw, protocolSubtype: .init(NIOIPProtocol.reservedForTesting), setNonBlocking: true) + try socket.close() +} + +func XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI(file: StaticString = #filePath, line: UInt = #line) throws { + do { + try cachedRawSocketAPICheck.get() + } catch let error as IOError where error.errnoCode == EPERM { + throw XCTSkip("Raw Socket API requires higher privileges: \(error)", file: file, line: line) + } +} + +final class RawSocketBootstrapTests: XCTestCase { + func testBindWithRecevMmsg() throws { + try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI() + + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + let channel = try NIORawSocketBootstrap(group: elg) + .channelInitializer { + $0.pipeline.addHandler(DatagramReadRecorder(), name: "ByteReadRecorder") + } + .bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait() + defer { XCTAssertNoThrow(try channel.close().wait()) } + try channel.configureForRecvMmsg(messageCount: 10) + let expectedMessages = (1...10).map { "Hello World \($0)" } + for message in expectedMessages { + _ = try channel.write(AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0), + data: ByteBuffer(string: message) + )) + } + channel.flush() + + let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in + var data = envelop.data + let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket()) + XCTAssertEqual(header.version, 4) + XCTAssertEqual(header.protocol, .reservedForTesting) + XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes) + XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)") + XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1)) + XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1)) + return String(buffer: data) + }) + + XCTAssertEqual(receivedMessages, Set(expectedMessages)) + } + + func testConnect() throws { + try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI() + + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + let readChannel = try NIORawSocketBootstrap(group: elg) + .channelInitializer { + $0.pipeline.addHandler(DatagramReadRecorder(), name: "ByteReadRecorder") + } + .bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait() + defer { XCTAssertNoThrow(try readChannel.close().wait()) } + + let writeChannel = try NIORawSocketBootstrap(group: elg) + .channelInitializer { + $0.pipeline.addHandler(DatagramReadRecorder(), name: "ByteReadRecorder") + } + .bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait() + defer { XCTAssertNoThrow(try writeChannel.close().wait()) } + + let expectedMessages = (1...10).map { "Hello World \($0)" } + for message in expectedMessages { + _ = try writeChannel.write(AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0), + data: ByteBuffer(string: message) + )) + } + writeChannel.flush() + + let receivedMessages = Set(try readChannel.waitForDatagrams(count: 10).map { envelop -> String in + var data = envelop.data + let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket()) + XCTAssertEqual(header.version, 4) + XCTAssertEqual(header.protocol, .reservedForTesting) + XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes) + XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)") + XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1)) + XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1)) + return String(buffer: data) + }) + + XCTAssertEqual(receivedMessages, Set(expectedMessages)) + } + + func testIpHdrincl() throws { + try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI() + + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + let channel = try NIORawSocketBootstrap(group: elg) + .channelOption(ChannelOptions.ipOption(.ip_hdrincl), value: 1) + .channelInitializer { + $0.pipeline.addHandler(DatagramReadRecorder(), name: "ByteReadRecorder") + } + .bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait() + defer { XCTAssertNoThrow(try channel.close().wait()) } + try channel.configureForRecvMmsg(messageCount: 10) + let expectedMessages = (1...10).map { "Hello World \($0)" } + for message in expectedMessages.map(ByteBuffer.init(string:)) { + var packet = ByteBuffer() + var header = IPv4Header() + header.version = 4 + header.internetHeaderLength = 5 + header.totalLength = UInt16(IPv4Header.size + message.readableBytes) + header.protocol = .reservedForTesting + header.timeToLive = 64 + header.destinationIpAddress = .init(127, 0, 0, 1) + header.sourceIpAddress = .init(127, 0, 0, 1) + header.setChecksum() + packet.writeIPv4HeaderToOSRawSocket(header) + packet.writeImmutableBuffer(message) + try channel.writeAndFlush(AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0), + data: packet + )).wait() + } + + let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in + var data = envelop.data + let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket()) + XCTAssertEqual(header.version, 4) + XCTAssertEqual(header.protocol, .reservedForTesting) + XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes) + XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)") + XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1)) + XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1)) + return String(buffer: data) + }) + + XCTAssertEqual(receivedMessages, Set(expectedMessages)) + } +}