Add `RawSocketBootstrap` (#2320)
This commit is contained in:
parent
00341c9277
commit
810544ec41
|
@ -267,6 +267,13 @@ extension NIOBSDSocket.Option {
|
|||
/// Control multicast time-to-live.
|
||||
public static let ip_multicast_ttl: NIOBSDSocket.Option =
|
||||
NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL)
|
||||
|
||||
/// The IPv4 layer generates an IP header when sending a packet
|
||||
/// unless the ``ip_hdrincl`` socket option is enabled on the socket.
|
||||
/// When it is enabled, the packet must contain an IP header. For
|
||||
/// receiving, the IP header is always included in the packet.
|
||||
public static let ip_hdrincl: NIOBSDSocket.Option =
|
||||
NIOBSDSocket.Option(rawValue: IP_HDRINCL)
|
||||
}
|
||||
|
||||
// IPv6 Options
|
||||
|
|
|
@ -277,6 +277,11 @@ public struct ChannelOptions {
|
|||
.init(level: .socket, name: name)
|
||||
}
|
||||
|
||||
/// - seealso: `SocketOption`.
|
||||
public static let ipOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
||||
.init(level: .ip, name: name)
|
||||
}
|
||||
|
||||
/// - seealso: `SocketOption`.
|
||||
public static let tcpOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in
|
||||
.init(level: .tcp, name: name)
|
||||
|
|
|
@ -83,6 +83,14 @@ extension NIOBSDSocket.SocketType {
|
|||
internal static let stream: NIOBSDSocket.SocketType =
|
||||
NIOBSDSocket.SocketType(rawValue: SOCK_STREAM)
|
||||
#endif
|
||||
|
||||
#if os(Linux)
|
||||
internal static let raw: NIOBSDSocket.SocketType =
|
||||
NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue))
|
||||
#else
|
||||
internal static let raw: NIOBSDSocket.SocketType =
|
||||
NIOBSDSocket.SocketType(rawValue: SOCK_RAW)
|
||||
#endif
|
||||
}
|
||||
|
||||
// IPv4 Options
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
import NIOCore
|
||||
|
||||
/// A `RawSocketBootstrap` is an easy way to interact with IP based protocols other then TCP and UDP.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```swift
|
||||
/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
/// defer {
|
||||
/// try! group.syncShutdownGracefully()
|
||||
/// }
|
||||
/// let bootstrap = RawSocketBootstrap(group: group)
|
||||
/// .channelInitializer { channel in
|
||||
/// channel.pipeline.addHandler(MyChannelHandler())
|
||||
/// }
|
||||
/// let channel = try! bootstrap.bind(host: "127.0.0.1", ipProtocol: .icmp).wait()
|
||||
/// /* the Channel is now ready to send/receive IP packets */
|
||||
///
|
||||
/// try channel.closeFuture.wait() // Wait until the channel un-binds.
|
||||
/// ```
|
||||
///
|
||||
/// The `Channel` will operate on `AddressedEnvelope<ByteBuffer>` as inbound and outbound messages.
|
||||
public final class NIORawSocketBootstrap {
|
||||
|
||||
private let group: EventLoopGroup
|
||||
private var channelInitializer: Optional<ChannelInitializerCallback>
|
||||
@usableFromInline
|
||||
internal var _channelOptions: ChannelOptions.Storage
|
||||
|
||||
/// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`.
|
||||
///
|
||||
/// The `EventLoopGroup` `group` must be compatible, otherwise the program will crash. `RawSocketBootstrap` is
|
||||
/// compatible only with `MultiThreadedEventLoopGroup` as well as the `EventLoop`s returned by
|
||||
/// `MultiThreadedEventLoopGroup.next`. See `init(validatingGroup:)` for a fallible initializer for
|
||||
/// situations where it's impossible to tell ahead of time if the `EventLoopGroup` is compatible or not.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - group: The `EventLoopGroup` to use.
|
||||
public convenience init(group: EventLoopGroup) {
|
||||
guard NIOOnSocketsBootstraps.isCompatible(group: group) else {
|
||||
preconditionFailure("RawSocketBootstrap is only compatible with MultiThreadedEventLoopGroup and " +
|
||||
"SelectableEventLoop. You tried constructing one with \(group) which is incompatible.")
|
||||
}
|
||||
self.init(validatingGroup: group)!
|
||||
}
|
||||
|
||||
/// Create a `RawSocketBootstrap` on the `EventLoopGroup` `group`, validating that `group` is compatible.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - group: The `EventLoopGroup` to use.
|
||||
public init?(validatingGroup group: EventLoopGroup) {
|
||||
guard NIOOnSocketsBootstraps.isCompatible(group: group) else {
|
||||
return nil
|
||||
}
|
||||
self._channelOptions = ChannelOptions.Storage()
|
||||
self.group = group
|
||||
self.channelInitializer = nil
|
||||
}
|
||||
|
||||
/// Initialize the bound `Channel` with `initializer`. The most common task in initializer is to add
|
||||
/// `ChannelHandler`s to the `ChannelPipeline`.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - handler: A closure that initializes the provided `Channel`.
|
||||
public func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self {
|
||||
self.channelInitializer = handler
|
||||
return self
|
||||
}
|
||||
|
||||
/// Specifies a `ChannelOption` to be applied to the `Channel`.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - option: The option to be applied.
|
||||
/// - value: The value for the option.
|
||||
@inlinable
|
||||
public func channelOption<Option: ChannelOption>(_ option: Option, value: Option.Value) -> Self {
|
||||
self._channelOptions.append(key: option, value: value)
|
||||
return self
|
||||
}
|
||||
|
||||
/// Bind the `Channel` to `host`.
|
||||
/// All packets or errors matching the `ipProtocol` specified are passed to the resulting `Channel`.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - host: The host to bind on.
|
||||
/// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field.
|
||||
public func bind(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture<Channel> {
|
||||
return bind0(ipProtocol: ipProtocol) {
|
||||
return try SocketAddress.makeAddressResolvingHost(host, port: 0)
|
||||
}
|
||||
}
|
||||
|
||||
private func bind0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
|
||||
let address: SocketAddress
|
||||
do {
|
||||
address = try makeSocketAddress()
|
||||
} catch {
|
||||
return group.next().makeFailedFuture(error)
|
||||
}
|
||||
precondition(address.port == nil || address.port == 0, "port must be 0 or not set")
|
||||
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
|
||||
return try DatagramChannel(eventLoop: eventLoop,
|
||||
protocolFamily: address.protocol,
|
||||
protocolSubtype: .init(ipProtocol),
|
||||
socketType: .raw)
|
||||
}
|
||||
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
|
||||
channel.register().flatMap {
|
||||
channel.bind(to: address)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect the `Channel` to `host`.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - host: The host to connect to.
|
||||
/// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field.
|
||||
public func connect(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture<Channel> {
|
||||
return connect0(ipProtocol: ipProtocol) {
|
||||
return try SocketAddress.makeAddressResolvingHost(host, port: 0)
|
||||
}
|
||||
}
|
||||
|
||||
private func connect0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
|
||||
let address: SocketAddress
|
||||
do {
|
||||
address = try makeSocketAddress()
|
||||
} catch {
|
||||
return group.next().makeFailedFuture(error)
|
||||
}
|
||||
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
|
||||
return try DatagramChannel(eventLoop: eventLoop,
|
||||
protocolFamily: address.protocol,
|
||||
protocolSubtype: .init(ipProtocol),
|
||||
socketType: .raw)
|
||||
}
|
||||
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
|
||||
channel.register().flatMap {
|
||||
channel.connect(to: address)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
|
||||
let eventLoop = self.group.next()
|
||||
let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) }
|
||||
let channelOptions = self._channelOptions
|
||||
|
||||
let channel: DatagramChannel
|
||||
do {
|
||||
channel = try makeChannel(eventLoop as! SelectableEventLoop)
|
||||
} catch {
|
||||
return eventLoop.makeFailedFuture(error)
|
||||
}
|
||||
|
||||
func setupChannel() -> EventLoopFuture<Channel> {
|
||||
eventLoop.assertInEventLoop()
|
||||
return channelOptions.applyAllChannelOptions(to: channel).flatMap {
|
||||
channelInitializer(channel)
|
||||
}.flatMap {
|
||||
eventLoop.assertInEventLoop()
|
||||
return bringup(eventLoop, channel)
|
||||
}.map {
|
||||
channel
|
||||
}.flatMapError { error in
|
||||
eventLoop.makeFailedFuture(error)
|
||||
}
|
||||
}
|
||||
|
||||
if eventLoop.inEventLoop {
|
||||
return setupChannel()
|
||||
} else {
|
||||
return eventLoop.flatSubmit {
|
||||
setupChannel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if swift(>=5.6)
|
||||
@available(*, unavailable)
|
||||
extension NIORawSocketBootstrap: Sendable {}
|
||||
#endif
|
|
@ -32,7 +32,7 @@ typealias IOVector = iovec
|
|||
/// - parameters:
|
||||
/// - protocolFamily: The protocol family to use (usually `AF_INET6` or `AF_INET`).
|
||||
/// - type: The type of the socket to create.
|
||||
/// - protocolSubtype: The subtype of the protocol, corresponding to the `protocol`
|
||||
/// - protocolSubtype: The subtype of the protocol, corresponding to the `protocolSubtype`
|
||||
/// argument to the socket syscall. Defaults to 0.
|
||||
/// - setNonBlocking: Set non-blocking mode on the socket.
|
||||
/// - throws: An `IOError` if creation of the socket failed.
|
||||
|
|
|
@ -120,6 +120,7 @@ class LinuxMainRunner {
|
|||
testCase(PendingDatagramWritesManagerTests.allTests),
|
||||
testCase(PipeChannelTest.allTests),
|
||||
testCase(PriorityQueueTest.allTests),
|
||||
testCase(RawSocketBootstrapTests.allTests),
|
||||
testCase(SALChannelTest.allTests),
|
||||
testCase(SALEventLoopTests.allTests),
|
||||
testCase(SNIHandlerTest.allTests),
|
||||
|
|
|
@ -17,7 +17,7 @@ import NIOCore
|
|||
@testable import NIOPosix
|
||||
import XCTest
|
||||
|
||||
private extension Channel {
|
||||
extension Channel {
|
||||
func waitForDatagrams(count: Int) throws -> [AddressedEnvelope<ByteBuffer>] {
|
||||
return try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in
|
||||
if let future = (context.handler as? DatagramReadRecorder<ByteBuffer>)?.notifyForDatagrams(count) {
|
||||
|
@ -47,7 +47,7 @@ private extension Channel {
|
|||
/// A class that records datagrams received and forwards them on.
|
||||
///
|
||||
/// Used extensively in tests to validate messaging expectations.
|
||||
private class DatagramReadRecorder<DataType>: ChannelInboundHandler {
|
||||
final class DatagramReadRecorder<DataType>: ChannelInboundHandler {
|
||||
typealias InboundIn = AddressedEnvelope<DataType>
|
||||
typealias InboundOut = AddressedEnvelope<DataType>
|
||||
|
||||
|
|
|
@ -0,0 +1,365 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import NIOCore
|
||||
|
||||
struct IPv4Address: Hashable {
|
||||
var rawValue: UInt32
|
||||
}
|
||||
|
||||
extension IPv4Address {
|
||||
init(_ v1: UInt8, _ v2: UInt8, _ v3: UInt8, _ v4: UInt8) {
|
||||
rawValue = UInt32(v1) << 24 | UInt32(v2) << 16 | UInt32(v3) << 8 | UInt32(v4)
|
||||
}
|
||||
}
|
||||
|
||||
extension IPv4Address: CustomStringConvertible {
|
||||
var description: String {
|
||||
let v1 = rawValue >> 24
|
||||
let v2 = rawValue >> 16 & 0b1111_1111
|
||||
let v3 = rawValue >> 8 & 0b1111_1111
|
||||
let v4 = rawValue & 0b1111_1111
|
||||
return "\(v1).\(v2).\(v3).\(v4)"
|
||||
}
|
||||
}
|
||||
|
||||
struct IPv4Header: Hashable {
|
||||
static let size: Int = 20
|
||||
|
||||
fileprivate var versionAndIhl: UInt8
|
||||
var version: UInt8 {
|
||||
get {
|
||||
versionAndIhl >> 4
|
||||
}
|
||||
set {
|
||||
precondition(newValue & 0b1111_0000 == 0)
|
||||
versionAndIhl = newValue << 4 | (0b0000_1111 & versionAndIhl)
|
||||
assert(newValue == version, "\(newValue) != \(version) \(versionAndIhl)")
|
||||
}
|
||||
}
|
||||
var internetHeaderLength: UInt8 {
|
||||
get {
|
||||
versionAndIhl & 0b0000_1111
|
||||
}
|
||||
set {
|
||||
precondition(newValue & 0b1111_0000 == 0)
|
||||
versionAndIhl = newValue | (0b1111_0000 & versionAndIhl)
|
||||
assert(newValue == internetHeaderLength)
|
||||
}
|
||||
}
|
||||
fileprivate var dscpAndEcn: UInt8
|
||||
var differentiatedServicesCodePoint: UInt8 {
|
||||
get {
|
||||
dscpAndEcn >> 2
|
||||
}
|
||||
set {
|
||||
precondition(newValue & 0b0000_0011 == 0)
|
||||
dscpAndEcn = newValue << 2 | (0b0000_0011 & dscpAndEcn)
|
||||
assert(newValue == differentiatedServicesCodePoint)
|
||||
}
|
||||
}
|
||||
var explicitCongestionNotification: UInt8 {
|
||||
get {
|
||||
dscpAndEcn & 0b0000_0011
|
||||
}
|
||||
set {
|
||||
precondition(newValue & 0b0000_0011 == 0)
|
||||
dscpAndEcn = newValue | (0b1111_1100 & dscpAndEcn)
|
||||
assert(newValue == explicitCongestionNotification)
|
||||
}
|
||||
}
|
||||
var totalLength: UInt16
|
||||
var identification: UInt16
|
||||
fileprivate var flagsAndFragmentOffset: UInt16
|
||||
var flags: UInt8 {
|
||||
get {
|
||||
UInt8(flagsAndFragmentOffset >> 13)
|
||||
}
|
||||
set {
|
||||
precondition(newValue & 0b0000_0111 == 0)
|
||||
flagsAndFragmentOffset = UInt16(newValue) << 13 | (0b0001_1111_1111_1111 & flagsAndFragmentOffset)
|
||||
assert(newValue == flags)
|
||||
}
|
||||
}
|
||||
var fragmentOffset: UInt16 {
|
||||
get {
|
||||
flagsAndFragmentOffset & 0b0001_1111_1111_1111
|
||||
}
|
||||
set {
|
||||
precondition(newValue & 0b1110_0000_0000_0000 == 0)
|
||||
flagsAndFragmentOffset = newValue | (0b1110_0000_0000_0000 & flagsAndFragmentOffset)
|
||||
assert(newValue == fragmentOffset)
|
||||
}
|
||||
}
|
||||
var timeToLive: UInt8
|
||||
var `protocol`: NIOIPProtocol
|
||||
var headerChecksum: UInt16
|
||||
var sourceIpAddress: IPv4Address
|
||||
var destinationIpAddress: IPv4Address
|
||||
|
||||
fileprivate init(
|
||||
versionAndIhl: UInt8,
|
||||
dscpAndEcn: UInt8,
|
||||
totalLength: UInt16,
|
||||
identification: UInt16,
|
||||
flagsAndFragmentOffset: UInt16,
|
||||
timeToLive: UInt8,
|
||||
`protocol`: NIOIPProtocol,
|
||||
headerChecksum: UInt16,
|
||||
sourceIpAddress: IPv4Address,
|
||||
destinationIpAddress: IPv4Address
|
||||
) {
|
||||
self.versionAndIhl = versionAndIhl
|
||||
self.dscpAndEcn = dscpAndEcn
|
||||
self.totalLength = totalLength
|
||||
self.identification = identification
|
||||
self.flagsAndFragmentOffset = flagsAndFragmentOffset
|
||||
self.timeToLive = timeToLive
|
||||
self.`protocol` = `protocol`
|
||||
self.headerChecksum = headerChecksum
|
||||
self.sourceIpAddress = sourceIpAddress
|
||||
self.destinationIpAddress = destinationIpAddress
|
||||
}
|
||||
|
||||
init() {
|
||||
self.versionAndIhl = 0
|
||||
self.dscpAndEcn = 0
|
||||
self.totalLength = 0
|
||||
self.identification = 0
|
||||
self.flagsAndFragmentOffset = 0
|
||||
self.timeToLive = 0
|
||||
self.`protocol` = .init(rawValue: 0)
|
||||
self.headerChecksum = 0
|
||||
self.sourceIpAddress = .init(rawValue: 0)
|
||||
self.destinationIpAddress = .init(rawValue: 0)
|
||||
}
|
||||
}
|
||||
|
||||
extension FixedWidthInteger {
|
||||
func convertEndianness(to endianness: Endianness) -> Self {
|
||||
switch endianness {
|
||||
case .little:
|
||||
return self.littleEndian
|
||||
case .big:
|
||||
return self.bigEndian
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
extension ByteBuffer {
|
||||
mutating func readIPv4Header() -> IPv4Header? {
|
||||
guard let (
|
||||
versionAndIhl,
|
||||
dscpAndEcn,
|
||||
totalLength,
|
||||
identification,
|
||||
flagsAndFragmentOffset,
|
||||
timeToLive,
|
||||
`protocol`,
|
||||
headerChecksum,
|
||||
sourceIpAddress,
|
||||
destinationIpAddress
|
||||
) = self.readMultipleIntegers(as: (
|
||||
UInt8,
|
||||
UInt8,
|
||||
UInt16,
|
||||
UInt16,
|
||||
UInt16,
|
||||
UInt8,
|
||||
UInt8,
|
||||
UInt16,
|
||||
UInt32,
|
||||
UInt32
|
||||
).self) else { return nil }
|
||||
return .init(
|
||||
versionAndIhl: versionAndIhl,
|
||||
dscpAndEcn: dscpAndEcn,
|
||||
totalLength: totalLength,
|
||||
identification: identification,
|
||||
flagsAndFragmentOffset: flagsAndFragmentOffset,
|
||||
timeToLive: timeToLive,
|
||||
protocol: .init(rawValue: `protocol`),
|
||||
headerChecksum: headerChecksum,
|
||||
sourceIpAddress: .init(rawValue: sourceIpAddress),
|
||||
destinationIpAddress: .init(rawValue: destinationIpAddress)
|
||||
)
|
||||
}
|
||||
|
||||
mutating func readIPv4HeaderFromBSDRawSocket() -> IPv4Header? {
|
||||
guard var header = self.readIPv4Header() else { return nil }
|
||||
// On BSD, the total length is in host byte order
|
||||
header.totalLength = header.totalLength.convertEndianness(to: .big)
|
||||
// TODO: fragmentOffset is in host byte order as well but it is always zero in our tests
|
||||
// and fragmentOffset is 13 bits in size so we can't just use readInteger(endianness: .host)
|
||||
return header
|
||||
}
|
||||
|
||||
mutating func readIPv4HeaderFromOSRawSocket() -> IPv4Header? {
|
||||
#if canImport(Darwin)
|
||||
return self.readIPv4HeaderFromBSDRawSocket()
|
||||
#else
|
||||
return self.readIPv4Header()
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
extension ByteBuffer {
|
||||
@discardableResult
|
||||
mutating func writeIPv4Header(_ header: IPv4Header) -> Int {
|
||||
assert({
|
||||
var buffer = ByteBuffer()
|
||||
buffer._writeIPv4Header(header)
|
||||
let writtenHeader = buffer.readIPv4Header()
|
||||
return header == writtenHeader
|
||||
}())
|
||||
return self._writeIPv4Header(header)
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
private mutating func _writeIPv4Header(_ header: IPv4Header) -> Int {
|
||||
return self.writeMultipleIntegers(
|
||||
header.versionAndIhl,
|
||||
header.dscpAndEcn,
|
||||
header.totalLength,
|
||||
header.identification,
|
||||
header.flagsAndFragmentOffset,
|
||||
header.timeToLive,
|
||||
header.`protocol`.rawValue,
|
||||
header.headerChecksum,
|
||||
header.sourceIpAddress.rawValue,
|
||||
header.destinationIpAddress.rawValue
|
||||
)
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
mutating func writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int {
|
||||
assert({
|
||||
var buffer = ByteBuffer()
|
||||
buffer._writeIPv4HeaderToBSDRawSocket(header)
|
||||
let writtenHeader = buffer.readIPv4HeaderFromBSDRawSocket()
|
||||
return header == writtenHeader
|
||||
}())
|
||||
return self._writeIPv4HeaderToBSDRawSocket(header)
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
private mutating func _writeIPv4HeaderToBSDRawSocket(_ header: IPv4Header) -> Int {
|
||||
var header = header
|
||||
// On BSD, the total length needs to be in host byte order
|
||||
header.totalLength = header.totalLength.convertEndianness(to: .big)
|
||||
// TODO: fragmentOffset needs to be in host byte order as well but it is always zero in our tests
|
||||
// and fragmentOffset is 13 bits in size so we can't just use writeInteger(endianness: .host)
|
||||
return self._writeIPv4Header(header)
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
mutating func writeIPv4HeaderToOSRawSocket(_ header: IPv4Header) -> Int {
|
||||
#if canImport(Darwin)
|
||||
self.writeIPv4HeaderToBSDRawSocket(header)
|
||||
#else
|
||||
self.writeIPv4Header(header)
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
extension IPv4Header {
|
||||
func computeChecksum() -> UInt16 {
|
||||
let checksum = ~[
|
||||
UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn),
|
||||
totalLength,
|
||||
identification,
|
||||
flagsAndFragmentOffset,
|
||||
UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue),
|
||||
UInt16(sourceIpAddress.rawValue >> 16),
|
||||
UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||
UInt16(destinationIpAddress.rawValue >> 16),
|
||||
UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||
].reduce(UInt16(0), onesComplementAdd)
|
||||
assert(isValidChecksum(checksum))
|
||||
return checksum
|
||||
}
|
||||
mutating func setChecksum() {
|
||||
self.headerChecksum = computeChecksum()
|
||||
}
|
||||
func isValidChecksum(_ headerChecksum: UInt16) -> Bool {
|
||||
let sum = ~[
|
||||
UInt16(versionAndIhl) << 8 | UInt16(dscpAndEcn),
|
||||
totalLength,
|
||||
identification,
|
||||
flagsAndFragmentOffset,
|
||||
UInt16(timeToLive) << 8 | UInt16(`protocol`.rawValue),
|
||||
headerChecksum,
|
||||
UInt16(sourceIpAddress.rawValue >> 16),
|
||||
UInt16(sourceIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||
UInt16(destinationIpAddress.rawValue >> 16),
|
||||
UInt16(destinationIpAddress.rawValue & 0b0000_0000_0000_0000_1111_1111_1111_1111),
|
||||
].reduce(UInt16(0), onesComplementAdd)
|
||||
return sum == 0
|
||||
}
|
||||
func isValidChecksum() -> Bool {
|
||||
isValidChecksum(headerChecksum)
|
||||
}
|
||||
}
|
||||
|
||||
private func onesComplementAdd<Integer: FixedWidthInteger>(lhs: Integer, rhs: Integer) -> Integer {
|
||||
var (sum, overflowed) = lhs.addingReportingOverflow(rhs)
|
||||
if overflowed {
|
||||
sum &+= 1
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
extension IPv4Header {
|
||||
var platformIndependentTotalLengthForReceivedPacketFromRawSocket: UInt16 {
|
||||
#if canImport(Darwin)
|
||||
// On BSD the IP header will only contain the size of the ip packet body, not the header.
|
||||
// This is known bug which can't be fixed without breaking old apps which already workaround the issue
|
||||
// like e.g. we do now too.
|
||||
return totalLength + 20
|
||||
#elseif os(Linux)
|
||||
return totalLength
|
||||
#endif
|
||||
}
|
||||
var platformIndependentChecksumForReceivedPacketFromRawSocket: UInt16 {
|
||||
#if canImport(Darwin)
|
||||
// On BSD the checksum is always zero and we need to compute it
|
||||
precondition(headerChecksum == 0)
|
||||
return computeChecksum()
|
||||
#elseif os(Linux)
|
||||
return headerChecksum
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
extension IPv4Header: CustomStringConvertible {
|
||||
var description: String {
|
||||
"""
|
||||
Version: \(version)
|
||||
Header Length: \(internetHeaderLength * 4) bytes
|
||||
Differentiated Services: \(String(differentiatedServicesCodePoint, radix: 2))
|
||||
Explicit Congestion Notification: \(String(explicitCongestionNotification, radix: 2))
|
||||
Total Length: \(totalLength) bytes
|
||||
Identification: \(identification)
|
||||
Flags: \(String(flags, radix: 2))
|
||||
Fragment Offset: \(fragmentOffset) bytes
|
||||
Time to Live: \(timeToLive)
|
||||
Protocol: \(`protocol`)
|
||||
Header Checksum: \(headerChecksum) (\(isValidChecksum() ? "valid" : "*not* valid"))
|
||||
Source IP Address: \(sourceIpAddress)
|
||||
Destination IP Address: \(destinationIpAddress)
|
||||
"""
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// RawSocketBootstrapTests+XCTest.swift
|
||||
//
|
||||
import XCTest
|
||||
|
||||
///
|
||||
/// NOTE: This file was generated by generate_linux_tests.rb
|
||||
///
|
||||
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
|
||||
///
|
||||
|
||||
extension RawSocketBootstrapTests {
|
||||
|
||||
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
|
||||
static var allTests : [(String, (RawSocketBootstrapTests) -> () throws -> Void)] {
|
||||
return [
|
||||
("testBindWithRecevMmsg", testBindWithRecevMmsg),
|
||||
("testConnect", testConnect),
|
||||
("testIpHdrincl", testIpHdrincl),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import XCTest
|
||||
import NIOCore
|
||||
@testable import NIOPosix
|
||||
|
||||
extension NIOIPProtocol {
|
||||
static let reservedForTesting = Self(rawValue: 253)
|
||||
}
|
||||
|
||||
// lazily try's to create a raw socket and caches the error if it fails
|
||||
private let cachedRawSocketAPICheck = Result<Void, Error> {
|
||||
let socket = try Socket(protocolFamily: .inet, type: .raw, protocolSubtype: .init(NIOIPProtocol.reservedForTesting), setNonBlocking: true)
|
||||
try socket.close()
|
||||
}
|
||||
|
||||
func XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI(file: StaticString = #filePath, line: UInt = #line) throws {
|
||||
do {
|
||||
try cachedRawSocketAPICheck.get()
|
||||
} catch let error as IOError where error.errnoCode == EPERM {
|
||||
throw XCTSkip("Raw Socket API requires higher privileges: \(error)", file: file, line: line)
|
||||
}
|
||||
}
|
||||
|
||||
final class RawSocketBootstrapTests: XCTestCase {
|
||||
func testBindWithRecevMmsg() throws {
|
||||
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
|
||||
|
||||
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
|
||||
let channel = try NIORawSocketBootstrap(group: elg)
|
||||
.channelInitializer {
|
||||
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||
}
|
||||
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||
defer { XCTAssertNoThrow(try channel.close().wait()) }
|
||||
try channel.configureForRecvMmsg(messageCount: 10)
|
||||
let expectedMessages = (1...10).map { "Hello World \($0)" }
|
||||
for message in expectedMessages {
|
||||
_ = try channel.write(AddressedEnvelope(
|
||||
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
|
||||
data: ByteBuffer(string: message)
|
||||
))
|
||||
}
|
||||
channel.flush()
|
||||
|
||||
let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in
|
||||
var data = envelop.data
|
||||
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
|
||||
XCTAssertEqual(header.version, 4)
|
||||
XCTAssertEqual(header.protocol, .reservedForTesting)
|
||||
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
|
||||
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
|
||||
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
|
||||
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
|
||||
return String(buffer: data)
|
||||
})
|
||||
|
||||
XCTAssertEqual(receivedMessages, Set(expectedMessages))
|
||||
}
|
||||
|
||||
func testConnect() throws {
|
||||
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
|
||||
|
||||
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
|
||||
let readChannel = try NIORawSocketBootstrap(group: elg)
|
||||
.channelInitializer {
|
||||
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||
}
|
||||
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||
defer { XCTAssertNoThrow(try readChannel.close().wait()) }
|
||||
|
||||
let writeChannel = try NIORawSocketBootstrap(group: elg)
|
||||
.channelInitializer {
|
||||
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||
}
|
||||
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||
defer { XCTAssertNoThrow(try writeChannel.close().wait()) }
|
||||
|
||||
let expectedMessages = (1...10).map { "Hello World \($0)" }
|
||||
for message in expectedMessages {
|
||||
_ = try writeChannel.write(AddressedEnvelope(
|
||||
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
|
||||
data: ByteBuffer(string: message)
|
||||
))
|
||||
}
|
||||
writeChannel.flush()
|
||||
|
||||
let receivedMessages = Set(try readChannel.waitForDatagrams(count: 10).map { envelop -> String in
|
||||
var data = envelop.data
|
||||
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
|
||||
XCTAssertEqual(header.version, 4)
|
||||
XCTAssertEqual(header.protocol, .reservedForTesting)
|
||||
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
|
||||
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
|
||||
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
|
||||
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
|
||||
return String(buffer: data)
|
||||
})
|
||||
|
||||
XCTAssertEqual(receivedMessages, Set(expectedMessages))
|
||||
}
|
||||
|
||||
func testIpHdrincl() throws {
|
||||
try XCTSkipIfUserHasNotEnoughRightsForRawSocketAPI()
|
||||
|
||||
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) }
|
||||
let channel = try NIORawSocketBootstrap(group: elg)
|
||||
.channelOption(ChannelOptions.ipOption(.ip_hdrincl), value: 1)
|
||||
.channelInitializer {
|
||||
$0.pipeline.addHandler(DatagramReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
|
||||
}
|
||||
.bind(host: "127.0.0.1", ipProtocol: .reservedForTesting).wait()
|
||||
defer { XCTAssertNoThrow(try channel.close().wait()) }
|
||||
try channel.configureForRecvMmsg(messageCount: 10)
|
||||
let expectedMessages = (1...10).map { "Hello World \($0)" }
|
||||
for message in expectedMessages.map(ByteBuffer.init(string:)) {
|
||||
var packet = ByteBuffer()
|
||||
var header = IPv4Header()
|
||||
header.version = 4
|
||||
header.internetHeaderLength = 5
|
||||
header.totalLength = UInt16(IPv4Header.size + message.readableBytes)
|
||||
header.protocol = .reservedForTesting
|
||||
header.timeToLive = 64
|
||||
header.destinationIpAddress = .init(127, 0, 0, 1)
|
||||
header.sourceIpAddress = .init(127, 0, 0, 1)
|
||||
header.setChecksum()
|
||||
packet.writeIPv4HeaderToOSRawSocket(header)
|
||||
packet.writeImmutableBuffer(message)
|
||||
try channel.writeAndFlush(AddressedEnvelope(
|
||||
remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0),
|
||||
data: packet
|
||||
)).wait()
|
||||
}
|
||||
|
||||
let receivedMessages = Set(try channel.waitForDatagrams(count: 10).map { envelop -> String in
|
||||
var data = envelop.data
|
||||
let header = try XCTUnwrap(data.readIPv4HeaderFromOSRawSocket())
|
||||
XCTAssertEqual(header.version, 4)
|
||||
XCTAssertEqual(header.protocol, .reservedForTesting)
|
||||
XCTAssertEqual(Int(header.platformIndependentTotalLengthForReceivedPacketFromRawSocket), IPv4Header.size + data.readableBytes)
|
||||
XCTAssertTrue(header.isValidChecksum(header.platformIndependentChecksumForReceivedPacketFromRawSocket), "\(header)")
|
||||
XCTAssertEqual(header.sourceIpAddress, .init(127, 0, 0, 1))
|
||||
XCTAssertEqual(header.destinationIpAddress, .init(127, 0, 0, 1))
|
||||
return String(buffer: data)
|
||||
})
|
||||
|
||||
XCTAssertEqual(receivedMessages, Set(expectedMessages))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue