Make PooledBuffer safer. (#2363)

Motivation:

PooledBuffer is an inherently unsafe type, but its original incarnation
was less safe than it needed to be. In particular, we can rewrite it to
ensure that it is compatible with automatic reference counting.

Modifications:

- Rewrite PooledBuffer to use ManagedBuffer
- Clean up alignment math
- Use scoped accessors
- Add hooks for future non-scoped access

Result:

Safer, clearer code
This commit is contained in:
Cory Benfield 2023-02-08 07:59:17 +00:00 committed by GitHub
parent 1e7ad9a0db
commit 39047aec7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 249 additions and 133 deletions

View File

@ -78,80 +78,81 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite
let buffer = bufferPool.get()
defer { bufferPool.put(buffer) }
let (iovecs, storageRefs) = buffer.get()
for p in pending.flushedWrites {
// Must not write more than Int32.max in one go.
// TODO(cory): I can't see this limit documented in a man page anywhere, but it seems
// plausible given that a similar limit exists for TCP. For now we assume it's present
// in UDP until I can do some research to validate the existence of this limit.
guard (Socket.writevLimitBytes - toWrite >= p.data.readableBytes) else {
if c == 0 {
// The first buffer is larger than the writev limit. Let's throw, and fall back to linear processing.
throw IOError(errnoCode: EMSGSIZE, reason: "synthetic error for overlarge write")
} else {
return try buffer.withUnsafePointers { iovecs, storageRefs in
for p in pending.flushedWrites {
// Must not write more than Int32.max in one go.
// TODO(cory): I can't see this limit documented in a man page anywhere, but it seems
// plausible given that a similar limit exists for TCP. For now we assume it's present
// in UDP until I can do some research to validate the existence of this limit.
guard (Socket.writevLimitBytes - toWrite >= p.data.readableBytes) else {
if c == 0 {
// The first buffer is larger than the writev limit. Let's throw, and fall back to linear processing.
throw IOError(errnoCode: EMSGSIZE, reason: "synthetic error for overlarge write")
} else {
break
}
}
// Must not write more than writevLimitIOVectors in one go
guard c < Socket.writevLimitIOVectors else {
break
}
}
// Must not write more than writevLimitIOVectors in one go
guard c < Socket.writevLimitIOVectors else {
break
}
let toWriteForThisBuffer = p.data.readableBytes
toWrite += numericCast(toWriteForThisBuffer)
let toWriteForThisBuffer = p.data.readableBytes
toWrite += numericCast(toWriteForThisBuffer)
p.data.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
storageRefs[c] = storageRef.retain()
p.data.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
storageRefs[c] = storageRef.retain()
/// From man page of `sendmsg(2)`:
///
/// > The `msg_name` field is used on an unconnected socket to specify
/// > the target address for a datagram. It points to a buffer
/// > containing the address; the `msg_namelen` field should be set to
/// > the size of the address. For a connected socket, these fields
/// > should be specified as `NULL` and 0, respectively.
let address: UnsafeMutablePointer<sockaddr_storage>?
let addressLen: socklen_t
let protocolFamily: NIOBSDSocket.ProtocolFamily
if let envelopeAddress = p.address {
precondition(pending.remoteAddress == nil, "Pending write with address on connected socket.")
address = addresses.baseAddress! + c
addressLen = p.copySocketAddress(address!)
protocolFamily = envelopeAddress.protocol
} else {
guard let connectedRemoteAddress = pending.remoteAddress else {
preconditionFailure("Pending write without address on unconnected socket.")
/// From man page of `sendmsg(2)`:
///
/// > The `msg_name` field is used on an unconnected socket to specify
/// > the target address for a datagram. It points to a buffer
/// > containing the address; the `msg_namelen` field should be set to
/// > the size of the address. For a connected socket, these fields
/// > should be specified as `NULL` and 0, respectively.
let address: UnsafeMutablePointer<sockaddr_storage>?
let addressLen: socklen_t
let protocolFamily: NIOBSDSocket.ProtocolFamily
if let envelopeAddress = p.address {
precondition(pending.remoteAddress == nil, "Pending write with address on connected socket.")
address = addresses.baseAddress! + c
addressLen = p.copySocketAddress(address!)
protocolFamily = envelopeAddress.protocol
} else {
guard let connectedRemoteAddress = pending.remoteAddress else {
preconditionFailure("Pending write without address on unconnected socket.")
}
address = nil
addressLen = 0
protocolFamily = connectedRemoteAddress.protocol
}
address = nil
addressLen = 0
protocolFamily = connectedRemoteAddress.protocol
iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c])
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily)
let controlMessageBytePointer = controlBytes.validControlBytes
let msg = msghdr(msg_name: address,
msg_namelen: addressLen,
msg_iov: iovecs.baseAddress! + c,
msg_iovlen: 1,
msg_control: controlMessageBytePointer.baseAddress,
msg_controllen: .init(controlMessageBytePointer.count),
msg_flags: 0)
msgs[c] = MMsgHdr(msg_hdr: msg, msg_len: 0)
}
iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c])
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily)
let controlMessageBytePointer = controlBytes.validControlBytes
let msg = msghdr(msg_name: address,
msg_namelen: addressLen,
msg_iov: iovecs.baseAddress! + c,
msg_iovlen: 1,
msg_control: controlMessageBytePointer.baseAddress,
msg_controllen: .init(controlMessageBytePointer.count),
msg_flags: 0)
msgs[c] = MMsgHdr(msg_hdr: msg, msg_len: 0)
c += 1
}
c += 1
}
defer {
for i in 0..<c {
storageRefs[i].release()
defer {
for i in 0..<c {
storageRefs[i].release()
}
}
return try body(UnsafeMutableBufferPointer(start: msgs.baseAddress!, count: c))
}
return try body(UnsafeMutableBufferPointer(start: msgs.baseAddress!, count: c))
}
/// This holds the states of the currently pending datagram writes. The core is a `MarkedCircularBuffer` which holds all the

