Add `RawSocketBootstrap` (#2320)

This commit is contained in:
David Nadoba 2022-12-01 15:35:04 +01:00 committed by GitHub
parent 00341c9277
commit 810544ec41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 786 additions and 3 deletions

View File

@ -267,6 +267,13 @@ extension NIOBSDSocket.Option {
/// Control multicast time-to-live. /// Control multicast time-to-live.
public static let ip_multicast_ttl: NIOBSDSocket.Option = public static let ip_multicast_ttl: NIOBSDSocket.Option =
NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL) NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL)
/// The IPv4 layer generates an IP header when sending a packet
/// unless the ``ip_hdrincl`` socket option is enabled on the socket.
/// When it is enabled, the packet must contain an IP header. For
/// receiving, the IP header is always included in the packet.
public static let ip_hdrincl: NIOBSDSocket.Option =
NIOBSDSocket.Option(rawValue: IP_HDRINCL)
} }
// IPv6 Options // IPv6 Options

View File

@ -276,6 +276,11 @@ public struct ChannelOptions {
public static let socketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in public static let socketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
.init(level: .socket, name: name) .init(level: .socket, name: name)
} }
/// - seealso: `SocketOption`.
public static let ipOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
.init(level: .ip, name: name)
}
/// - seealso: `SocketOption`. /// - seealso: `SocketOption`.
public static let tcpOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in public static let tcpOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in

View File

@ -83,6 +83,14 @@ extension NIOBSDSocket.SocketType {
internal static let stream: NIOBSDSocket.SocketType = internal static let stream: NIOBSDSocket.SocketType =
NIOBSDSocket.SocketType(rawValue: SOCK_STREAM) NIOBSDSocket.SocketType(rawValue: SOCK_STREAM)
#endif #endif
#if os(Linux)
internal static let raw: NIOBSDSocket.SocketType =
NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue))
#else
internal static let raw: NIOBSDSocket.SocketType =
NIOBSDSocket.SocketType(rawValue: SOCK_RAW)
#endif
} }
// IPv4 Options // IPv4 Options

View File

@ -0,0 +1,197 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
/// A `RawSocketBootstrap` is an easy way to interact with IP based protocols other then TCP and UDP.
///
/// Example:
///
/// ```swift
/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
/// defer {
/// try! group.syncShutdownGracefully()
/// }
/// let bootstrap = RawSocketBootstrap(group: group)
/// .channelInitializer { channel in
/// channel.pipeline.addHandler(MyChannelHandler())
/// }
/// let channel = try! bootstrap.bind(host: "127.0.0.1", ipProtocol: .icmp).wait()
/// /* the Channel is now ready to send/receive IP packets */
///
/// try channel.closeFuture.wait() // Wait until the channel un-binds.
/// ```
///
/// The `Channel` will operate on `AddressedEnvelope<ByteBuffer>` as inbound and outbound messages.
public final class NIORawSocketBootstrap {
private let group: EventLoopGroup
private var channelInitializer: Optional<ChannelInitializerCallback>
@usableFromInline
internal var _channelOptions: ChannelOptions.Storage
/// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`.
///
/// The `EventLoopGroup` `group` must be compatible, otherwise the program will crash. `RawSocketBootstrap` is
/// compatible only with `MultiThreadedEventLoopGroup` as well as the `EventLoop`s returned by
/// `MultiThreadedEventLoopGroup.next`. See `init(validatingGroup:)` for a fallible initializer for
/// situations where it's impossible to tell ahead of time if the `EventLoopGroup` is compatible or not.
///
/// - parameters:
/// - group: The `EventLoopGroup` to use.
public convenience init(group: EventLoopGroup) {
guard NIOOnSocketsBootstraps.isCompatible(group: group) else {
preconditionFailure("RawSocketBootstrap is only compatible with MultiThreadedEventLoopGroup and " +
"SelectableEventLoop. You tried constructing one with \(group) which is incompatible.")
}
self.init(validatingGroup: group)!
}
/// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`, validating that `group` is compatible.
///
/// - parameters:
/// - group: The `EventLoopGroup` to use.
public init?(validatingGroup group: EventLoopGroup) {
guard NIOOnSocketsBootstraps.isCompatible(group: group) else {
return nil
}
self._channelOptions = ChannelOptions.Storage()
self.group = group
self.channelInitializer = nil
}
/// Initialize the bound `Channel` with `initializer`. The most common task in initializer is to add
/// `ChannelHandler`s to the `ChannelPipeline`.
///
/// - parameters:
/// - handler: A closure that initializes the provided `Channel`.
public func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self {
self.channelInitializer = handler
return self
}
/// Specifies a `ChannelOption` to be applied to the `Channel`.
///
/// - parameters:
/// - option: The option to be applied.
/// - value: The value for the option.
@inlinable
public func channelOption<Option: ChannelOption>(_ option: Option, value: Option.Value) -> Self {
self._channelOptions.append(key: option, value: value)
return self
}
/// Bind the `Channel` to `host`.
/// All packets or errors matching the `ipProtocol` specified are passed to the resulting `Channel`.
///
/// - parameters:
/// - host: The host to bind on.
/// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field.
public func bind(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture<Channel> {
return bind0(ipProtocol: ipProtocol) {
return try SocketAddress.makeAddressResolvingHost(host, port: 0)
}
}
private func bind0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
let address: SocketAddress
do {
address = try makeSocketAddress()
} catch {
return group.next().makeFailedFuture(error)
}
precondition(address.port == nil || address.port == 0, "port must be 0 or not set")
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
return try DatagramChannel(eventLoop: eventLoop,
protocolFamily: address.protocol,
protocolSubtype: .init(ipProtocol),
socketType: .raw)
}
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
channel.register().flatMap {
channel.bind(to: address)
}
}
}
/// Connect the `Channel` to `host`.
///
/// - parameters:
/// - host: The host to connect to.
/// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field.
public func connect(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture<Channel> {
return connect0(ipProtocol: ipProtocol) {
return try SocketAddress.makeAddressResolvingHost(host, port: 0)
}
}
private func connect0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
let address: SocketAddress
do {
address = try makeSocketAddress()
} catch {
return group.next().makeFailedFuture(error)
}
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
return try DatagramChannel(eventLoop: eventLoop,
protocolFamily: address.protocol,
protocolSubtype: .init(ipProtocol),
socketType: .raw)
}
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
channel.register().flatMap {
channel.connect(to: address)
}
}
}
private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let eventLoop = self.group.next()
let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) }
let channelOptions = self._channelOptions
let channel: DatagramChannel
do {
channel = try makeChannel(eventLoop as! SelectableEventLoop)
} catch {
return eventLoop.makeFailedFuture(error)
}
func setupChannel() -> EventLoopFuture<Channel> {
eventLoop.assertInEventLoop()
return channelOptions.applyAllChannelOptions(to: channel).flatMap {
channelInitializer(channel)
}.flatMap {
eventLoop.assertInEventLoop()
return bringup(eventLoop, channel)
}.map {
channel
}.flatMapError { error in
eventLoop.makeFailedFuture(error)
}
}
if eventLoop.inEventLoop {
return setupChannel()
} else {
return eventLoop.flatSubmit {
setupChannel()
}
}
}
}
#if swift(>=5.6)
@available(*, unavailable)
extension NIORawSocketBootstrap: Sendable {}
#endif

View File

@ -32,7 +32,7 @@ typealias IOVector = iovec
/// - parameters: /// - parameters:
/// - protocolFamily: The protocol family to use (usually `AF_INET6` or `AF_INET`). /// - protocolFamily: The protocol family to use (usually `AF_INET6` or `AF_INET`).
/// - type: The type of the socket to create. /// - type: The type of the socket to create.
/// - protocolSubtype: The subtype of the protocol, corresponding to the `protocol` /// - protocolSubtype: The subtype of the protocol, corresponding to the `protocolSubtype`
/// argument to the socket syscall. Defaults to 0. /// argument to the socket syscall. Defaults to 0.
/// - setNonBlocking: Set non-blocking mode on the socket. /// - setNonBlocking: Set non-blocking mode on the socket.
/// - throws: An `IOError` if creation of the socket failed. /// - throws: An `IOError` if creation of the socket failed.

View File

