Add `collect(upTo:) -> ByteBuffer` and variations to `AsyncSequence` (#2038)

This commit is contained in:
David Nadoba 2022-02-03 19:31:09 +01:00 committed by GitHub
parent 3a3e6cb9e3
commit 4f2c6a3e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 461 additions and 2 deletions

View File

@ -78,3 +78,10 @@ This product contains a derivation of Fabian Fett's 'Base64.swift'.
* https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE * https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE
* HOMEPAGE: * HOMEPAGE:
* https://github.com/fabianfett/swift-base64-kit * https://github.com/fabianfett/swift-base64-kit
This product contains a derivation of "XCTest+AsyncAwait.swift" from AsyncHTTPClient.
* LICENSE (Apache License 2.0):
* https://www.apache.org/licenses/LICENSE-2.0
* HOMEPAGE:
* https://github.com/swift-server/async-http-client

View File

@ -2,7 +2,7 @@
// //
// This source file is part of the SwiftNIO open source project // This source file is part of the SwiftNIO open source project
// //
// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors // Copyright (c) 2021-2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0 // Licensed under Apache License v2.0
// //
// See LICENSE.txt for license information // See LICENSE.txt for license information
@ -218,4 +218,122 @@ extension ChannelPipeline {
try await self.addHandlers(handlers, position: position) try await self.addHandlers(handlers, position: position)
} }
} }
public struct NIOTooManyBytesError: Error {
public init() {}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension AsyncSequence where Element: RandomAccessCollection, Element.Element == UInt8 {
/// Accumulates an ``Swift/AsyncSequence`` of ``Swift/RandomAccessCollection``s into a single `accumulationBuffer`.
/// - Parameters:
/// - accumulationBuffer: buffer to write all the elements of `self` into
/// - maxBytes: The maximum number of bytes this method is allowed to write into `accumulationBuffer`
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
/// Note that previous elements of `self` might already be write to `accumulationBuffer`.
@inlinable
public func collect(
upTo maxBytes: Int,
into accumulationBuffer: inout ByteBuffer
) async throws {
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
var bytesRead = 0
for try await fragment in self {
bytesRead += fragment.count
guard bytesRead <= maxBytes else {
throw NIOTooManyBytesError()
}
accumulationBuffer.writeBytes(fragment)
}
}
/// Accumulates an ``Swift/AsyncSequence`` of ``Swift/RandomAccessCollection``s into a single ``NIO/ByteBuffer``.
/// - Parameters:
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate
/// - allocator: Allocator used for allocating the result `ByteBuffer`
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
@inlinable
public func collect(
upTo maxBytes: Int,
using allocator: ByteBufferAllocator
) async throws -> ByteBuffer {
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
var accumulationBuffer = allocator.buffer(capacity: Swift.min(maxBytes, 1024))
try await self.collect(upTo: maxBytes, into: &accumulationBuffer)
return accumulationBuffer
}
}
// MARK: optimised methods for ByteBuffer
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension AsyncSequence where Element == ByteBuffer {
/// Accumulates an ``Swift/AsyncSequence`` of ``ByteBuffer``s into a single `accumulationBuffer`.
/// - Parameters:
/// - accumulationBuffer: buffer to write all the elements of `self` into
/// - maxBytes: The maximum number of bytes this method is allowed to write into `accumulationBuffer`
/// - Throws: ``NIOTooManyBytesError`` if the the sequence contains more than `maxBytes`.
/// Note that previous elements of `self` might be already write to `accumulationBuffer`.
@inlinable
public func collect(
upTo maxBytes: Int,
into accumulationBuffer: inout ByteBuffer
) async throws {
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
var bytesRead = 0
for try await fragment in self {
bytesRead += fragment.readableBytes
guard bytesRead <= maxBytes else {
throw NIOTooManyBytesError()
}
accumulationBuffer.writeImmutableBuffer(fragment)
}
}
/// Accumulates an ``Swift/AsyncSequence`` of ``ByteBuffer``s into a single ``ByteBuffer``.
/// - Parameters:
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
@inlinable
public func collect(
upTo maxBytes: Int
) async throws -> ByteBuffer {
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
// we use the first `ByteBuffer` to accumulate all subsequent `ByteBuffer`s into.
// this has also the benefit of not copying at all,
// if the async sequence contains only one element.
var iterator = self.makeAsyncIterator()
guard var head = try await iterator.next() else {
return ByteBuffer()
}
guard head.readableBytes <= maxBytes else {
throw NIOTooManyBytesError()
}
let tail = AsyncSequenceFromIterator(iterator)
// it is guaranteed that
// `maxBytes >= 0 && head.readableBytes >= 0 && head.readableBytes <= maxBytes`
// This implies that `maxBytes - head.readableBytes >= 0`
// we can therefore use wrapping subtraction
try await tail.collect(upTo: maxBytes &- head.readableBytes, into: &head)
return head
}
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@usableFromInline
struct AsyncSequenceFromIterator<AsyncIterator: AsyncIteratorProtocol>: AsyncSequence {
@usableFromInline typealias Element = AsyncIterator.Element
@usableFromInline var iterator: AsyncIterator
@inlinable init(_ iterator: AsyncIterator) {
self.iterator = iterator
}
@inlinable func makeAsyncIterator() -> AsyncIterator {
self.iterator
}
}
#endif #endif

View File

@ -2,7 +2,7 @@
// //
// This source file is part of the SwiftNIO open source project // This source file is part of the SwiftNIO open source project
// //
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors // Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0 // Licensed under Apache License v2.0
// //
// See LICENSE.txt for license information // See LICENSE.txt for license information
@ -48,6 +48,7 @@ class LinuxMainRunnerImpl: LinuxMainRunner {
testCase(AdaptiveRecvByteBufferAllocatorTest.allTests), testCase(AdaptiveRecvByteBufferAllocatorTest.allTests),
testCase(AddressedEnvelopeTests.allTests), testCase(AddressedEnvelopeTests.allTests),
testCase(ApplicationProtocolNegotiationHandlerTests.allTests), testCase(ApplicationProtocolNegotiationHandlerTests.allTests),
testCase(AsyncSequenceCollectTests.allTests),
testCase(Base64Test.allTests), testCase(Base64Test.allTests),
testCase(BaseObjectTest.allTests), testCase(BaseObjectTest.allTests),
testCase(BlockingIOThreadPoolTest.allTests), testCase(BlockingIOThreadPoolTest.allTests),

View File

@ -0,0 +1,34 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//
// AsyncSequenceTests+XCTest.swift
//
import XCTest
///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///
extension AsyncSequenceCollectTests {
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
static var allTests : [(String, (AsyncSequenceCollectTests) -> () throws -> Void)] {
return [
("testAsyncSequenceCollect", testAsyncSequenceCollect),
]
}
}

View File

@ -0,0 +1,193 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
import XCTest
fileprivate struct TestCase {
var buffers: [[UInt8]]
var file: StaticString
var line: UInt
init(_ buffers: [[UInt8]], file: StaticString = #file, line: UInt = #line) {
self.buffers = buffers
self.file = file
self.line = line
}
}
final class AsyncSequenceCollectTests: XCTestCase {
func testAsyncSequenceCollect() {
#if compiler(>=5.5.2) && canImport(_Concurrency)
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest(timeout: 5) {
let testCases = [
TestCase([
[],
]),
TestCase([
[],
[],
]),
TestCase([
[0],
[],
]),
TestCase([
[],
[0],
]),
TestCase([
[0],
[1],
]),
TestCase([
[0],
[1],
]),
TestCase([
[0],
[1],
[2],
]),
TestCase([
[],
[0],
[],
[1],
[],
[2],
[],
]),
TestCase([
[0],
[1],
[2],
[],
[],
]),
TestCase([
Array(0..<10),
]),
TestCase([
Array(0..<10),
Array(10..<20),
]),
TestCase([
Array(0..<10),
Array(10..<20),
Array(20..<30),
]),
TestCase([
Array(0..<10),
Array(10..<20),
Array(20..<30),
Array(repeating: 99, count: 1000),
]),
]
for testCase in testCases {
let expectedBytes = testCase.buffers.flatMap({ $0 })
// happy case where maxBytes is exactly the same as number of buffers received
// test for the generic version
let accumulatedBytes1 = try await testCase.buffers
.asAsyncSequence()
.collect(upTo: expectedBytes.count, using: .init())
XCTAssertEqual(
accumulatedBytes1,
ByteBuffer(bytes: expectedBytes),
file: testCase.file,
line: testCase.line
)
// test for the `ByteBuffer` optimised version
let accumulatedBytes2 = try await testCase.buffers
.map(ByteBuffer.init(bytes:))
.asAsyncSequence()
.collect(upTo: expectedBytes.count)
XCTAssertEqual(
accumulatedBytes2,
ByteBuffer(bytes: expectedBytes),
file: testCase.file,
line: testCase.line
)
// unhappy case where maxBytes is one byte less than actually received
guard expectedBytes.count >= 1 else {
continue
}
// test for the generic version
await XCTAssertThrowsError(
try await testCase.buffers
.asAsyncSequence()
.collect(upTo: max(expectedBytes.count - 1, 0), using: .init()),
file: testCase.file,
line: testCase.line
) { error in
XCTAssertTrue(
error is NIOTooManyBytesError,
file: testCase.file,
line: testCase.line
)
}
// test for the `ByteBuffer` optimised version
await XCTAssertThrowsError(
try await testCase.buffers
.map(ByteBuffer.init(bytes:))
.asAsyncSequence()
.collect(upTo: max(expectedBytes.count - 1, 0)),
file: testCase.file,
line: testCase.line
) { error in
XCTAssertTrue(
error is NIOTooManyBytesError,
file: testCase.file,
line: testCase.line
)
}
}
}
#endif
}
}
#if compiler(>=5.5.2) && canImport(_Concurrency)
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
struct AsyncSequenceFromSyncSequence<Wrapped: Sequence>: AsyncSequence {
typealias Element = Wrapped.Element
struct AsyncIterator: AsyncIteratorProtocol {
fileprivate var iterator: Wrapped.Iterator
mutating func next() async throws -> Wrapped.Element? {
self.iterator.next()
}
}
fileprivate let wrapped: Wrapped
func makeAsyncIterator() -> AsyncIterator {
.init(iterator: self.wrapped.makeIterator())
}
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension Sequence {
/// Turns `self` into an `AsyncSequence` by wending each element of `self` asynchronously.
func asAsyncSequence() -> AsyncSequenceFromSyncSequence<Self> {
.init(wrapped: self)
}
}
#endif

View File

@ -0,0 +1,106 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
//
// This source file is part of the AsyncHTTPClient open source project
//
// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
/*
* Copyright 2021, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if compiler(>=5.5.2) && canImport(_Concurrency)
import XCTest
extension XCTestCase {
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
/// Cross-platform XCTest support for async-await tests.
///
/// Currently the Linux implementation of XCTest doesn't have async-await support.
/// Until it does, we make use of this shim which uses a detached `Task` along with
/// `XCTest.wait(for:timeout:)` to wrap the operation.
///
/// - NOTE: Support for Linux is tracked by https://bugs.swift.org/browse/SR-14403.
/// - NOTE: Implementation currently in progress: https://github.com/apple/swift-corelibs-xctest/pull/326
func XCTAsyncTest(
expectationDescription: String = "Async operation",
timeout: TimeInterval = 30,
file: StaticString = #filePath,
line: UInt = #line,
function: StaticString = #function,
operation: @escaping @Sendable () async throws -> Void
) {
let expectation = self.expectation(description: expectationDescription)
Task {
do {
try await operation()
} catch {
XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line)
Thread.callStackSymbols.forEach { print($0) }
}
expectation.fulfill()
}
self.wait(for: [expectation], timeout: timeout)
}
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
internal func XCTAssertThrowsError<T>(
_ expression: @autoclosure () async throws -> T,
file: StaticString = #file,
line: UInt = #line,
verify: (Error) -> Void = { _ in }
) async {
do {
_ = try await expression()
XCTFail("Expression did not throw error", file: file, line: line)
} catch {
verify(error)
}
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
internal func XCTAssertNoThrowWithResult<Result>(
_ expression: @autoclosure () async throws -> Result,
file: StaticString = #file,
line: UInt = #line
) async -> Result? {
do {
return try await expression()
} catch {
XCTFail("Expression did throw: \(error)", file: file, line: line)
}
return nil
}
#endif