View File

@ -31,46 +31,47 @@ private func doPendingWriteVectorOperation(pending: PendingStreamWritesState,
_ body: (UnsafeBufferPointer<IOVector>) throws -> IOResult<Int>) throws -> (itemCount: Int, writeResult: IOResult<Int>) {
let buffer = bufferPool.get()
defer { bufferPool.put(buffer) }
let (iovecs, storageRefs) = buffer.get()
// Clamp the number of writes we're willing to issue to the limit for writev.
var count = min(iovecs.count, storageRefs.count)
count = min(pending.flushedChunks, count)
return try buffer.withUnsafePointers { iovecs, storageRefs in
// Clamp the number of writes we're willing to issue to the limit for writev.
var count = min(iovecs.count, storageRefs.count)
count = min(pending.flushedChunks, count)
// the numbers of storage refs that we need to decrease later.
var numberOfUsedStorageSlots = 0
var toWrite: Int = 0
// the numbers of storage refs that we need to decrease later.
var numberOfUsedStorageSlots = 0
var toWrite: Int = 0
loop: for i in 0..<count {
let p = pending[i]
switch p.data {
case .byteBuffer(let buffer):
// Must not write more than Int32.max in one go.
guard (numberOfUsedStorageSlots == 0) || (Socket.writevLimitBytes - toWrite >= buffer.readableBytes) else {
loop: for i in 0..<count {
let p = pending[i]
switch p.data {
case .byteBuffer(let buffer):
// Must not write more than Int32.max in one go.
guard (numberOfUsedStorageSlots == 0) || (Socket.writevLimitBytes - toWrite >= buffer.readableBytes) else {
break loop
}
let toWriteForThisBuffer = min(Socket.writevLimitBytes, buffer.readableBytes)
toWrite += numericCast(toWriteForThisBuffer)
buffer.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
storageRefs[i] = storageRef.retain()
iovecs[i] = IOVector(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
}
numberOfUsedStorageSlots += 1
case .fileRegion:
assert(numberOfUsedStorageSlots != 0, "first item in doPendingWriteVectorOperation was a FileRegion")
// We found a FileRegion so stop collecting
break loop
}
let toWriteForThisBuffer = min(Socket.writevLimitBytes, buffer.readableBytes)
toWrite += numericCast(toWriteForThisBuffer)
buffer.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
storageRefs[i] = storageRef.retain()
iovecs[i] = IOVector(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
}
defer {
for i in 0..<numberOfUsedStorageSlots {
storageRefs[i].release()
}
numberOfUsedStorageSlots += 1
case .fileRegion:
assert(numberOfUsedStorageSlots != 0, "first item in doPendingWriteVectorOperation was a FileRegion")
// We found a FileRegion so stop collecting
break loop
}
let result = try body(UnsafeBufferPointer(start: iovecs.baseAddress!, count: numberOfUsedStorageSlots))
/* if we hit a limit, we really wanted to write more than we have so the caller should retry us */
return (numberOfUsedStorageSlots, result)
}
defer {
for i in 0..<numberOfUsedStorageSlots {
storageRefs[i].release()
}
}
let result = try body(UnsafeBufferPointer(start: iovecs.baseAddress!, count: numberOfUsedStorageSlots))
/* if we hit a limit, we really wanted to write more than we have so the caller should retry us */
return (numberOfUsedStorageSlots, result)
}
/// The result of a single write operation, usually `write`, `sendfile` or `writev`.

View File

@ -50,3 +50,155 @@ class Pool<Element: PoolElement> {
}
}
}
/// A ``PooledBuffer`` is used to track an allocation of memory required
/// by a `Channel` or `EventLoopGroup`.
///
/// ``PooledBuffer`` is a reference type with inline storage. It is intended to
/// be bound to a single thread, and ensures that the allocation it stores does not
/// get freed before the buffer is out of use.
struct PooledBuffer: PoolElement {
private static let sentinelValue = MemorySentinel(0xdeadbeef)
private let storage: BackingStorage
init() {
self.storage = .create(iovectorCount: Socket.writevLimitIOVectors)
self.configureSentinel()
}
func evictedFromPool() {
self.validateSentinel()
}
func withUnsafePointers<ReturnValue>(
_ body: (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>) throws -> ReturnValue
) rethrows -> ReturnValue {
defer {
self.validateSentinel()
}
return try self.storage.withUnsafeMutableTypedPointers { iovecPointer, ownerPointer, _ in
try body(iovecPointer, ownerPointer)
}
}
/// Yields buffer pointers containing this ``PooledBuffer``'s readable bytes. You may hold a pointer to those bytes
/// even after the closure has returned iff you model the lifetime of those bytes correctly using the `Unmanaged`
/// instance. If you don't require the pointer after the closure returns, use ``withUnsafePointers``.
///
/// If you escape the pointer from the closure, you _must_ call `storageManagement.retain()` to get ownership to
/// the bytes and you also must call `storageManagement.release()` if you no longer require those bytes. Calls to
/// `retain` and `release` must be balanced.
///
/// - parameters:
/// - body: The closure that will accept the yielded pointers and the `storageManagement`.
/// - returns: The value returned by `body`.
func withUnsafePointersWithStorageManagement<ReturnValue>(
_ body: (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>, Unmanaged<AnyObject>) throws -> ReturnValue
) rethrows -> ReturnValue {
let storageRef: Unmanaged<AnyObject> = Unmanaged.passUnretained(self.storage)
return try self.storage.withUnsafeMutableTypedPointers { iovecPointer, ownerPointer, _ in
try body(iovecPointer, ownerPointer, storageRef)
}
}
private func configureSentinel() {
self.storage.withUnsafeMutableTypedPointers { _, _, sentinelPointer in
sentinelPointer.pointee = Self.sentinelValue
}
}
private func validateSentinel() {
self.storage.withUnsafeMutableTypedPointers { _, _, sentinelPointer in
precondition(sentinelPointer.pointee == Self.sentinelValue, "Detected memory handling error!")
}
}
}
extension PooledBuffer {
fileprivate typealias MemorySentinel = UInt32
fileprivate struct PooledBufferHead {
let iovectorCount: Int
let spaceForIOVectors: Int
let spaceForBufferOwners: Int
init(iovectorCount: Int) {
var spaceForIOVectors = MemoryLayout<IOVector>.stride * iovectorCount
spaceForIOVectors.roundUpToAlignment(for: Unmanaged<AnyObject>.self)
var spaceForBufferOwners = MemoryLayout<Unmanaged<AnyObject>>.stride * iovectorCount
spaceForBufferOwners.roundUpToAlignment(for: MemorySentinel.self)
self.iovectorCount = iovectorCount
self.spaceForIOVectors = spaceForIOVectors
self.spaceForBufferOwners = spaceForBufferOwners
}
var totalByteCount: Int {
self.spaceForIOVectors + self.spaceForBufferOwners + MemoryLayout<MemorySentinel>.size
}
var iovectorOffset: Int {
0
}
var bufferOwnersOffset: Int {
self.spaceForIOVectors
}
var memorySentinelOffset: Int {
self.spaceForIOVectors + self.spaceForBufferOwners
}
}
fileprivate final class BackingStorage: ManagedBuffer<PooledBufferHead, UInt8> {
static func create(iovectorCount: Int) -> Self {
let head = PooledBufferHead(iovectorCount: iovectorCount)
let baseStorage = Self.create(minimumCapacity: head.totalByteCount) { _ in
head
}
// Here we set up our memory bindings.
let storage = unsafeDowncast(baseStorage, to: Self.self)
storage.withUnsafeMutablePointers { headPointer, tailPointer in
UnsafeRawPointer(tailPointer).bindMemory(to: IOVector.self, capacity: headPointer.pointee.spaceForIOVectors)
UnsafeRawPointer(tailPointer + headPointer.pointee.spaceForIOVectors).bindMemory(to: Unmanaged<AnyObject>.self, capacity: headPointer.pointee.spaceForBufferOwners)
UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory(to: MemorySentinel.self, capacity: MemoryLayout<MemorySentinel>.size)
}
return storage
}
func withUnsafeMutableTypedPointers<ReturnType>(
_ body: (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>, UnsafeMutablePointer<MemorySentinel>) throws -> ReturnType
) rethrows -> ReturnType {
return try self.withUnsafeMutablePointers { headPointer, tailPointer in
let iovecPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.iovectorOffset).assumingMemoryBound(to: IOVector.self)
let ownersPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.bufferOwnersOffset).assumingMemoryBound(to: Unmanaged<AnyObject>.self)
let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).assumingMemoryBound(to: MemorySentinel.self)
let iovecBufferPointer = UnsafeMutableBufferPointer(
start: iovecPointer, count: headPointer.pointee.iovectorCount
)
let ownersBufferPointer = UnsafeMutableBufferPointer(
start: ownersPointer, count: headPointer.pointee.iovectorCount
)
return try body(iovecBufferPointer, ownersBufferPointer, sentinelPointer)
}
}
}
}
extension Int {
fileprivate mutating func roundUpToAlignment<Type>(for: Type.Type) {
// Alignment is always positive, we can use unchecked subtraction here.
let alignmentGuide = MemoryLayout<Type>.alignment &- 1
// But we can't use unchecked addition.
self = (self + alignmentGuide) & (~alignmentGuide)
}
}

View File

@ -30,44 +30,6 @@ internal func withAutoReleasePool<T>(_ execute: () throws -> T) rethrows -> T {
#endif
}
struct PooledBuffer: PoolElement {
private let bufferSize: Int
private let buffer: UnsafeMutableRawPointer
init() {
precondition(MemoryLayout<IOVector>.alignment >= MemoryLayout<Unmanaged<AnyObject>>.alignment)
self.bufferSize = (MemoryLayout<IOVector>.stride + MemoryLayout<Unmanaged<AnyObject>>.stride) * Socket.writevLimitIOVectors
var byteCount = self.bufferSize
debugOnly {
byteCount += MemoryLayout<UInt32>.stride
}
self.buffer = UnsafeMutableRawPointer.allocate(byteCount: byteCount, alignment: MemoryLayout<IOVector>.alignment)
debugOnly {
self.buffer.storeBytes(of: 0xdeadbee, toByteOffset: self.bufferSize, as: UInt32.self)
}
}
func evictedFromPool() {
debugOnly {
assert(0xdeadbee == self.buffer.load(fromByteOffset: self.bufferSize, as: UInt32.self))
}
self.buffer.deallocate()
}
func get() -> (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>) {
let count = Socket.writevLimitIOVectors
let iovecs = self.buffer.bindMemory(to: IOVector.self, capacity: count)
let storageRefs = (self.buffer + (count * MemoryLayout<IOVector>.stride)).bindMemory(to: Unmanaged<AnyObject>.self, capacity: count)
assert(UnsafeMutableRawPointer(iovecs) >= self.buffer)
assert(UnsafeMutableRawPointer(iovecs) <= (self.buffer + self.bufferSize))
assert(UnsafeMutableRawPointer(storageRefs) >= self.buffer)
assert(UnsafeMutableRawPointer(storageRefs) <= (self.buffer + self.bufferSize))
assert(UnsafeMutableRawPointer(iovecs + count) == UnsafeMutableRawPointer(storageRefs))
assert(UnsafeMutableRawPointer(storageRefs + count) <= (self.buffer + bufferSize))
return (UnsafeMutableBufferPointer(start: iovecs, count: count), UnsafeMutableBufferPointer(start: storageRefs, count: count))
}
}
/// `EventLoop` implementation that uses a `Selector` to get notified once there is more I/O or tasks to process.
/// The whole processing of I/O and tasks is done by a `NIOThread` that is tied to the `SelectableEventLoop`. This `NIOThread`
/// is guaranteed to never change!