@ -120,6 +120,7 @@ class LinuxMainRunner {
testCase(PendingDatagramWritesManagerTests.allTests), testCase(PendingDatagramWritesManagerTests.allTests),
testCase(PipeChannelTest.allTests), testCase(PipeChannelTest.allTests),
testCase(PriorityQueueTest.allTests), testCase(PriorityQueueTest.allTests),
testCase(RawSocketBootstrapTests.allTests),
testCase(SALChannelTest.allTests), testCase(SALChannelTest.allTests),
testCase(SALEventLoopTests.allTests), testCase(SALEventLoopTests.allTests),
testCase(SNIHandlerTest.allTests), testCase(SNIHandlerTest.allTests),

View File

@ -17,7 +17,7 @@ import NIOCore
@testable import NIOPosix @testable import NIOPosix
import XCTest import XCTest
private extension Channel { extension Channel {
func waitForDatagrams(count: Int) throws -> [AddressedEnvelope<ByteBuffer>] { func waitForDatagrams(count: Int) throws -> [AddressedEnvelope<ByteBuffer>] {
return try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in return try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in
if let future = (context.handler as? DatagramReadRecorder<ByteBuffer>)?.notifyForDatagrams(count) { if let future = (context.handler as? DatagramReadRecorder<ByteBuffer>)?.notifyForDatagrams(count) {
@ -47,7 +47,7 @@ private extension Channel {
/// A class that records datagrams received and forwards them on. /// A class that records datagrams received and forwards them on.
/// ///
/// Used extensively in tests to validate messaging expectations. /// Used extensively in tests to validate messaging expectations.
private class DatagramReadRecorder<DataType>: ChannelInboundHandler { final class DatagramReadRecorder<DataType>: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<DataType> typealias InboundIn = AddressedEnvelope<DataType>
typealias InboundOut = AddressedEnvelope<DataType> typealias InboundOut = AddressedEnvelope<DataType>

View File

@ -0,0 +1,365 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
struct IPv4Address: Hashable {
var rawValue: UInt32
}
extension IPv4Address {
init(_ v1: UInt8, _ v2: UInt8, _ v3: UInt8, _ v4: UInt8) {
rawValue = UInt32(v1) << 24 | UInt32(v2) << 16 | UInt32(v3) << 8 | UInt32(v4)
}
}
extension IPv4Address: CustomStringConvertible {
var description: String {
let v1 = rawValue >> 24
let v2 = rawValue >> 16 & 0b1111_1111
let v3 = rawValue >> 8 & 0b1111_1111
let v4 = rawValue & 0b1111_1111
return "\(v1).\(v2).\(v3).\(v4)"
}
}
struct IPv4Header: Hashable {
static let size: Int = 20
fileprivate var versionAndIhl: UInt8
var version: UInt8 {
get {
versionAndIhl >> 4
}
set {
precondition(newValue & 0b1111_0000 == 0)
versionAndIhl = newValue << 4 | (0b0000_1111 & versionAndIhl)
assert(newValue == version, "\(newValue) != \(version) \(versionAndIhl)")
}
}
var internetHeaderLength: UInt8 {
get {
versionAndIhl & 0b0000_1111
}
set {
precondition(newValue & 0b1111_0000 == 0)
versionAndIhl = newValue | (0b1111_0000 & versionAndIhl)
assert(newValue == internetHeaderLength)
}
}
fileprivate var dscpAndEcn: UInt8
var differentiatedServicesCodePoint: UInt8 {
get {
dscpAndEcn >> 2
}
set {
precondition(newValue & 0b0000_0011 == 0)
dscpAndEcn = newValue << 2 | (0b0000_0011 & dscpAndEcn)
assert(newValue == differentiatedServicesCodePoint)
}
}
var explicitCongestionNotification: UInt8 {
get {
dscpAndEcn & 0b0000_0011
}
set {
precondition(newValue & 0b0000_0011 == 0)
dscpAndEcn = newValue | (0b1111_1100 & dscpAndEcn)
assert(newValue == explicitCongestionNotification)
}
}
var totalLength: UInt16
var identification: UInt16
fileprivate var flagsAndFragmentOffset: UInt16
var flags: UInt8 {
get {
UInt8(flagsAndFragmentOffset >> 13)
}
set {
precondition(newValue & 0b0000_0111 == 0)
flagsAndFragmentOffset = UInt16(newValue) << 13 | (0b0001_1111_1111_1111 & flagsAndFragmentOffset)
assert(newValue == flags)
}
}
var fragmentOffset: UInt16 {
get {
flagsAndFragmentOffset & 0b0001_1111_1111_1111
}
set {
precondition(newValue & 0b1110_0000_0000_0000 == 0)
flagsAndFragmentOffset = newValue | (0b1110_0000_0000_0000 & flagsAndFragmentOffset)
assert(newValue == fragmentOffset)
}
}
var timeToLive: UInt8
var `protocol`: NIOIPProtocol
var headerChecksum: UInt16
var sourceIpAddress: IPv4Address
var destinationIpAddress: IPv4Address
fileprivate init(
versionAndIhl: UInt8,
dscpAndEcn: UInt8,
totalLength: UInt16,
identification: UInt16,
flagsAndFragmentOffset: UInt16,
timeToLive: UInt8,
`protocol`: NIOIPProtocol,
headerChecksum: UInt16,
sourceIpAddress: IPv4Address,
destinationIpAddress: IPv4Address
) {
self.versionAndIhl = versionAndIhl
self.dscpAndEcn = dscpAndEcn
self.totalLength = totalLength
self.identification = identification
self.flagsAndFragmentOffset = flagsAndFragmentOffset
self.timeToLive = timeToLive
self.`protocol` = `protocol`
self.headerChecksum = headerChecksum
self.sourceIpAddress = sourceIpAddress
self.destinationIpAddress = destinationIpAddress
}
init() {
self.versionAndIhl = 0
self.dscpAndEcn = 0
self.totalLength = 0
self.identification = 0
self.flagsAndFragmentOffset = 0
self.timeToLive = 0
self.`protocol` = .init(rawValue: 0)
self.headerChecksum = 0
self.sourceIpAddress = .init(rawValue: 0)
self.destinationIpAddress = .init(rawValue: 0)
}
}
extension FixedWidthInteger {
func convertEndianness(to endianness: Endianness) -> Self {
switch endianness {
case .little:
return self.littleEndian
case .big:
return self.bigEndian
}
}
}
extension ByteBuffer {
mutating func readIPv4Header() -> IPv4Header? {
guard let (
versionAndIhl,
dscpAndEcn,
totalLength,
identification,
flagsAndFragmentOffset,
timeToLive,
`protocol`,
headerChecksum,
sourceIpAddress,
destinationIpAddress
) = self.readMultipleIntegers(as: (
UInt8,
UInt8,
UInt16,
UInt16,
UInt16,
UInt8,
UInt8,
UInt16,
UInt32,
UInt32
).self) else { return nil }
return .init(
versionAndIhl: versionAndIhl,
dscpAndEcn: dscpAndEcn,
totalLength: totalLength,
identification: identification,
flagsAndFragmentOffset: flagsAndFragmentOffset,
timeToLive: timeToLive,
protocol: .init(rawValue: `protocol`),
headerChecksum: headerChecksum,
sourceIpAddress: .init(rawValue: sourceIpAddress),
destinationIpAddress: .init(rawValue: destinationIpAddress)
)
}
mutating func readIPv4HeaderFromBSDRawSocket() -> IPv4Header? {
guard var header = self.readIPv4Header() else { return nil }
// On BSD, the total length is in host byte order
header.totalLength = header.totalLength.convertEndianness(to: .big)
// TODO: fragmentOffset is in host byte order as well but it is always zero in our tests
// and fragmentOffset is 13 bits in size so we can't just use readInteger(endianness: .host)
return header
}
mutating func readIPv4HeaderFromOSRawSocket() -> IPv4Header? {
#if canImport(Darwin)
return self.readIPv4HeaderFromBSDRawSocket()
#else
return self.readIPv4Header()
#endif
}
}
extension ByteBuffer {
@discardableResult
mutating func writeIPv4Header(_ header: IPv4Header) -> Int {
assert({
var buffer = ByteBuffer()
buffer._writeIPv4Header(header)
let writtenHeader = buffer.readIPv4Header()
return header == writtenHeader
}())
return self._writeIPv4Header(header)
}
@discardableResult
private mutating func _writeIPv4Header(_ header: IPv4Header) -> Int {
return self.writeMultipleIntegers(
header.versionAndIhl,
header.dscpAndEcn,
header.totalLength,
header.identification,
header.flagsAndFragmentOffset,
header.timeToLive,
header.`protocol`.rawValue,
header.headerChecksum,
header.sourceIpAddress.rawValue,
header.destinationIpAddress.rawValue
)
}
@discardableResult
mutating func writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int {
assert({
var buffer = ByteBuffer()
buffer._writeIPv4HeaderToBSDRawSocket(header)
let writtenHeader = buffer.readIPv4HeaderFromBSDRawSocket()
return header == writtenHeader
}())
return self._writeIPv4HeaderToBSDRawSocket(header)
}
@discardableResult
private mutating func _writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int {
var header = header
// On BSD, the total length needs to be in host byte order
header.totalLength = header.totalLength.convertEndianness(to: .big)
// TODO: fragmentOffset needs to be in host byte order as well but it is always zero in our tests
// and fragmentOffset is 13 bits in size so we can't just use writeInteger(endianness: .host)
return self._writeIPv4Header(header)
}
@discardableResult
mutating func writeIPv4HeaderToOSRawSocket(_ header: IPv4Header) -> Int {
#if canImport(Darwin)
self.writeIPv4HeaderToBSDRawSocket(header)
#else
self.writeIPv4Header(header)
#endif
}
}
extension IPv4Header {
func computeChecksum() -> UInt16 {
let checksum = ~[
UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn),
totalLength,
identification,
flagsAndFragmentOffset,
UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue),
UInt16(sourceIpAddress.rawValue >> 16),
UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
UInt16(destinationIpAddress.rawValue >> 16),
UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
].reduce(UInt16(0), onesComplementAdd)
assert(isValidChecksum(checksum))
return checksum
}
mutating func setChecksum() {
self.headerChecksum = computeChecksum()
}
func isValidChecksum(_ headerChecksum: UInt16) -> Bool {
let sum = ~[
UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn),
totalLength,
identification,
flagsAndFragmentOffset,
UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue),
headerChecksum,
UInt16(sourceIpAddress.rawValue >> 16),
UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
UInt16(destinationIpAddress.rawValue >> 16),
UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
].reduce(UInt16(0), onesComplementAdd)
return sum == 0
}
func isValidChecksum() -> Bool {
isValidChecksum(headerChecksum)
}
}
private func onesComplementAdd<Integer: FixedWidthInteger>(lhs: Integer, rhs: Integer) -> Integer {
var (sum, overflowed) = lhs.addingReportingOverflow(rhs)
if overflowed {
sum &+= 1
}
return sum
}
extension IPv4Header {
var platformIndependentTotalLengthForReceivedPacketFromRawSocket: UInt16 {
#if canImport(Darwin)
// On BSD the IP header will only contain the size of the ip packet body, not the header.
// This is known bug which can't be fixed without breaking old apps which already workaround the issue
// like e.g. we do now too.
return totalLength + 20
#elseif os(Linux)
return totalLength
#endif
}
var platformIndependentChecksumForReceivedPacketFromRawSocket: UInt16 {
#if canImport(Darwin)
// On BSD the checksum is always zero and we need to compute it
precondition(headerChecksum == 0)
return computeChecksum()
#elseif os(Linux)
return headerChecksum
#endif
}
}
extension IPv4Header: CustomStringConvertible {
var description: String {
"""
Version: \(version)
Header Length: \(internetHeaderLength * 4) bytes
Differentiated Services: \(String(differentiatedServicesCodePoint, radix: 2))
Explicit Congestion Notification: \(String(explicitCongestionNotification, radix: 2))
Total Length: \(totalLength) bytes
Identification: \(identification)
Flags: \(String(flags, radix: 2))
Fragment Offset: \(fragmentOffset) bytes
Time to Live: \(timeToLive)
Protocol: \(`protocol`)
Header Checksum: \(headerChecksum) (\(isValidChecksum() ? "valid" : "*not* valid"))
Source IP Address: \(sourceIpAddress)
Destination IP Address: \(destinationIpAddress)
"""
}
}

View File

@ -0,0 +1,36 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//
// RawSocketBootstrapTests+XCTest.swift
//
import XCTest
///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///
extension RawSocketBootstrapTests {
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
static var allTests : [(String, (RawSocketBootstrapTests) -> () throws -> Void)] {
return [
("testBindWithRecevMmsg", testBindWithRecevMmsg),
("testConnect", testConnect),
("testIpHdrincl", testIpHdrincl),
]
}
}

View File

@ -0,0 +1,164 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
import NIOCore
@testable import NIOPosix
extension NIOIPProtocol {
static let reservedForTesting = Self(rawValue: 253)
}
// lazily try's to create a raw socket and caches the error if it fails
private let cachedRawSocketAPICheck = Result<Void, Error> {
let socket = try Socket(protocolFamily: .inet, type: .raw, protocolSubtype: .init(NIOIPProtocol.reservedForTesting), setNonBlocking: true)
try socket.close()
}
func XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI(file: StaticString = #filePath, line: UInt = #line) throws {
do {
try cachedRawSocketAPICheck.get()
} catch let error as IOError where error.errnoCode == EPERM {
throw XCTSkip("Raw Socket API requires higher privileges: \(error)", file: file, line: line)
}
}
final class RawSocketBootstrapTests: XCTestCase {
func testBindWithRecevMmsg() throws {
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
let channel = try NIORawSocketBootstrap(group: elg)
.channelInitializer {
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
}
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
defer { XCTAssertNoThrow(try channel.close().wait()) }
try channel.configureForRecvMmsg(messageCount: 10)
let expectedMessages = (1...10).map { "Hello World \($0)" }
for message in expectedMessages {
_ = try channel.write(AddressedEnvelope(
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
data: ByteBuffer(string: message)
))
}
channel.flush()
let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in
var data = envelop.data
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
XCTAssertEqual(header.version, 4)
XCTAssertEqual(header.protocol, .reservedForTesting)
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
return String(buffer: data)
})
XCTAssertEqual(receivedMessages, Set(expectedMessages))
}
func testConnect() throws {
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
let readChannel = try NIORawSocketBootstrap(group: elg)
.channelInitializer {
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
}
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
defer { XCTAssertNoThrow(try readChannel.close().wait()) }
let writeChannel = try NIORawSocketBootstrap(group: elg)
.channelInitializer {
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
}
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
defer { XCTAssertNoThrow(try writeChannel.close().wait()) }
let expectedMessages = (1...10).map { "Hello World \($0)" }
for message in expectedMessages {
_ = try writeChannel.write(AddressedEnvelope(
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
data: ByteBuffer(string: message)
))
}
writeChannel.flush()
let receivedMessages = Set(try readChannel.waitForDatagrams(count: 10).map { envelop -> String in
var data = envelop.data
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
XCTAssertEqual(header.version, 4)
XCTAssertEqual(header.protocol, .reservedForTesting)
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
return String(buffer: data)
})
XCTAssertEqual(receivedMessages, Set(expectedMessages))
}
func testIpHdrincl() throws {
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
let channel = try NIORawSocketBootstrap(group: elg)
.channelOption(ChannelOptions.ipOption(.ip_hdrincl), value: 1)
.channelInitializer {
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
}
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
defer { XCTAssertNoThrow(try channel.close().wait()) }
try channel.configureForRecvMmsg(messageCount: 10)
let expectedMessages = (1...10).map { "Hello World \($0)" }
for message in expectedMessages.map(ByteBuffer.init(string:)) {
var packet = ByteBuffer()
var header = IPv4Header()
header.version = 4
header.internetHeaderLength = 5
header.totalLength = UInt16(IPv4Header.size + message.readableBytes)
header.protocol = .reservedForTesting
header.timeToLive = 64
header.destinationIpAddress = .init(127, 0, 0, 1)
header.sourceIpAddress = .init(127, 0, 0, 1)
header.setChecksum()
packet.writeIPv4HeaderToOSRawSocket(header)
packet.writeImmutableBuffer(message)
try channel.writeAndFlush(AddressedEnvelope(
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
data: packet
)).wait()
}
let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in
var data = envelop.data
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
XCTAssertEqual(header.version, 4)
XCTAssertEqual(header.protocol, .reservedForTesting)
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
return String(buffer: data)
})
XCTAssertEqual(receivedMessages, Set(expectedMessages))
}
}