Add `collect(upTo:) -> ByteBuffer` and variations to `AsyncSequence` (#2038)
This commit is contained in:
parent
3a3e6cb9e3
commit
4f2c6a3e0b
|
@ -78,3 +78,10 @@ This product contains a derivation of Fabian Fett's 'Base64.swift'.
|
|||
* https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE
|
||||
* HOMEPAGE:
|
||||
* https://github.com/fabianfett/swift-base64-kit
|
||||
|
||||
This product contains a derivation of "XCTest+AsyncAwait.swift" from AsyncHTTPClient.
|
||||
|
||||
* LICENSE (Apache License 2.0):
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* HOMEPAGE:
|
||||
* https://github.com/swift-server/async-http-client
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors
|
||||
// Copyright (c) 2021-2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
|
@ -218,4 +218,122 @@ extension ChannelPipeline {
|
|||
try await self.addHandlers(handlers, position: position)
|
||||
}
|
||||
}
|
||||
|
||||
public struct NIOTooManyBytesError: Error {
|
||||
public init() {}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension AsyncSequence where Element: RandomAccessCollection, Element.Element == UInt8 {
|
||||
/// Accumulates an ``Swift/AsyncSequence`` of ``Swift/RandomAccessCollection``s into a single `accumulationBuffer`.
|
||||
/// - Parameters:
|
||||
/// - accumulationBuffer: buffer to write all the elements of `self` into
|
||||
/// - maxBytes: The maximum number of bytes this method is allowed to write into `accumulationBuffer`
|
||||
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
|
||||
/// Note that previous elements of `self` might already be write to `accumulationBuffer`.
|
||||
@inlinable
|
||||
public func collect(
|
||||
upTo maxBytes: Int,
|
||||
into accumulationBuffer: inout ByteBuffer
|
||||
) async throws {
|
||||
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
|
||||
var bytesRead = 0
|
||||
for try await fragment in self {
|
||||
bytesRead += fragment.count
|
||||
guard bytesRead <= maxBytes else {
|
||||
throw NIOTooManyBytesError()
|
||||
}
|
||||
accumulationBuffer.writeBytes(fragment)
|
||||
}
|
||||
}
|
||||
|
||||
/// Accumulates an ``Swift/AsyncSequence`` of ``Swift/RandomAccessCollection``s into a single ``NIO/ByteBuffer``.
|
||||
/// - Parameters:
|
||||
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate
|
||||
/// - allocator: Allocator used for allocating the result `ByteBuffer`
|
||||
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
|
||||
@inlinable
|
||||
public func collect(
|
||||
upTo maxBytes: Int,
|
||||
using allocator: ByteBufferAllocator
|
||||
) async throws -> ByteBuffer {
|
||||
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
|
||||
var accumulationBuffer = allocator.buffer(capacity: Swift.min(maxBytes, 1024))
|
||||
try await self.collect(upTo: maxBytes, into: &accumulationBuffer)
|
||||
return accumulationBuffer
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: optimised methods for ByteBuffer
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
extension AsyncSequence where Element == ByteBuffer {
|
||||
/// Accumulates an ``Swift/AsyncSequence`` of ``ByteBuffer``s into a single `accumulationBuffer`.
|
||||
/// - Parameters:
|
||||
/// - accumulationBuffer: buffer to write all the elements of `self` into
|
||||
/// - maxBytes: The maximum number of bytes this method is allowed to write into `accumulationBuffer`
|
||||
/// - Throws: ``NIOTooManyBytesError`` if the the sequence contains more than `maxBytes`.
|
||||
/// Note that previous elements of `self` might be already write to `accumulationBuffer`.
|
||||
@inlinable
|
||||
public func collect(
|
||||
upTo maxBytes: Int,
|
||||
into accumulationBuffer: inout ByteBuffer
|
||||
) async throws {
|
||||
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
|
||||
var bytesRead = 0
|
||||
for try await fragment in self {
|
||||
bytesRead += fragment.readableBytes
|
||||
guard bytesRead <= maxBytes else {
|
||||
throw NIOTooManyBytesError()
|
||||
}
|
||||
accumulationBuffer.writeImmutableBuffer(fragment)
|
||||
}
|
||||
}
|
||||
|
||||
/// Accumulates an ``Swift/AsyncSequence`` of ``ByteBuffer``s into a single ``ByteBuffer``.
|
||||
/// - Parameters:
|
||||
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate
|
||||
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
|
||||
@inlinable
|
||||
public func collect(
|
||||
upTo maxBytes: Int
|
||||
) async throws -> ByteBuffer {
|
||||
precondition(maxBytes >= 0, "`maxBytes` must be greater than or equal to zero")
|
||||
// we use the first `ByteBuffer` to accumulate all subsequent `ByteBuffer`s into.
|
||||
// this has also the benefit of not copying at all,
|
||||
// if the async sequence contains only one element.
|
||||
var iterator = self.makeAsyncIterator()
|
||||
guard var head = try await iterator.next() else {
|
||||
return ByteBuffer()
|
||||
}
|
||||
guard head.readableBytes <= maxBytes else {
|
||||
throw NIOTooManyBytesError()
|
||||
}
|
||||
|
||||
let tail = AsyncSequenceFromIterator(iterator)
|
||||
// it is guaranteed that
|
||||
// `maxBytes >= 0 && head.readableBytes >= 0 && head.readableBytes <= maxBytes`
|
||||
// This implies that `maxBytes - head.readableBytes >= 0`
|
||||
// we can therefore use wrapping subtraction
|
||||
try await tail.collect(upTo: maxBytes &- head.readableBytes, into: &head)
|
||||
return head
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@usableFromInline
|
||||
struct AsyncSequenceFromIterator<AsyncIterator: AsyncIteratorProtocol>: AsyncSequence {
|
||||
@usableFromInline typealias Element = AsyncIterator.Element
|
||||
|
||||
@usableFromInline var iterator: AsyncIterator
|
||||
|
||||
@inlinable init(_ iterator: AsyncIterator) {
|
||||
self.iterator = iterator
|
||||
}
|
||||
|
||||
@inlinable func makeAsyncIterator() -> AsyncIterator {
|
||||
self.iterator
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
|
||||
// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
|
@ -48,6 +48,7 @@ class LinuxMainRunnerImpl: LinuxMainRunner {
|
|||
testCase(AdaptiveRecvByteBufferAllocatorTest.allTests),
|
||||
testCase(AddressedEnvelopeTests.allTests),
|
||||
testCase(ApplicationProtocolNegotiationHandlerTests.allTests),
|
||||
testCase(AsyncSequenceCollectTests.allTests),
|
||||
testCase(Base64Test.allTests),
|
||||
testCase(BaseObjectTest.allTests),
|
||||
testCase(BlockingIOThreadPoolTest.allTests),
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// AsyncSequenceTests+XCTest.swift
|
||||
//
|
||||
import XCTest
|
||||
|
||||
///
|
||||
/// NOTE: This file was generated by generate_linux_tests.rb
|
||||
///
|
||||
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
|
||||
///
|
||||
|
||||
extension AsyncSequenceCollectTests {
|
||||
|
||||
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
|
||||
static var allTests : [(String, (AsyncSequenceCollectTests) -> () throws -> Void)] {
|
||||
return [
|
||||
("testAsyncSequenceCollect", testAsyncSequenceCollect),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
import NIOCore
|
||||
import XCTest
|
||||
|
||||
fileprivate struct TestCase {
|
||||
var buffers: [[UInt8]]
|
||||
var file: StaticString
|
||||
var line: UInt
|
||||
init(_ buffers: [[UInt8]], file: StaticString = #file, line: UInt = #line) {
|
||||
self.buffers = buffers
|
||||
self.file = file
|
||||
self.line = line
|
||||
}
|
||||
}
|
||||
|
||||
final class AsyncSequenceCollectTests: XCTestCase {
|
||||
func testAsyncSequenceCollect() {
|
||||
#if compiler(>=5.5.2) && canImport(_Concurrency)
|
||||
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
|
||||
XCTAsyncTest(timeout: 5) {
|
||||
let testCases = [
|
||||
TestCase([
|
||||
[],
|
||||
]),
|
||||
TestCase([
|
||||
[],
|
||||
[],
|
||||
]),
|
||||
TestCase([
|
||||
[0],
|
||||
[],
|
||||
]),
|
||||
TestCase([
|
||||
[],
|
||||
[0],
|
||||
]),
|
||||
TestCase([
|
||||
[0],
|
||||
[1],
|
||||
]),
|
||||
TestCase([
|
||||
[0],
|
||||
[1],
|
||||
]),
|
||||
TestCase([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
]),
|
||||
TestCase([
|
||||
[],
|
||||
[0],
|
||||
[],
|
||||
[1],
|
||||
[],
|
||||
[2],
|
||||
[],
|
||||
]),
|
||||
TestCase([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[],
|
||||
[],
|
||||
]),
|
||||
TestCase([
|
||||
Array(0..<10),
|
||||
]),
|
||||
TestCase([
|
||||
Array(0..<10),
|
||||
Array(10..<20),
|
||||
]),
|
||||
TestCase([
|
||||
Array(0..<10),
|
||||
Array(10..<20),
|
||||
Array(20..<30),
|
||||
]),
|
||||
TestCase([
|
||||
Array(0..<10),
|
||||
Array(10..<20),
|
||||
Array(20..<30),
|
||||
Array(repeating: 99, count: 1000),
|
||||
]),
|
||||
]
|
||||
for testCase in testCases {
|
||||
let expectedBytes = testCase.buffers.flatMap({ $0 })
|
||||
|
||||
// happy case where maxBytes is exactly the same as number of buffers received
|
||||
|
||||
// test for the generic version
|
||||
let accumulatedBytes1 = try await testCase.buffers
|
||||
.asAsyncSequence()
|
||||
.collect(upTo: expectedBytes.count, using: .init())
|
||||
XCTAssertEqual(
|
||||
accumulatedBytes1,
|
||||
ByteBuffer(bytes: expectedBytes),
|
||||
file: testCase.file,
|
||||
line: testCase.line
|
||||
)
|
||||
|
||||
// test for the `ByteBuffer` optimised version
|
||||
let accumulatedBytes2 = try await testCase.buffers
|
||||
.map(ByteBuffer.init(bytes:))
|
||||
.asAsyncSequence()
|
||||
.collect(upTo: expectedBytes.count)
|
||||
XCTAssertEqual(
|
||||
accumulatedBytes2,
|
||||
ByteBuffer(bytes: expectedBytes),
|
||||
file: testCase.file,
|
||||
line: testCase.line
|
||||
)
|
||||
|
||||
// unhappy case where maxBytes is one byte less than actually received
|
||||
guard expectedBytes.count >= 1 else {
|
||||
continue
|
||||
}
|
||||
|
||||
// test for the generic version
|
||||
await XCTAssertThrowsError(
|
||||
try await testCase.buffers
|
||||
.asAsyncSequence()
|
||||
.collect(upTo: max(expectedBytes.count - 1, 0), using: .init()),
|
||||
file: testCase.file,
|
||||
line: testCase.line
|
||||
) { error in
|
||||
XCTAssertTrue(
|
||||
error is NIOTooManyBytesError,
|
||||
file: testCase.file,
|
||||
line: testCase.line
|
||||
)
|
||||
}
|
||||
|
||||
// test for the `ByteBuffer` optimised version
|
||||
await XCTAssertThrowsError(
|
||||
try await testCase.buffers
|
||||
.map(ByteBuffer.init(bytes:))
|
||||
.asAsyncSequence()
|
||||
.collect(upTo: max(expectedBytes.count - 1, 0)),
|
||||
file: testCase.file,
|
||||
line: testCase.line
|
||||
) { error in
|
||||
XCTAssertTrue(
|
||||
error is NIOTooManyBytesError,
|
||||
file: testCase.file,
|
||||
line: testCase.line
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#if compiler(>=5.5.2) && canImport(_Concurrency)
|
||||
|
||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||
struct AsyncSequenceFromSyncSequence<Wrapped: Sequence>: AsyncSequence {
|
||||
typealias Element = Wrapped.Element
|
||||
struct AsyncIterator: AsyncIteratorProtocol {
|
||||
fileprivate var iterator: Wrapped.Iterator
|
||||
mutating func next() async throws -> Wrapped.Element? {
|
||||
self.iterator.next()
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate let wrapped: Wrapped
|
||||
|
||||
func makeAsyncIterator() -> AsyncIterator {
|
||||
.init(iterator: self.wrapped.makeIterator())
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||
extension Sequence {
|
||||
/// Turns `self` into an `AsyncSequence` by wending each element of `self` asynchronously.
|
||||
func asAsyncSequence() -> AsyncSequenceFromSyncSequence<Self> {
|
||||
.init(wrapped: self)
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,106 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the AsyncHTTPClient open source project
|
||||
//
|
||||
// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
/*
|
||||
* Copyright 2021, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#if compiler(>=5.5.2) && canImport(_Concurrency)
|
||||
import XCTest
|
||||
|
||||
extension XCTestCase {
|
||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||
/// Cross-platform XCTest support for async-await tests.
|
||||
///
|
||||
/// Currently the Linux implementation of XCTest doesn't have async-await support.
|
||||
/// Until it does, we make use of this shim which uses a detached `Task` along with
|
||||
/// `XCTest.wait(for:timeout:)` to wrap the operation.
|
||||
///
|
||||
/// - NOTE: Support for Linux is tracked by https://bugs.swift.org/browse/SR-14403.
|
||||
/// - NOTE: Implementation currently in progress: https://github.com/apple/swift-corelibs-xctest/pull/326
|
||||
func XCTAsyncTest(
|
||||
expectationDescription: String = "Async operation",
|
||||
timeout: TimeInterval = 30,
|
||||
file: StaticString = #filePath,
|
||||
line: UInt = #line,
|
||||
function: StaticString = #function,
|
||||
operation: @escaping @Sendable () async throws -> Void
|
||||
) {
|
||||
let expectation = self.expectation(description: expectationDescription)
|
||||
Task {
|
||||
do {
|
||||
try await operation()
|
||||
} catch {
|
||||
XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line)
|
||||
Thread.callStackSymbols.forEach { print($0) }
|
||||
}
|
||||
expectation.fulfill()
|
||||
}
|
||||
self.wait(for: [expectation], timeout: timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||
internal func XCTAssertThrowsError<T>(
|
||||
_ expression: @autoclosure () async throws -> T,
|
||||
file: StaticString = #file,
|
||||
line: UInt = #line,
|
||||
verify: (Error) -> Void = { _ in }
|
||||
) async {
|
||||
do {
|
||||
_ = try await expression()
|
||||
XCTFail("Expression did not throw error", file: file, line: line)
|
||||
} catch {
|
||||
verify(error)
|
||||
}
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||
internal func XCTAssertNoThrowWithResult<Result>(
|
||||
_ expression: @autoclosure () async throws -> Result,
|
||||
file: StaticString = #file,
|
||||
line: UInt = #line
|
||||
) async -> Result? {
|
||||
do {
|
||||
return try await expression()
|
||||
} catch {
|
||||
XCTFail("Expression did throw: \(error)", file: file, line: line)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue