Make PooledBuffer safer. (#2363)
Motivation: PooledBuffer is an inherently unsafe type, but its original incarnation was less safe than it needed to be. In particular, we can rewrite it to ensure that it is compatible with automatic reference counting. Modifications: - Rewrite PooledBuffer to use ManagedBuffer - Clean up alignment math - Use scoped accessors - Add hooks for future non-scoped access Result: Safer, clearer code
This commit is contained in:
parent
1e7ad9a0db
commit
39047aec7c
|
@ -78,80 +78,81 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite
|
|||
|
||||
let buffer = bufferPool.get()
|
||||
defer { bufferPool.put(buffer) }
|
||||
let (iovecs, storageRefs) = buffer.get()
|
||||
|
||||
for p in pending.flushedWrites {
|
||||
// Must not write more than Int32.max in one go.
|
||||
// TODO(cory): I can't see this limit documented in a man page anywhere, but it seems
|
||||
// plausible given that a similar limit exists for TCP. For now we assume it's present
|
||||
// in UDP until I can do some research to validate the existence of this limit.
|
||||
guard (Socket.writevLimitBytes - toWrite >= p.data.readableBytes) else {
|
||||
if c == 0 {
|
||||
// The first buffer is larger than the writev limit. Let's throw, and fall back to linear processing.
|
||||
throw IOError(errnoCode: EMSGSIZE, reason: "synthetic error for overlarge write")
|
||||
} else {
|
||||
return try buffer.withUnsafePointers { iovecs, storageRefs in
|
||||
for p in pending.flushedWrites {
|
||||
// Must not write more than Int32.max in one go.
|
||||
// TODO(cory): I can't see this limit documented in a man page anywhere, but it seems
|
||||
// plausible given that a similar limit exists for TCP. For now we assume it's present
|
||||
// in UDP until I can do some research to validate the existence of this limit.
|
||||
guard (Socket.writevLimitBytes - toWrite >= p.data.readableBytes) else {
|
||||
if c == 0 {
|
||||
// The first buffer is larger than the writev limit. Let's throw, and fall back to linear processing.
|
||||
throw IOError(errnoCode: EMSGSIZE, reason: "synthetic error for overlarge write")
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Must not write more than writevLimitIOVectors in one go
|
||||
guard c < Socket.writevLimitIOVectors else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Must not write more than writevLimitIOVectors in one go
|
||||
guard c < Socket.writevLimitIOVectors else {
|
||||
break
|
||||
}
|
||||
let toWriteForThisBuffer = p.data.readableBytes
|
||||
toWrite += numericCast(toWriteForThisBuffer)
|
||||
|
||||
let toWriteForThisBuffer = p.data.readableBytes
|
||||
toWrite += numericCast(toWriteForThisBuffer)
|
||||
p.data.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
|
||||
storageRefs[c] = storageRef.retain()
|
||||
|
||||
p.data.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
|
||||
storageRefs[c] = storageRef.retain()
|
||||
|
||||
/// From man page of `sendmsg(2)`:
|
||||
///
|
||||
/// > The `msg_name` field is used on an unconnected socket to specify
|
||||
/// > the target address for a datagram. It points to a buffer
|
||||
/// > containing the address; the `msg_namelen` field should be set to
|
||||
/// > the size of the address. For a connected socket, these fields
|
||||
/// > should be specified as `NULL` and 0, respectively.
|
||||
let address: UnsafeMutablePointer<sockaddr_storage>?
|
||||
let addressLen: socklen_t
|
||||
let protocolFamily: NIOBSDSocket.ProtocolFamily
|
||||
if let envelopeAddress = p.address {
|
||||
precondition(pending.remoteAddress == nil, "Pending write with address on connected socket.")
|
||||
address = addresses.baseAddress! + c
|
||||
addressLen = p.copySocketAddress(address!)
|
||||
protocolFamily = envelopeAddress.protocol
|
||||
} else {
|
||||
guard let connectedRemoteAddress = pending.remoteAddress else {
|
||||
preconditionFailure("Pending write without address on unconnected socket.")
|
||||
/// From man page of `sendmsg(2)`:
|
||||
///
|
||||
/// > The `msg_name` field is used on an unconnected socket to specify
|
||||
/// > the target address for a datagram. It points to a buffer
|
||||
/// > containing the address; the `msg_namelen` field should be set to
|
||||
/// > the size of the address. For a connected socket, these fields
|
||||
/// > should be specified as `NULL` and 0, respectively.
|
||||
let address: UnsafeMutablePointer<sockaddr_storage>?
|
||||
let addressLen: socklen_t
|
||||
let protocolFamily: NIOBSDSocket.ProtocolFamily
|
||||
if let envelopeAddress = p.address {
|
||||
precondition(pending.remoteAddress == nil, "Pending write with address on connected socket.")
|
||||
address = addresses.baseAddress! + c
|
||||
addressLen = p.copySocketAddress(address!)
|
||||
protocolFamily = envelopeAddress.protocol
|
||||
} else {
|
||||
guard let connectedRemoteAddress = pending.remoteAddress else {
|
||||
preconditionFailure("Pending write without address on unconnected socket.")
|
||||
}
|
||||
address = nil
|
||||
addressLen = 0
|
||||
protocolFamily = connectedRemoteAddress.protocol
|
||||
}
|
||||
address = nil
|
||||
addressLen = 0
|
||||
protocolFamily = connectedRemoteAddress.protocol
|
||||
|
||||
iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
|
||||
|
||||
var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c])
|
||||
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily)
|
||||
let controlMessageBytePointer = controlBytes.validControlBytes
|
||||
|
||||
let msg = msghdr(msg_name: address,
|
||||
msg_namelen: addressLen,
|
||||
msg_iov: iovecs.baseAddress! + c,
|
||||
msg_iovlen: 1,
|
||||
msg_control: controlMessageBytePointer.baseAddress,
|
||||
msg_controllen: .init(controlMessageBytePointer.count),
|
||||
msg_flags: 0)
|
||||
msgs[c] = MMsgHdr(msg_hdr: msg, msg_len: 0)
|
||||
}
|
||||
|
||||
iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
|
||||
|
||||
var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c])
|
||||
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily)
|
||||
let controlMessageBytePointer = controlBytes.validControlBytes
|
||||
|
||||
let msg = msghdr(msg_name: address,
|
||||
msg_namelen: addressLen,
|
||||
msg_iov: iovecs.baseAddress! + c,
|
||||
msg_iovlen: 1,
|
||||
msg_control: controlMessageBytePointer.baseAddress,
|
||||
msg_controllen: .init(controlMessageBytePointer.count),
|
||||
msg_flags: 0)
|
||||
msgs[c] = MMsgHdr(msg_hdr: msg, msg_len: 0)
|
||||
c += 1
|
||||
}
|
||||
c += 1
|
||||
}
|
||||
defer {
|
||||
for i in 0..<c {
|
||||
storageRefs[i].release()
|
||||
defer {
|
||||
for i in 0..<c {
|
||||
storageRefs[i].release()
|
||||
}
|
||||
}
|
||||
return try body(UnsafeMutableBufferPointer(start: msgs.baseAddress!, count: c))
|
||||
}
|
||||
return try body(UnsafeMutableBufferPointer(start: msgs.baseAddress!, count: c))
|
||||
}
|
||||
|
||||
/// This holds the states of the currently pending datagram writes. The core is a `MarkedCircularBuffer` which holds all the
|
||||
|
|
|
@ -31,46 +31,47 @@ private func doPendingWriteVectorOperation(pending: PendingStreamWritesState,
|
|||
_ body: (UnsafeBufferPointer<IOVector>) throws -> IOResult<Int>) throws -> (itemCount: Int, writeResult: IOResult<Int>) {
|
||||
let buffer = bufferPool.get()
|
||||
defer { bufferPool.put(buffer) }
|
||||
let (iovecs, storageRefs) = buffer.get()
|
||||
|
||||
// Clamp the number of writes we're willing to issue to the limit for writev.
|
||||
var count = min(iovecs.count, storageRefs.count)
|
||||
count = min(pending.flushedChunks, count)
|
||||
return try buffer.withUnsafePointers { iovecs, storageRefs in
|
||||
// Clamp the number of writes we're willing to issue to the limit for writev.
|
||||
var count = min(iovecs.count, storageRefs.count)
|
||||
count = min(pending.flushedChunks, count)
|
||||
|
||||
// the numbers of storage refs that we need to decrease later.
|
||||
var numberOfUsedStorageSlots = 0
|
||||
var toWrite: Int = 0
|
||||
// the numbers of storage refs that we need to decrease later.
|
||||
var numberOfUsedStorageSlots = 0
|
||||
var toWrite: Int = 0
|
||||
|
||||
loop: for i in 0..<count {
|
||||
let p = pending[i]
|
||||
switch p.data {
|
||||
case .byteBuffer(let buffer):
|
||||
// Must not write more than Int32.max in one go.
|
||||
guard (numberOfUsedStorageSlots == 0) || (Socket.writevLimitBytes - toWrite >= buffer.readableBytes) else {
|
||||
loop: for i in 0..<count {
|
||||
let p = pending[i]
|
||||
switch p.data {
|
||||
case .byteBuffer(let buffer):
|
||||
// Must not write more than Int32.max in one go.
|
||||
guard (numberOfUsedStorageSlots == 0) || (Socket.writevLimitBytes - toWrite >= buffer.readableBytes) else {
|
||||
break loop
|
||||
}
|
||||
let toWriteForThisBuffer = min(Socket.writevLimitBytes, buffer.readableBytes)
|
||||
toWrite += numericCast(toWriteForThisBuffer)
|
||||
|
||||
buffer.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
|
||||
storageRefs[i] = storageRef.retain()
|
||||
iovecs[i] = IOVector(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
|
||||
}
|
||||
numberOfUsedStorageSlots += 1
|
||||
case .fileRegion:
|
||||
assert(numberOfUsedStorageSlots != 0, "first item in doPendingWriteVectorOperation was a FileRegion")
|
||||
// We found a FileRegion so stop collecting
|
||||
break loop
|
||||
}
|
||||
let toWriteForThisBuffer = min(Socket.writevLimitBytes, buffer.readableBytes)
|
||||
toWrite += numericCast(toWriteForThisBuffer)
|
||||
|
||||
buffer.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
|
||||
storageRefs[i] = storageRef.retain()
|
||||
iovecs[i] = IOVector(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
|
||||
}
|
||||
defer {
|
||||
for i in 0..<numberOfUsedStorageSlots {
|
||||
storageRefs[i].release()
|
||||
}
|
||||
numberOfUsedStorageSlots += 1
|
||||
case .fileRegion:
|
||||
assert(numberOfUsedStorageSlots != 0, "first item in doPendingWriteVectorOperation was a FileRegion")
|
||||
// We found a FileRegion so stop collecting
|
||||
break loop
|
||||
}
|
||||
let result = try body(UnsafeBufferPointer(start: iovecs.baseAddress!, count: numberOfUsedStorageSlots))
|
||||
/* if we hit a limit, we really wanted to write more than we have so the caller should retry us */
|
||||
return (numberOfUsedStorageSlots, result)
|
||||
}
|
||||
defer {
|
||||
for i in 0..<numberOfUsedStorageSlots {
|
||||
storageRefs[i].release()
|
||||
}
|
||||
}
|
||||
let result = try body(UnsafeBufferPointer(start: iovecs.baseAddress!, count: numberOfUsedStorageSlots))
|
||||
/* if we hit a limit, we really wanted to write more than we have so the caller should retry us */
|
||||
return (numberOfUsedStorageSlots, result)
|
||||
}
|
||||
|
||||
/// The result of a single write operation, usually `write`, `sendfile` or `writev`.
|
||||
|
|
|
@ -50,3 +50,155 @@ class Pool<Element: PoolElement> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A ``PooledBuffer`` is used to track an allocation of memory required
|
||||
/// by a `Channel` or `EventLoopGroup`.
|
||||
///
|
||||
/// ``PooledBuffer`` is a reference type with inline storage. It is intended to
|
||||
/// be bound to a single thread, and ensures that the allocation it stores does not
|
||||
/// get freed before the buffer is out of use.
|
||||
struct PooledBuffer: PoolElement {
|
||||
private static let sentinelValue = MemorySentinel(0xdeadbeef)
|
||||
|
||||
private let storage: BackingStorage
|
||||
|
||||
init() {
|
||||
self.storage = .create(iovectorCount: Socket.writevLimitIOVectors)
|
||||
self.configureSentinel()
|
||||
}
|
||||
|
||||
func evictedFromPool() {
|
||||
self.validateSentinel()
|
||||
}
|
||||
|
||||
func withUnsafePointers<ReturnValue>(
|
||||
_ body: (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>) throws -> ReturnValue
|
||||
) rethrows -> ReturnValue {
|
||||
defer {
|
||||
self.validateSentinel()
|
||||
}
|
||||
return try self.storage.withUnsafeMutableTypedPointers { iovecPointer, ownerPointer, _ in
|
||||
try body(iovecPointer, ownerPointer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields buffer pointers containing this ``PooledBuffer``'s readable bytes. You may hold a pointer to those bytes
|
||||
/// even after the closure has returned iff you model the lifetime of those bytes correctly using the `Unmanaged`
|
||||
/// instance. If you don't require the pointer after the closure returns, use ``withUnsafePointers``.
|
||||
///
|
||||
/// If you escape the pointer from the closure, you _must_ call `storageManagement.retain()` to get ownership to
|
||||
/// the bytes and you also must call `storageManagement.release()` if you no longer require those bytes. Calls to
|
||||
/// `retain` and `release` must be balanced.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - body: The closure that will accept the yielded pointers and the `storageManagement`.
|
||||
/// - returns: The value returned by `body`.
|
||||
func withUnsafePointersWithStorageManagement<ReturnValue>(
|
||||
_ body: (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>, Unmanaged<AnyObject>) throws -> ReturnValue
|
||||
) rethrows -> ReturnValue {
|
||||
let storageRef: Unmanaged<AnyObject> = Unmanaged.passUnretained(self.storage)
|
||||
return try self.storage.withUnsafeMutableTypedPointers { iovecPointer, ownerPointer, _ in
|
||||
try body(iovecPointer, ownerPointer, storageRef)
|
||||
}
|
||||
}
|
||||
|
||||
private func configureSentinel() {
|
||||
self.storage.withUnsafeMutableTypedPointers { _, _, sentinelPointer in
|
||||
sentinelPointer.pointee = Self.sentinelValue
|
||||
}
|
||||
}
|
||||
|
||||
private func validateSentinel() {
|
||||
self.storage.withUnsafeMutableTypedPointers { _, _, sentinelPointer in
|
||||
precondition(sentinelPointer.pointee == Self.sentinelValue, "Detected memory handling error!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension PooledBuffer {
|
||||
fileprivate typealias MemorySentinel = UInt32
|
||||
|
||||
fileprivate struct PooledBufferHead {
|
||||
let iovectorCount: Int
|
||||
|
||||
let spaceForIOVectors: Int
|
||||
|
||||
let spaceForBufferOwners: Int
|
||||
|
||||
init(iovectorCount: Int) {
|
||||
var spaceForIOVectors = MemoryLayout<IOVector>.stride * iovectorCount
|
||||
spaceForIOVectors.roundUpToAlignment(for: Unmanaged<AnyObject>.self)
|
||||
|
||||
var spaceForBufferOwners = MemoryLayout<Unmanaged<AnyObject>>.stride * iovectorCount
|
||||
spaceForBufferOwners.roundUpToAlignment(for: MemorySentinel.self)
|
||||
|
||||
self.iovectorCount = iovectorCount
|
||||
self.spaceForIOVectors = spaceForIOVectors
|
||||
self.spaceForBufferOwners = spaceForBufferOwners
|
||||
}
|
||||
|
||||
var totalByteCount: Int {
|
||||
self.spaceForIOVectors + self.spaceForBufferOwners + MemoryLayout<MemorySentinel>.size
|
||||
}
|
||||
|
||||
var iovectorOffset: Int {
|
||||
0
|
||||
}
|
||||
|
||||
var bufferOwnersOffset: Int {
|
||||
self.spaceForIOVectors
|
||||
}
|
||||
|
||||
var memorySentinelOffset: Int {
|
||||
self.spaceForIOVectors + self.spaceForBufferOwners
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate final class BackingStorage: ManagedBuffer<PooledBufferHead, UInt8> {
|
||||
static func create(iovectorCount: Int) -> Self {
|
||||
let head = PooledBufferHead(iovectorCount: iovectorCount)
|
||||
|
||||
let baseStorage = Self.create(minimumCapacity: head.totalByteCount) { _ in
|
||||
head
|
||||
}
|
||||
|
||||
// Here we set up our memory bindings.
|
||||
let storage = unsafeDowncast(baseStorage, to: Self.self)
|
||||
storage.withUnsafeMutablePointers { headPointer, tailPointer in
|
||||
UnsafeRawPointer(tailPointer).bindMemory(to: IOVector.self, capacity: headPointer.pointee.spaceForIOVectors)
|
||||
UnsafeRawPointer(tailPointer + headPointer.pointee.spaceForIOVectors).bindMemory(to: Unmanaged<AnyObject>.self, capacity: headPointer.pointee.spaceForBufferOwners)
|
||||
UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory(to: MemorySentinel.self, capacity: MemoryLayout<MemorySentinel>.size)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func withUnsafeMutableTypedPointers<ReturnType>(
|
||||
_ body: (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>, UnsafeMutablePointer<MemorySentinel>) throws -> ReturnType
|
||||
) rethrows -> ReturnType {
|
||||
return try self.withUnsafeMutablePointers { headPointer, tailPointer in
|
||||
let iovecPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.iovectorOffset).assumingMemoryBound(to: IOVector.self)
|
||||
let ownersPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.bufferOwnersOffset).assumingMemoryBound(to: Unmanaged<AnyObject>.self)
|
||||
let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).assumingMemoryBound(to: MemorySentinel.self)
|
||||
|
||||
let iovecBufferPointer = UnsafeMutableBufferPointer(
|
||||
start: iovecPointer, count: headPointer.pointee.iovectorCount
|
||||
)
|
||||
let ownersBufferPointer = UnsafeMutableBufferPointer(
|
||||
start: ownersPointer, count: headPointer.pointee.iovectorCount
|
||||
)
|
||||
return try body(iovecBufferPointer, ownersBufferPointer, sentinelPointer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension Int {
|
||||
fileprivate mutating func roundUpToAlignment<Type>(for: Type.Type) {
|
||||
// Alignment is always positive, we can use unchecked subtraction here.
|
||||
let alignmentGuide = MemoryLayout<Type>.alignment &- 1
|
||||
|
||||
// But we can't use unchecked addition.
|
||||
self = (self + alignmentGuide) & (~alignmentGuide)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,44 +30,6 @@ internal func withAutoReleasePool<T>(_ execute: () throws -> T) rethrows -> T {
|
|||
#endif
|
||||
}
|
||||
|
||||
struct PooledBuffer: PoolElement {
|
||||
private let bufferSize: Int
|
||||
private let buffer: UnsafeMutableRawPointer
|
||||
|
||||
init() {
|
||||
precondition(MemoryLayout<IOVector>.alignment >= MemoryLayout<Unmanaged<AnyObject>>.alignment)
|
||||
self.bufferSize = (MemoryLayout<IOVector>.stride + MemoryLayout<Unmanaged<AnyObject>>.stride) * Socket.writevLimitIOVectors
|
||||
var byteCount = self.bufferSize
|
||||
debugOnly {
|
||||
byteCount += MemoryLayout<UInt32>.stride
|
||||
}
|
||||
self.buffer = UnsafeMutableRawPointer.allocate(byteCount: byteCount, alignment: MemoryLayout<IOVector>.alignment)
|
||||
debugOnly {
|
||||
self.buffer.storeBytes(of: 0xdeadbee, toByteOffset: self.bufferSize, as: UInt32.self)
|
||||
}
|
||||
}
|
||||
|
||||
func evictedFromPool() {
|
||||
debugOnly {
|
||||
assert(0xdeadbee == self.buffer.load(fromByteOffset: self.bufferSize, as: UInt32.self))
|
||||
}
|
||||
self.buffer.deallocate()
|
||||
}
|
||||
|
||||
func get() -> (UnsafeMutableBufferPointer<IOVector>, UnsafeMutableBufferPointer<Unmanaged<AnyObject>>) {
|
||||
let count = Socket.writevLimitIOVectors
|
||||
let iovecs = self.buffer.bindMemory(to: IOVector.self, capacity: count)
|
||||
let storageRefs = (self.buffer + (count * MemoryLayout<IOVector>.stride)).bindMemory(to: Unmanaged<AnyObject>.self, capacity: count)
|
||||
assert(UnsafeMutableRawPointer(iovecs) >= self.buffer)
|
||||
assert(UnsafeMutableRawPointer(iovecs) <= (self.buffer + self.bufferSize))
|
||||
assert(UnsafeMutableRawPointer(storageRefs) >= self.buffer)
|
||||
assert(UnsafeMutableRawPointer(storageRefs) <= (self.buffer + self.bufferSize))
|
||||
assert(UnsafeMutableRawPointer(iovecs + count) == UnsafeMutableRawPointer(storageRefs))
|
||||
assert(UnsafeMutableRawPointer(storageRefs + count) <= (self.buffer + bufferSize))
|
||||
return (UnsafeMutableBufferPointer(start: iovecs, count: count), UnsafeMutableBufferPointer(start: storageRefs, count: count))
|
||||
}
|
||||
}
|
||||
|
||||
/// `EventLoop` implementation that uses a `Selector` to get notified once there is more I/O or tasks to process.
|
||||
/// The whole processing of I/O and tasks is done by a `NIOThread` that is tied to the `SelectableEventLoop`. This `NIOThread`
|
||||
/// is guaranteed to never change!
|
||||
|
|
Loading…
Reference in New Issue