Add `RawSocketBootstrap` (#2320)
This commit is contained in:
parent
00341c9277
commit
810544ec41
|
@ -267,6 +267,13 @@ extension NIOBSDSocket.Option {
|
||||||
/// Control multicast time-to-live.
|
/// Control multicast time-to-live.
|
||||||
public static let ip_multicast_ttl: NIOBSDSocket.Option =
|
public static let ip_multicast_ttl: NIOBSDSocket.Option =
|
||||||
NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL)
|
NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL)
|
||||||
|
|
||||||
|
/// The IPv4 layer generates an IP header when sending a packet
|
||||||
|
/// unless the ``ip_hdrincl`` socket option is enabled on the socket.
|
||||||
|
/// When it is enabled, the packet must contain an IP header. For
|
||||||
|
/// receiving, the IP header is always included in the packet.
|
||||||
|
public static let ip_hdrincl: NIOBSDSocket.Option =
|
||||||
|
NIOBSDSocket.Option(rawValue: IP_HDRINCL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv6 Options
|
// IPv6 Options
|
||||||
|
|
|
@ -276,6 +276,11 @@ public struct ChannelOptions {
|
||||||
public static let socketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
public static let socketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
||||||
.init(level: .socket, name: name)
|
.init(level: .socket, name: name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// - seealso: `SocketOption`.
|
||||||
|
public static let ipOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
||||||
|
.init(level: .ip, name: name)
|
||||||
|
}
|
||||||
|
|
||||||
/// - seealso: `SocketOption`.
|
/// - seealso: `SocketOption`.
|
||||||
public static let tcpOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
public static let tcpOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
||||||
|
|
|
@ -83,6 +83,14 @@ extension NIOBSDSocket.SocketType {
|
||||||
internal static let stream: NIOBSDSocket.SocketType =
|
internal static let stream: NIOBSDSocket.SocketType =
|
||||||
NIOBSDSocket.SocketType(rawValue: SOCK_STREAM)
|
NIOBSDSocket.SocketType(rawValue: SOCK_STREAM)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if os(Linux)
|
||||||
|
internal static let raw: NIOBSDSocket.SocketType =
|
||||||
|
NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue))
|
||||||
|
#else
|
||||||
|
internal static let raw: NIOBSDSocket.SocketType =
|
||||||
|
NIOBSDSocket.SocketType(rawValue: SOCK_RAW)
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv4 Options
|
// IPv4 Options
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
import NIOCore
|
||||||
|
|
||||||
|
/// A `RawSocketBootstrap` is an easy way to interact with IP based protocols other then TCP and UDP.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||||
|
/// defer {
|
||||||
|
/// try! group.syncShutdownGracefully()
|
||||||
|
/// }
|
||||||
|
/// let bootstrap = RawSocketBootstrap(group: group)
|
||||||
|
/// .channelInitializer { channel in
|
||||||
|
/// channel.pipeline.addHandler(MyChannelHandler())
|
||||||
|
/// }
|
||||||
|
/// let channel = try! bootstrap.bind(host: "127.0.0.1", ipProtocol: .icmp).wait()
|
||||||
|
/// /* the Channel is now ready to send/receive IP packets */
|
||||||
|
///
|
||||||
|
/// try channel.closeFuture.wait() // Wait until the channel un-binds.
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The `Channel` will operate on `AddressedEnvelope<ByteBuffer>` as inbound and outbound messages.
|
||||||
|
public final class NIORawSocketBootstrap {
|
||||||
|
|
||||||
|
private let group: EventLoopGroup
|
||||||
|
private var channelInitializer: Optional<ChannelInitializerCallback>
|
||||||
|
@usableFromInline
|
||||||
|
internal var _channelOptions: ChannelOptions.Storage
|
||||||
|
|
||||||
|
/// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`.
|
||||||
|
///
|
||||||
|
/// The `EventLoopGroup` `group` must be compatible, otherwise the program will crash. `RawSocketBootstrap` is
|
||||||
|
/// compatible only with `MultiThreadedEventLoopGroup` as well as the `EventLoop`s returned by
|
||||||
|
/// `MultiThreadedEventLoopGroup.next`. See `init(validatingGroup:)` for a fallible initializer for
|
||||||
|
/// situations where it's impossible to tell ahead of time if the `EventLoopGroup` is compatible or not.
|
||||||
|
///
|
||||||
|
/// - parameters:
|
||||||
|
/// - group: The `EventLoopGroup` to use.
|
||||||
|
public convenience init(group: EventLoopGroup) {
|
||||||
|
guard NIOOnSocketsBootstraps.isCompatible(group: group) else {
|
||||||
|
preconditionFailure("RawSocketBootstrap is only compatible with MultiThreadedEventLoopGroup and " +
|
||||||
|
"SelectableEventLoop. You tried constructing one with \(group) which is incompatible.")
|
||||||
|
}
|
||||||
|
self.init(validatingGroup: group)!
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`, validating that `group` is compatible.
|
||||||
|
///
|
||||||
|
/// - parameters:
|
||||||
|
/// - group: The `EventLoopGroup` to use.
|
||||||
|
public init?(validatingGroup group: EventLoopGroup) {
|
||||||
|
guard NIOOnSocketsBootstraps.isCompatible(group: group) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
self._channelOptions = ChannelOptions.Storage()
|
||||||
|
self.group = group
|
||||||
|
self.channelInitializer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize the bound `Channel` with `initializer`. The most common task in initializer is to add
|
||||||
|
/// `ChannelHandler`s to the `ChannelPipeline`.
|
||||||
|
///
|
||||||
|
/// - parameters:
|
||||||
|
/// - handler: A closure that initializes the provided `Channel`.
|
||||||
|
public func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self {
|
||||||
|
self.channelInitializer = handler
|
||||||
|
return self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Specifies a `ChannelOption` to be applied to the `Channel`.
|
||||||
|
///
|
||||||
|
/// - parameters:
|
||||||
|
/// - option: The option to be applied.
|
||||||
|
/// - value: The value for the option.
|
||||||
|
@inlinable
|
||||||
|
public func channelOption<Option: ChannelOption>(_ option: Option, value: Option.Value) -> Self {
|
||||||
|
self._channelOptions.append(key: option, value: value)
|
||||||
|
return self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bind the `Channel` to `host`.
|
||||||
|
/// All packets or errors matching the `ipProtocol` specified are passed to the resulting `Channel`.
|
||||||
|
///
|
||||||
|
/// - parameters:
|
||||||
|
/// - host: The host to bind on.
|
||||||
|
/// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field.
|
||||||
|
public func bind(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture<Channel> {
|
||||||
|
return bind0(ipProtocol: ipProtocol) {
|
||||||
|
return try SocketAddress.makeAddressResolvingHost(host, port: 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func bind0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
|
||||||
|
let address: SocketAddress
|
||||||
|
do {
|
||||||
|
address = try makeSocketAddress()
|
||||||
|
} catch {
|
||||||
|
return group.next().makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
precondition(address.port == nil || address.port == 0, "port must be 0 or not set")
|
||||||
|
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
|
||||||
|
return try DatagramChannel(eventLoop: eventLoop,
|
||||||
|
protocolFamily: address.protocol,
|
||||||
|
protocolSubtype: .init(ipProtocol),
|
||||||
|
socketType: .raw)
|
||||||
|
}
|
||||||
|
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
|
||||||
|
channel.register().flatMap {
|
||||||
|
channel.bind(to: address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect the `Channel` to `host`.
|
||||||
|
///
|
||||||
|
/// - parameters:
|
||||||
|
/// - host: The host to connect to.
|
||||||
|
/// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field.
|
||||||
|
public func connect(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture<Channel> {
|
||||||
|
return connect0(ipProtocol: ipProtocol) {
|
||||||
|
return try SocketAddress.makeAddressResolvingHost(host, port: 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func connect0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
|
||||||
|
let address: SocketAddress
|
||||||
|
do {
|
||||||
|
address = try makeSocketAddress()
|
||||||
|
} catch {
|
||||||
|
return group.next().makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
|
||||||
|
return try DatagramChannel(eventLoop: eventLoop,
|
||||||
|
protocolFamily: address.protocol,
|
||||||
|
protocolSubtype: .init(ipProtocol),
|
||||||
|
socketType: .raw)
|
||||||
|
}
|
||||||
|
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
|
||||||
|
channel.register().flatMap {
|
||||||
|
channel.connect(to: address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
|
||||||
|
let eventLoop = self.group.next()
|
||||||
|
let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) }
|
||||||
|
let channelOptions = self._channelOptions
|
||||||
|
|
||||||
|
let channel: DatagramChannel
|
||||||
|
do {
|
||||||
|
channel = try makeChannel(eventLoop as! SelectableEventLoop)
|
||||||
|
} catch {
|
||||||
|
return eventLoop.makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupChannel() -> EventLoopFuture<Channel> {
|
||||||
|
eventLoop.assertInEventLoop()
|
||||||
|
return channelOptions.applyAllChannelOptions(to: channel).flatMap {
|
||||||
|
channelInitializer(channel)
|
||||||
|
}.flatMap {
|
||||||
|
eventLoop.assertInEventLoop()
|
||||||
|
return bringup(eventLoop, channel)
|
||||||
|
}.map {
|
||||||
|
channel
|
||||||
|
}.flatMapError { error in
|
||||||
|
eventLoop.makeFailedFuture(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventLoop.inEventLoop {
|
||||||
|
return setupChannel()
|
||||||
|
} else {
|
||||||
|
return eventLoop.flatSubmit {
|
||||||
|
setupChannel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if swift(>=5.6)
|
||||||
|
@available(*, unavailable)
|
||||||
|
extension NIORawSocketBootstrap: Sendable {}
|
||||||
|
#endif
|
|
@ -32,7 +32,7 @@ typealias IOVector = iovec
|
||||||
/// - parameters:
|
/// - parameters:
|
||||||
/// - protocolFamily: The protocol family to use (usually `AF_INET6` or `AF_INET`).
|
/// - protocolFamily: The protocol family to use (usually `AF_INET6` or `AF_INET`).
|
||||||
/// - type: The type of the socket to create.
|
/// - type: The type of the socket to create.
|
||||||
/// - protocolSubtype: The subtype of the protocol, corresponding to the `protocol`
|
/// - protocolSubtype: The subtype of the protocol, corresponding to the `protocolSubtype`
|
||||||
/// argument to the socket syscall. Defaults to 0.
|
/// argument to the socket syscall. Defaults to 0.
|
||||||
/// - setNonBlocking: Set non-blocking mode on the socket.
|
/// - setNonBlocking: Set non-blocking mode on the socket.
|
||||||
/// - throws: An `IOError` if creation of the socket failed.
|
/// - throws: An `IOError` if creation of the socket failed.
|
||||||
|
|
|
@ -120,6 +120,7 @@ class LinuxMainRunner {
|
||||||
testCase(PendingDatagramWritesManagerTests.allTests),
|
testCase(PendingDatagramWritesManagerTests.allTests),
|
||||||
testCase(PipeChannelTest.allTests),
|
testCase(PipeChannelTest.allTests),
|
||||||
testCase(PriorityQueueTest.allTests),
|
testCase(PriorityQueueTest.allTests),
|
||||||
|
testCase(RawSocketBootstrapTests.allTests),
|
||||||
testCase(SALChannelTest.allTests),
|
testCase(SALChannelTest.allTests),
|
||||||
testCase(SALEventLoopTests.allTests),
|
testCase(SALEventLoopTests.allTests),
|
||||||
testCase(SNIHandlerTest.allTests),
|
testCase(SNIHandlerTest.allTests),
|
||||||
|
|
|
@ -17,7 +17,7 @@ import NIOCore
|
||||||
@testable import NIOPosix
|
@testable import NIOPosix
|
||||||
import XCTest
|
import XCTest
|
||||||
|
|
||||||
private extension Channel {
|
extension Channel {
|
||||||
func waitForDatagrams(count: Int) throws -> [AddressedEnvelope<ByteBuffer>] {
|
func waitForDatagrams(count: Int) throws -> [AddressedEnvelope<ByteBuffer>] {
|
||||||
return try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in
|
return try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in
|
||||||
if let future = (context.handler as? DatagramReadRecorder<ByteBuffer>)?.notifyForDatagrams(count) {
|
if let future = (context.handler as? DatagramReadRecorder<ByteBuffer>)?.notifyForDatagrams(count) {
|
||||||
|
@ -47,7 +47,7 @@ private extension Channel {
|
||||||
/// A class that records datagrams received and forwards them on.
|
/// A class that records datagrams received and forwards them on.
|
||||||
///
|
///
|
||||||
/// Used extensively in tests to validate messaging expectations.
|
/// Used extensively in tests to validate messaging expectations.
|
||||||
private class DatagramReadRecorder<DataType>: ChannelInboundHandler {
|
final class DatagramReadRecorder<DataType>: ChannelInboundHandler {
|
||||||
typealias InboundIn = AddressedEnvelope<DataType>
|
typealias InboundIn = AddressedEnvelope<DataType>
|
||||||
typealias InboundOut = AddressedEnvelope<DataType>
|
typealias InboundOut = AddressedEnvelope<DataType>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,365 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
import NIOCore
|
||||||
|
|
||||||
|
struct IPv4Address: Hashable {
|
||||||
|
var rawValue: UInt32
|
||||||
|
}
|
||||||
|
|
||||||
|
extension IPv4Address {
|
||||||
|
init(_ v1: UInt8, _ v2: UInt8, _ v3: UInt8, _ v4: UInt8) {
|
||||||
|
rawValue = UInt32(v1) << 24 | UInt32(v2) << 16 | UInt32(v3) << 8 | UInt32(v4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension IPv4Address: CustomStringConvertible {
|
||||||
|
var description: String {
|
||||||
|
let v1 = rawValue >> 24
|
||||||
|
let v2 = rawValue >> 16 & 0b1111_1111
|
||||||
|
let v3 = rawValue >> 8 & 0b1111_1111
|
||||||
|
let v4 = rawValue & 0b1111_1111
|
||||||
|
return "\(v1).\(v2).\(v3).\(v4)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct IPv4Header: Hashable {
|
||||||
|
static let size: Int = 20
|
||||||
|
|
||||||
|
fileprivate var versionAndIhl: UInt8
|
||||||
|
var version: UInt8 {
|
||||||
|
get {
|
||||||
|
versionAndIhl >> 4
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
precondition(newValue & 0b1111_0000 == 0)
|
||||||
|
versionAndIhl = newValue << 4 | (0b0000_1111 & versionAndIhl)
|
||||||
|
assert(newValue == version, "\(newValue) != \(version) \(versionAndIhl)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var internetHeaderLength: UInt8 {
|
||||||
|
get {
|
||||||
|
versionAndIhl & 0b0000_1111
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
precondition(newValue & 0b1111_0000 == 0)
|
||||||
|
versionAndIhl = newValue | (0b1111_0000 & versionAndIhl)
|
||||||
|
assert(newValue == internetHeaderLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fileprivate var dscpAndEcn: UInt8
|
||||||
|
var differentiatedServicesCodePoint: UInt8 {
|
||||||
|
get {
|
||||||
|
dscpAndEcn >> 2
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
precondition(newValue & 0b0000_0011 == 0)
|
||||||
|
dscpAndEcn = newValue << 2 | (0b0000_0011 & dscpAndEcn)
|
||||||
|
assert(newValue == differentiatedServicesCodePoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var explicitCongestionNotification: UInt8 {
|
||||||
|
get {
|
||||||
|
dscpAndEcn & 0b0000_0011
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
precondition(newValue & 0b0000_0011 == 0)
|
||||||
|
dscpAndEcn = newValue | (0b1111_1100 & dscpAndEcn)
|
||||||
|
assert(newValue == explicitCongestionNotification)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var totalLength: UInt16
|
||||||
|
var identification: UInt16
|
||||||
|
fileprivate var flagsAndFragmentOffset: UInt16
|
||||||
|
var flags: UInt8 {
|
||||||
|
get {
|
||||||
|
UInt8(flagsAndFragmentOffset >> 13)
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
precondition(newValue & 0b0000_0111 == 0)
|
||||||
|
flagsAndFragmentOffset = UInt16(newValue) << 13 | (0b0001_1111_1111_1111 & flagsAndFragmentOffset)
|
||||||
|
assert(newValue == flags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var fragmentOffset: UInt16 {
|
||||||
|
get {
|
||||||
|
flagsAndFragmentOffset & 0b0001_1111_1111_1111
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
precondition(newValue & 0b1110_0000_0000_0000 == 0)
|
||||||
|
flagsAndFragmentOffset = newValue | (0b1110_0000_0000_0000 & flagsAndFragmentOffset)
|
||||||
|
assert(newValue == fragmentOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var timeToLive: UInt8
|
||||||
|
var `protocol`: NIOIPProtocol
|
||||||
|
var headerChecksum: UInt16
|
||||||
|
var sourceIpAddress: IPv4Address
|
||||||
|
var destinationIpAddress: IPv4Address
|
||||||
|
|
||||||
|
fileprivate init(
|
||||||
|
versionAndIhl: UInt8,
|
||||||
|
dscpAndEcn: UInt8,
|
||||||
|
totalLength: UInt16,
|
||||||
|
identification: UInt16,
|
||||||
|
flagsAndFragmentOffset: UInt16,
|
||||||
|
timeToLive: UInt8,
|
||||||
|
`protocol`: NIOIPProtocol,
|
||||||
|
headerChecksum: UInt16,
|
||||||
|
sourceIpAddress: IPv4Address,
|
||||||
|
destinationIpAddress: IPv4Address
|
||||||
|
) {
|
||||||
|
self.versionAndIhl = versionAndIhl
|
||||||
|
self.dscpAndEcn = dscpAndEcn
|
||||||
|
self.totalLength = totalLength
|
||||||
|
self.identification = identification
|
||||||
|
self.flagsAndFragmentOffset = flagsAndFragmentOffset
|
||||||
|
self.timeToLive = timeToLive
|
||||||
|
self.`protocol` = `protocol`
|
||||||
|
self.headerChecksum = headerChecksum
|
||||||
|
self.sourceIpAddress = sourceIpAddress
|
||||||
|
self.destinationIpAddress = destinationIpAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
init() {
|
||||||
|
self.versionAndIhl = 0
|
||||||
|
self.dscpAndEcn = 0
|
||||||
|
self.totalLength = 0
|
||||||
|
self.identification = 0
|
||||||
|
self.flagsAndFragmentOffset = 0
|
||||||
|
self.timeToLive = 0
|
||||||
|
self.`protocol` = .init(rawValue: 0)
|
||||||
|
self.headerChecksum = 0
|
||||||
|
self.sourceIpAddress = .init(rawValue: 0)
|
||||||
|
self.destinationIpAddress = .init(rawValue: 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension FixedWidthInteger {
|
||||||
|
func convertEndianness(to endianness: Endianness) -> Self {
|
||||||
|
switch endianness {
|
||||||
|
case .little:
|
||||||
|
return self.littleEndian
|
||||||
|
case .big:
|
||||||
|
return self.bigEndian
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
extension ByteBuffer {
|
||||||
|
mutating func readIPv4Header() -> IPv4Header? {
|
||||||
|
guard let (
|
||||||
|
versionAndIhl,
|
||||||
|
dscpAndEcn,
|
||||||
|
totalLength,
|
||||||
|
identification,
|
||||||
|
flagsAndFragmentOffset,
|
||||||
|
timeToLive,
|
||||||
|
`protocol`,
|
||||||
|
headerChecksum,
|
||||||
|
sourceIpAddress,
|
||||||
|
destinationIpAddress
|
||||||
|
) = self.readMultipleIntegers(as: (
|
||||||
|
UInt8,
|
||||||
|
UInt8,
|
||||||
|
UInt16,
|
||||||
|
UInt16,
|
||||||
|
UInt16,
|
||||||
|
UInt8,
|
||||||
|
UInt8,
|
||||||
|
UInt16,
|
||||||
|
UInt32,
|
||||||
|
UInt32
|
||||||
|
).self) else { return nil }
|
||||||
|
return .init(
|
||||||
|
versionAndIhl: versionAndIhl,
|
||||||
|
dscpAndEcn: dscpAndEcn,
|
||||||
|
totalLength: totalLength,
|
||||||
|
identification: identification,
|
||||||
|
flagsAndFragmentOffset: flagsAndFragmentOffset,
|
||||||
|
timeToLive: timeToLive,
|
||||||
|
protocol: .init(rawValue: `protocol`),
|
||||||
|
headerChecksum: headerChecksum,
|
||||||
|
sourceIpAddress: .init(rawValue: sourceIpAddress),
|
||||||
|
destinationIpAddress: .init(rawValue: destinationIpAddress)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating func readIPv4HeaderFromBSDRawSocket() -> IPv4Header? {
|
||||||
|
guard var header = self.readIPv4Header() else { return nil }
|
||||||
|
// On BSD, the total length is in host byte order
|
||||||
|
header.totalLength = header.totalLength.convertEndianness(to: .big)
|
||||||
|
// TODO: fragmentOffset is in host byte order as well but it is always zero in our tests
|
||||||
|
// and fragmentOffset is 13 bits in size so we can't just use readInteger(endianness: .host)
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating func readIPv4HeaderFromOSRawSocket() -> IPv4Header? {
|
||||||
|
#if canImport(Darwin)
|
||||||
|
return self.readIPv4HeaderFromBSDRawSocket()
|
||||||
|
#else
|
||||||
|
return self.readIPv4Header()
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension ByteBuffer {
|
||||||
|
@discardableResult
|
||||||
|
mutating func writeIPv4Header(_ header: IPv4Header) -> Int {
|
||||||
|
assert({
|
||||||
|
var buffer = ByteBuffer()
|
||||||
|
buffer._writeIPv4Header(header)
|
||||||
|
let writtenHeader = buffer.readIPv4Header()
|
||||||
|
return header == writtenHeader
|
||||||
|
}())
|
||||||
|
return self._writeIPv4Header(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
private mutating func _writeIPv4Header(_ header: IPv4Header) -> Int {
|
||||||
|
return self.writeMultipleIntegers(
|
||||||
|
header.versionAndIhl,
|
||||||
|
header.dscpAndEcn,
|
||||||
|
header.totalLength,
|
||||||
|
header.identification,
|
||||||
|
header.flagsAndFragmentOffset,
|
||||||
|
header.timeToLive,
|
||||||
|
header.`protocol`.rawValue,
|
||||||
|
header.headerChecksum,
|
||||||
|
header.sourceIpAddress.rawValue,
|
||||||
|
header.destinationIpAddress.rawValue
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
mutating func writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int {
|
||||||
|
assert({
|
||||||
|
var buffer = ByteBuffer()
|
||||||
|
buffer._writeIPv4HeaderToBSDRawSocket(header)
|
||||||
|
let writtenHeader = buffer.readIPv4HeaderFromBSDRawSocket()
|
||||||
|
return header == writtenHeader
|
||||||
|
}())
|
||||||
|
return self._writeIPv4HeaderToBSDRawSocket(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
private mutating func _writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int {
|
||||||
|
var header = header
|
||||||
|
// On BSD, the total length needs to be in host byte order
|
||||||
|
header.totalLength = header.totalLength.convertEndianness(to: .big)
|
||||||
|
// TODO: fragmentOffset needs to be in host byte order as well but it is always zero in our tests
|
||||||
|
// and fragmentOffset is 13 bits in size so we can't just use writeInteger(endianness: .host)
|
||||||
|
return self._writeIPv4Header(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
mutating func writeIPv4HeaderToOSRawSocket(_ header: IPv4Header) -> Int {
|
||||||
|
#if canImport(Darwin)
|
||||||
|
self.writeIPv4HeaderToBSDRawSocket(header)
|
||||||
|
#else
|
||||||
|
self.writeIPv4Header(header)
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension IPv4Header {
|
||||||
|
func computeChecksum() -> UInt16 {
|
||||||
|
let checksum = ~[
|
||||||
|
UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn),
|
||||||
|
totalLength,
|
||||||
|
identification,
|
||||||
|
flagsAndFragmentOffset,
|
||||||
|
UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue),
|
||||||
|
UInt16(sourceIpAddress.rawValue >> 16),
|
||||||
|
UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||||
|
UInt16(destinationIpAddress.rawValue >> 16),
|
||||||
|
UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||||
|
].reduce(UInt16(0), onesComplementAdd)
|
||||||
|
assert(isValidChecksum(checksum))
|
||||||
|
return checksum
|
||||||
|
}
|
||||||
|
mutating func setChecksum() {
|
||||||
|
self.headerChecksum = computeChecksum()
|
||||||
|
}
|
||||||
|
func isValidChecksum(_ headerChecksum: UInt16) -> Bool {
|
||||||
|
let sum = ~[
|
||||||
|
UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn),
|
||||||
|
totalLength,
|
||||||
|
identification,
|
||||||
|
flagsAndFragmentOffset,
|
||||||
|
UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue),
|
||||||
|
headerChecksum,
|
||||||
|
UInt16(sourceIpAddress.rawValue >> 16),
|
||||||
|
UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||||
|
UInt16(destinationIpAddress.rawValue >> 16),
|
||||||
|
UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||||
|
].reduce(UInt16(0), onesComplementAdd)
|
||||||
|
return sum == 0
|
||||||
|
}
|
||||||
|
func isValidChecksum() -> Bool {
|
||||||
|
isValidChecksum(headerChecksum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func onesComplementAdd<Integer: FixedWidthInteger>(lhs: Integer, rhs: Integer) -> Integer {
|
||||||
|
var (sum, overflowed) = lhs.addingReportingOverflow(rhs)
|
||||||
|
if overflowed {
|
||||||
|
sum &+= 1
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
extension IPv4Header {
|
||||||
|
var platformIndependentTotalLengthForReceivedPacketFromRawSocket: UInt16 {
|
||||||
|
#if canImport(Darwin)
|
||||||
|
// On BSD the IP header will only contain the size of the ip packet body, not the header.
|
||||||
|
// This is known bug which can't be fixed without breaking old apps which already workaround the issue
|
||||||
|
// like e.g. we do now too.
|
||||||
|
return totalLength + 20
|
||||||
|
#elseif os(Linux)
|
||||||
|
return totalLength
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
var platformIndependentChecksumForReceivedPacketFromRawSocket: UInt16 {
|
||||||
|
#if canImport(Darwin)
|
||||||
|
// On BSD the checksum is always zero and we need to compute it
|
||||||
|
precondition(headerChecksum == 0)
|
||||||
|
return computeChecksum()
|
||||||
|
#elseif os(Linux)
|
||||||
|
return headerChecksum
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension IPv4Header: CustomStringConvertible {
|
||||||
|
var description: String {
|
||||||
|
"""
|
||||||
|
Version: \(version)
|
||||||
|
Header Length: \(internetHeaderLength * 4) bytes
|
||||||
|
Differentiated Services: \(String(differentiatedServicesCodePoint, radix: 2))
|
||||||
|
Explicit Congestion Notification: \(String(explicitCongestionNotification, radix: 2))
|
||||||
|
Total Length: \(totalLength) bytes
|
||||||
|
Identification: \(identification)
|
||||||
|
Flags: \(String(flags, radix: 2))
|
||||||
|
Fragment Offset: \(fragmentOffset) bytes
|
||||||
|
Time to Live: \(timeToLive)
|
||||||
|
Protocol: \(`protocol`)
|
||||||
|
Header Checksum: \(headerChecksum) (\(isValidChecksum() ? "valid" : "*not* valid"))
|
||||||
|
Source IP Address: \(sourceIpAddress)
|
||||||
|
Destination IP Address: \(destinationIpAddress)
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// RawSocketBootstrapTests+XCTest.swift
|
||||||
|
//
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
///
|
||||||
|
/// NOTE: This file was generated by generate_linux_tests.rb
|
||||||
|
///
|
||||||
|
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
|
||||||
|
///
|
||||||
|
|
||||||
|
extension RawSocketBootstrapTests {
|
||||||
|
|
||||||
|
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
|
||||||
|
static var allTests : [(String, (RawSocketBootstrapTests) -> () throws -> Void)] {
|
||||||
|
return [
|
||||||
|
("testBindWithRecevMmsg", testBindWithRecevMmsg),
|
||||||
|
("testConnect", testConnect),
|
||||||
|
("testIpHdrincl", testIpHdrincl),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This source file is part of the SwiftNIO open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||||
|
// Licensed under Apache License v2.0
|
||||||
|
//
|
||||||
|
// See LICENSE.txt for license information
|
||||||
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
import XCTest
|
||||||
|
import NIOCore
|
||||||
|
@testable import NIOPosix
|
||||||
|
|
||||||
|
extension NIOIPProtocol {
|
||||||
|
static let reservedForTesting = Self(rawValue: 253)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lazily try's to create a raw socket and caches the error if it fails
|
||||||
|
private let cachedRawSocketAPICheck = Result<Void, Error> {
|
||||||
|
let socket = try Socket(protocolFamily: .inet, type: .raw, protocolSubtype: .init(NIOIPProtocol.reservedForTesting), setNonBlocking: true)
|
||||||
|
try socket.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI(file: StaticString = #filePath, line: UInt = #line) throws {
|
||||||
|
do {
|
||||||
|
try cachedRawSocketAPICheck.get()
|
||||||
|
} catch let error as IOError where error.errnoCode == EPERM {
|
||||||
|
throw XCTSkip("Raw Socket API requires higher privileges: \(error)", file: file, line: line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final class RawSocketBootstrapTests: XCTestCase {
|
||||||
|
func testBindWithRecevMmsg() throws {
|
||||||
|
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
|
||||||
|
|
||||||
|
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||||
|
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
|
||||||
|
let channel = try NIORawSocketBootstrap(group: elg)
|
||||||
|
.channelInitializer {
|
||||||
|
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||||
|
}
|
||||||
|
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||||
|
defer { XCTAssertNoThrow(try channel.close().wait()) }
|
||||||
|
try channel.configureForRecvMmsg(messageCount: 10)
|
||||||
|
let expectedMessages = (1...10).map { "Hello World \($0)" }
|
||||||
|
for message in expectedMessages {
|
||||||
|
_ = try channel.write(AddressedEnvelope(
|
||||||
|
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
|
||||||
|
data: ByteBuffer(string: message)
|
||||||
|
))
|
||||||
|
}
|
||||||
|
channel.flush()
|
||||||
|
|
||||||
|
let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in
|
||||||
|
var data = envelop.data
|
||||||
|
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
|
||||||
|
XCTAssertEqual(header.version, 4)
|
||||||
|
XCTAssertEqual(header.protocol, .reservedForTesting)
|
||||||
|
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
|
||||||
|
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
|
||||||
|
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
|
||||||
|
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
|
||||||
|
return String(buffer: data)
|
||||||
|
})
|
||||||
|
|
||||||
|
XCTAssertEqual(receivedMessages, Set(expectedMessages))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testConnect() throws {
|
||||||
|
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
|
||||||
|
|
||||||
|
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||||
|
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
|
||||||
|
let readChannel = try NIORawSocketBootstrap(group: elg)
|
||||||
|
.channelInitializer {
|
||||||
|
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||||
|
}
|
||||||
|
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||||
|
defer { XCTAssertNoThrow(try readChannel.close().wait()) }
|
||||||
|
|
||||||
|
let writeChannel = try NIORawSocketBootstrap(group: elg)
|
||||||
|
.channelInitializer {
|
||||||
|
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||||
|
}
|
||||||
|
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||||
|
defer { XCTAssertNoThrow(try writeChannel.close().wait()) }
|
||||||
|
|
||||||
|
let expectedMessages = (1...10).map { "Hello World \($0)" }
|
||||||
|
for message in expectedMessages {
|
||||||
|
_ = try writeChannel.write(AddressedEnvelope(
|
||||||
|
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
|
||||||
|
data: ByteBuffer(string: message)
|
||||||
|
))
|
||||||
|
}
|
||||||
|
writeChannel.flush()
|
||||||
|
|
||||||
|
let receivedMessages = Set(try readChannel.waitForDatagrams(count: 10).map { envelop -> String in
|
||||||
|
var data = envelop.data
|
||||||
|
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
|
||||||
|
XCTAssertEqual(header.version, 4)
|
||||||
|
XCTAssertEqual(header.protocol, .reservedForTesting)
|
||||||
|
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
|
||||||
|
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
|
||||||
|
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
|
||||||
|
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
|
||||||
|
return String(buffer: data)
|
||||||
|
})
|
||||||
|
|
||||||
|
XCTAssertEqual(receivedMessages, Set(expectedMessages))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testIpHdrincl() throws {
|
||||||
|
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
|
||||||
|
|
||||||
|
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||||
|
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
|
||||||
|
let channel = try NIORawSocketBootstrap(group: elg)
|
||||||
|
.channelOption(ChannelOptions.ipOption(.ip_hdrincl), value: 1)
|
||||||
|
.channelInitializer {
|
||||||
|
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||||
|
}
|
||||||
|
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||||
|
defer { XCTAssertNoThrow(try channel.close().wait()) }
|
||||||
|
try channel.configureForRecvMmsg(messageCount: 10)
|
||||||
|
let expectedMessages = (1...10).map { "Hello World \($0)" }
|
||||||
|
for message in expectedMessages.map(ByteBuffer.init(string:)) {
|
||||||
|
var packet = ByteBuffer()
|
||||||
|
var header = IPv4Header()
|
||||||
|
header.version = 4
|
||||||
|
header.internetHeaderLength = 5
|
||||||
|
header.totalLength = UInt16(IPv4Header.size + message.readableBytes)
|
||||||
|
header.protocol = .reservedForTesting
|
||||||
|
header.timeToLive = 64
|
||||||
|
header.destinationIpAddress = .init(127, 0, 0, 1)
|
||||||
|
header.sourceIpAddress = .init(127, 0, 0, 1)
|
||||||
|
header.setChecksum()
|
||||||
|
packet.writeIPv4HeaderToOSRawSocket(header)
|
||||||
|
packet.writeImmutableBuffer(message)
|
||||||
|
try channel.writeAndFlush(AddressedEnvelope(
|
||||||
|
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
|
||||||
|
data: packet
|
||||||
|
)).wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in
|
||||||
|
var data = envelop.data
|
||||||
|
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
|
||||||
|
XCTAssertEqual(header.version, 4)
|
||||||
|
XCTAssertEqual(header.protocol, .reservedForTesting)
|
||||||
|
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
|
||||||
|
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
|
||||||
|
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
|
||||||
|
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
|
||||||
|
return String(buffer: data)
|
||||||
|
})
|
||||||
|
|
||||||
|
XCTAssertEqual(receivedMessages, Set(expectedMessages))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue