Initial commit.
This commit is contained in:
commit
db9edd5ede
|
@ -0,0 +1,3 @@
|
|||
Package.resolved
|
||||
.build
|
||||
PostgresWireServer.xcodeproj
|
|
@ -0,0 +1,23 @@
|
|||
// swift-tools-version:4.0
|
||||
import PackageDescription
|
||||
|
||||
let package = Package(
|
||||
name: "PostgresWireServer",
|
||||
products: [
|
||||
.library(name: "PostgresWireServer", targets: ["PostgresWireServer"]),
|
||||
.executable(name: "PostgresWireServerExample", targets: ["PostgresWireServerExample"]),
|
||||
],
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/IBM-Swift/BlueSocket.git", from: Version("0.2.0")),
|
||||
.package(url: "https://github.com/IBM-Swift/HeliumLogger.git", from: Version("1.7.1")),
|
||||
],
|
||||
targets: [
|
||||
.target(name: "PostgresWireServer", dependencies: [
|
||||
"Socket",
|
||||
"HeliumLogger"
|
||||
], path: "Sources/PostgresWireServer"),
|
||||
.target(name: "PostgresWireServerExample", dependencies: [
|
||||
"PostgresWireServer"
|
||||
], path: "Sources/PostgresWireServerExample")
|
||||
]
|
||||
)
|
|
@ -0,0 +1,43 @@
|
|||
# Postgres wire server
|
||||
|
||||
This package allows you to create a server implementing the PostgreSQL wire protocol (PQ).
|
||||
|
||||
## How it works
|
||||
|
||||
See [usage example](/Sources/PostgresWireServerExample/main.swift).
|
||||
|
||||
## Installation
|
||||
|
||||
#### Swift Package Manager (SPM)
|
||||
|
||||
You can install the driver using Swift Package Manager by adding the following line to your ```Package.swift``` as a dependency:
|
||||
|
||||
```
|
||||
.Package(url: "https://github.com/pixelspark/postgres-wire-server.git", majorVersion: 1)
|
||||
```
|
||||
|
||||
To use in an Xcode project, generate an Xcode project file using SPM:
|
||||
```
|
||||
swift package generate-xcodeproj
|
||||
```
|
||||
|
||||
## MIT license
|
||||
|
||||
````
|
||||
Copyright (c) 2017 Pixelspark, Tommy van der Vorst
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
````
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions of all kinds - from typo fixes to complete refactors and new features. Just be sure to contact us if you want to work on something big, to prevent double effort. You can help in the following ways:
|
||||
|
||||
* Open an issue with suggestions for improvements
|
||||
* Submit a pull request (bug fix, new feature, improved documentation)
|
||||
|
||||
Note that before we can accept any new code to the repository, we need you to confirm in writing that your contribution is made available to us under the terms of the MIT license.
|
|
@ -0,0 +1,821 @@
|
|||
import Foundation
|
||||
import Socket
|
||||
import Dispatch
|
||||
import LoggerAPI
|
||||
|
||||
public final class QueryClientConnection<PreparedStatementType: PreparedStatement> {
|
||||
private enum State {
|
||||
case new
|
||||
case ready // building a query
|
||||
case querying // sending results
|
||||
case closed
|
||||
}
|
||||
|
||||
public private(set) var username: String? = nil
|
||||
public private(set) var password: String? = nil
|
||||
public private(set) var majorVersion: UInt16? = nil
|
||||
public private(set) var minorVersion: UInt16? = nil
|
||||
|
||||
let socket: Socket
|
||||
private var state: State = .new
|
||||
private weak var server: QueryServer<PreparedStatementType>?
|
||||
private var portals: [String: Portal<PreparedStatementType>] = [:]
|
||||
private var preparedStatements: [String: PreparedStatementType] = [:]
|
||||
private var currentPortalName: String? = nil
|
||||
private var bufferedData = Data()
|
||||
|
||||
private let bufferSize = 4096
|
||||
private let isLittleEndian = Int32(42).littleEndian == Int32(42)
|
||||
|
||||
public init(socket: Socket, server: QueryServer<PreparedStatementType>) {
|
||||
self.socket = socket
|
||||
self.server = server
|
||||
self.run()
|
||||
}
|
||||
|
||||
deinit {
|
||||
switch self.state {
|
||||
case .closed:
|
||||
break
|
||||
default:
|
||||
self.server?.connection(didClose: self)
|
||||
self.socket.close()
|
||||
}
|
||||
}
|
||||
|
||||
private func readInt32() -> UInt32? {
|
||||
do {
|
||||
if let data = try self.read(length: 4) {
|
||||
let x = data.map { return UInt8($0) }
|
||||
return self.isLittleEndian ? UInt32(x) : UInt32(x).byteSwapped
|
||||
}
|
||||
return nil
|
||||
}
|
||||
catch {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func readParameters(length: UInt32) throws -> [String: String]? {
|
||||
if let data = try self.read(length: length) {
|
||||
let elements = data.split(separator: 0x0, maxSplits: Int.max, omittingEmptySubsequences: false)
|
||||
let strs = elements.map { d -> String in
|
||||
let dx = Data(d)
|
||||
return String(data: dx, encoding: .utf8) ?? ""
|
||||
}
|
||||
|
||||
var parameters: [String: String] = [:]
|
||||
for idx in stride(from: 0, to: strs.count, by: 2) {
|
||||
parameters[strs[idx]] = strs[idx+1]
|
||||
}
|
||||
return parameters
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func read(length: UInt32) throws -> Data? {
|
||||
var data = Data(capacity: Int(length))
|
||||
|
||||
// Do we have any leftover bytes?
|
||||
if !bufferedData.isEmpty {
|
||||
let maxBuffered = min(Int(length), bufferedData.count)
|
||||
data.append(bufferedData.subdata(in: 0..<maxBuffered))
|
||||
bufferedData.removeSubrange(0..<maxBuffered)
|
||||
}
|
||||
|
||||
// Fetch bytes from socket
|
||||
while data.count < Int(length) {
|
||||
let n = try self.socket.read(into: &data)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Save leftover bytes for later
|
||||
if data.count > length {
|
||||
bufferedData.append(data.subdata(in: Int(length)..<data.count))
|
||||
data.removeSubrange(Int(length)..<data.count)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
private func readByte() throws -> CChar? {
|
||||
if let data = try self.read(length: 1) {
|
||||
let values = data.map { $0 }
|
||||
return CChar(values[0])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func readAuthentication() throws -> String? {
|
||||
if try self.readByte() == CChar(Character("p").codePoint) {
|
||||
// Password authentication, get password
|
||||
if let len = self.readInt32(), let pwData = try self.read(length: len - UInt32(4)) {
|
||||
return String(data: pwData.subdata(in: 0..<Int(len-4-1)), encoding: .utf8)
|
||||
}
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func readBind() throws -> Bool {
|
||||
/* Bind message
|
||||
Byte1('B') Identifies the message as a Bind command.
|
||||
Int32 Length of message contents in bytes, including self.
|
||||
String The name of the destination portal (an empty string selects the unnamed
|
||||
portal).
|
||||
String The name of the source prepared statement (an empty string selects the
|
||||
unnamed prepared statement).
|
||||
*/
|
||||
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
||||
var reader = DataReader(data: messageData)
|
||||
if let destinationPortalName = reader.readZeroTerminatedString(),
|
||||
let preparedStatementName = reader.readZeroTerminatedString(),
|
||||
let numberOfParameterFormatCodes = reader.readUInt16() {
|
||||
/* Int16 The number of parameter format codes that follow (denoted C below). This
|
||||
can be zero to indicate that there are no parameters or that the parameters
|
||||
all use the default format (text); or one, in which case the specified
|
||||
format code is applied to all parameters; or it can equal the actual number
|
||||
of parameters.
|
||||
|
||||
Int16[C] The parameter format codes. Each must presently be zero (text) or one (binary). */
|
||||
let parameterFormatCodes = try (0..<numberOfParameterFormatCodes).map { _ -> PQFormat in
|
||||
guard let r = reader.readUInt16() else { throw QueryServerError.protocolError }
|
||||
guard let format = PQFormat(rawValue: r) else { throw QueryServerError.protocolError }
|
||||
return format
|
||||
}
|
||||
|
||||
/*
|
||||
Int16 The number of parameter values that follow (possibly zero). This must match
|
||||
the number of parameters needed by the query.
|
||||
|
||||
Next, the following pair of fields appear for each parameter:
|
||||
|
||||
Int32 The length of the parameter value, in bytes (this count does not include
|
||||
itself). Can be zero. As a special case, -1 indicates a NULL parameter
|
||||
value. No value bytes follow in the NULL case.
|
||||
Byten The value of the parameter, in the format indicated by the associated
|
||||
format code. n is the above length. */
|
||||
guard let numberOfParameterValues = reader.readUInt16() else { throw QueryServerError.protocolError }
|
||||
|
||||
let parameterValues = try (0..<numberOfParameterValues).map { _ -> Data in
|
||||
guard let length = reader.readUInt32() else { throw QueryServerError.protocolError }
|
||||
guard let bytes = reader.readBytes(Int(length)) else { throw QueryServerError.protocolError }
|
||||
return bytes
|
||||
}
|
||||
|
||||
/* After the last parameter, the following fields appear:
|
||||
|
||||
Int16 The number of result-column format codes that follow (denoted R below).
|
||||
This can be zero to indicate that there are no result columns or that the
|
||||
result columns should all use the default format (text); or one, in which
|
||||
case the specified format code is applied to all result columns (if any);
|
||||
or it can equal the actual number of result columns of the query.
|
||||
|
||||
Int16[R] The result-column format codes. Each must presently be zero (text) or
|
||||
one (binary). */
|
||||
/*guard let numberOfResultFormatCodes = reader.readUInt16() else { throw QueryServerError.protocolError }
|
||||
let resultFormatCodes = try (0..<(Int(numberOfResultFormatCodes))).map { _ -> PQFormat in
|
||||
guard let r = reader.readUInt16() else { throw QueryServerError.protocolError }
|
||||
guard let format = PQFormat(rawValue: r) else { throw QueryServerError.protocolError }
|
||||
return format
|
||||
}*/
|
||||
|
||||
let parsedParameterValues = try self.parse(parameters: parameterValues, formats: parameterFormatCodes)
|
||||
|
||||
if let statement = self.preparedStatements[preparedStatementName] {
|
||||
if let _ = self.portals[destinationPortalName], !destinationPortalName.isEmpty {
|
||||
throw QueryServerError.portalAlreadyExists
|
||||
}
|
||||
else {
|
||||
self.portals[destinationPortalName] = Portal(statement: statement, parameters: parsedParameterValues)
|
||||
|
||||
/* Should send message BindComplete to client:
|
||||
Byte1('2') Identifies the message as a Bind-complete indicator.
|
||||
Int32(4) Length of message contents in bytes, including self. */
|
||||
let buf = Data(bytes: [UInt8(Character("2").codePoint), 0, 0, 0, 4])
|
||||
try self.socket.write(from: buf)
|
||||
self.state = .ready
|
||||
return true
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Statement not found
|
||||
throw QueryServerError.preparedStatementNotFound
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func parse(parameters: [Data], formats: [PQFormat]) throws -> [PQValue] {
|
||||
var values = Array<PQValue>(repeating: PQValue.null, count: parameters.count)
|
||||
for (idx, data) in parameters.enumerated() {
|
||||
let format: PQFormat
|
||||
if idx < formats.count {
|
||||
format = formats[idx]
|
||||
}
|
||||
else {
|
||||
// When there is only one format code, this is the one we will use
|
||||
// Otherwise, default to text.
|
||||
format = (formats.count == 1) ? formats[0] : .text
|
||||
}
|
||||
|
||||
switch format {
|
||||
case .text:
|
||||
if let s = String(data: data, encoding: .utf8) {
|
||||
values.append(PQValue.text(s))
|
||||
}
|
||||
else {
|
||||
values.append(PQValue.null)
|
||||
}
|
||||
|
||||
case .binary:
|
||||
if let s = String(data: data, encoding: .utf8) {
|
||||
values.append(PQValue.text(s))
|
||||
}
|
||||
else {
|
||||
values.append(PQValue.null)
|
||||
}
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
private func readClose() throws -> Bool {
|
||||
/* Close message
|
||||
Byte1('C') Identifies the message as a Close command.
|
||||
Int32 Length of message contents in bytes, including self.
|
||||
Byte1 'S' to close a prepared statement; or 'P' to close a portal.
|
||||
String The name of the prepared statement or portal to close (an empty string
|
||||
selects the unnamed prepared statement or portal). */
|
||||
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
||||
var reader = DataReader(data: messageData)
|
||||
if let type = reader.readBytes(1), let name = reader.readZeroTerminatedString() {
|
||||
if type[0] == CChar(Character("S").codePoint) {
|
||||
// Close prepared statement
|
||||
if let _ = self.preparedStatements[name] {
|
||||
self.preparedStatements[name] = nil
|
||||
|
||||
// Send close complete
|
||||
let buf = Data(bytes: [UInt8(Character("3").codePoint), 0, 0, 0, 4])
|
||||
try self.socket.write(from: buf)
|
||||
return true
|
||||
}
|
||||
else {
|
||||
throw QueryServerError.preparedStatementNotFound
|
||||
}
|
||||
}
|
||||
else if type[0] == CChar(Character("P").codePoint) {
|
||||
// Close portal
|
||||
if let _ = self.portals[name] {
|
||||
self.portals[name] = nil
|
||||
|
||||
// Send close complete
|
||||
let buf = Data(bytes: [UInt8(Character("3").codePoint), 0, 0, 0, 4])
|
||||
try self.socket.write(from: buf)
|
||||
return true
|
||||
}
|
||||
else {
|
||||
throw QueryServerError.portalNotFound
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw QueryServerError.protocolError
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private func readParse() throws -> Bool {
|
||||
/* Parse message: this should parse a statement into a prepared statement and store
|
||||
it somewhere in a [name: statement] dictionary. If the name is omitted, the statement
|
||||
is erased at the next Parse (with name=unnamed) or Query. A stored prepared statement
|
||||
cannot be overwritten unless first closed (except for the unnamed one).
|
||||
|
||||
Byte1('P') Identifies the message as a Parse command.
|
||||
Int32 Length of message contents in bytes, including self.
|
||||
String The name of the destination prepared statement (an empty string selects
|
||||
the unnamed prepared statement).
|
||||
String The query string to be parsed.
|
||||
Int16 The number of parameter data types specified (may be zero). Note that this
|
||||
is not an indication of the number of parameters that might appear in the
|
||||
query string, only the number that the frontend wants to prespecify types
|
||||
for.
|
||||
|
||||
Then, for each parameter, there is the following:
|
||||
|
||||
Int32 Specifies the object ID of the parameter data type. Placing a zero here is
|
||||
equivalent to leaving the type unspecified. */
|
||||
if let len = self.readInt32(), let messageData = try self.read(length: len - UInt32(4)) {
|
||||
var reader = DataReader(data: messageData)
|
||||
|
||||
if let destinationName = reader.readZeroTerminatedString(),
|
||||
let query = reader.readZeroTerminatedString() {
|
||||
// let numParameterType = reader.readUInt16() {
|
||||
// let parameterTypes = (0..<numParameterType).map { _ in return reader.readUInt32() ?? 0 }
|
||||
|
||||
// Remember prepared statement
|
||||
guard let server = self.server else { return false }
|
||||
let statement = try server.prepare(query, connection: self)
|
||||
if self.preparedStatements[destinationName] != nil && !destinationName.isEmpty {
|
||||
throw QueryServerError.preparedStatementAlreadyExists
|
||||
}
|
||||
self.preparedStatements[destinationName] = statement
|
||||
|
||||
// Send ParseComplete message ('1' + In32(5) indicating length of total message)
|
||||
let buf = Data(bytes: [UInt8(Character("1").codePoint), 0, 0, 0, 4])
|
||||
try self.socket.write(from: buf)
|
||||
self.state = .ready
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private func sendReadyForQuery() throws {
|
||||
// Send 'ready for query' (Z 5 I)
|
||||
let buf = Data(bytes: [UInt8(Character("Z").codePoint), 0, 0, 0, 5, UInt8(Character("I").codePoint)])
|
||||
try self.socket.write(from: buf)
|
||||
}
|
||||
|
||||
private func sendRowDescription(for statement: PreparedStatementType) throws {
|
||||
if statement.willReturnRows {
|
||||
/* RowDescription (B)
|
||||
Byte1('T') Identifies the message as a row description.
|
||||
Int32 Length of message contents in bytes, including self.
|
||||
Int16 Specifies the number of fields in a row (may be zero).
|
||||
|
||||
Then, for each field, there is the following:
|
||||
|
||||
String The field name.
|
||||
Int32 If the field can be identified as a column of a specific table, the object ID
|
||||
of the table; otherwise zero.
|
||||
Int16 If the field can be identified as a column of a specific table, the attribute
|
||||
number of the column; otherwise zero.
|
||||
Int32 The object ID of the field's data type.
|
||||
Int16 The data type size (see pg_type.typlen). Note that negative values denote
|
||||
variable-width types.
|
||||
Int32 The type modifier (see pg_attribute.atttypmod). The meaning of the modifier
|
||||
is type-specific.
|
||||
Int16 The format code being used for the field. Currently will be zero (text) or
|
||||
one (binary). In a RowDescription returned from the statement variant of
|
||||
Describe, the format code is not yet known and will always be zero. */
|
||||
|
||||
/// Request columns from prepared statement and send description
|
||||
if let cp = self.currentPortalName, let portal = self.portals[cp] {
|
||||
try send(description: try statement.fields(for: portal.parameters))
|
||||
}
|
||||
else {
|
||||
try send(description: try statement.fields(for: []))
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Send NoData response
|
||||
let buf = Data(bytes: [UInt8(Character("n").codePoint), 0, 0, 0, 4])
|
||||
try self.socket.write(from: buf)
|
||||
}
|
||||
}
|
||||
|
||||
/** Send a result set back to the client. */
|
||||
private func send(result: ResultSet) throws {
|
||||
if let e = result.error {
|
||||
try self.send(error: e)
|
||||
return
|
||||
}
|
||||
|
||||
// Send back result
|
||||
while result.hasRow {
|
||||
let row = try result.row()
|
||||
try self.send(row: row)
|
||||
}
|
||||
}
|
||||
|
||||
private func readDescribe() throws -> Bool {
|
||||
/* Describe (F)
|
||||
Byte1('D') Identifies the message as a Describe command.
|
||||
Int32 Length of message contents in bytes, including self.
|
||||
Byte1 'S' to describe a prepared statement; or 'P' to describe a portal.
|
||||
String The name of the prepared statement or portal to describe (an empty string
|
||||
selects the unnamed prepared statement or portal). */
|
||||
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
||||
var reader = DataReader(data: messageData)
|
||||
|
||||
if let type = reader.readBytes(1), let name = reader.readZeroTerminatedString() {
|
||||
if type[0] == UInt8(Character("S").codePoint) {
|
||||
// Describe statement
|
||||
guard let s = self.preparedStatements[name] else { throw QueryServerError.preparedStatementNotFound }
|
||||
|
||||
// Send parameter description packet
|
||||
// 't' + length of message (Int32) + parameter count (Int16)
|
||||
let buf = Data(bytes: [UInt8(Character("t").codePoint), 0, 0, 0, 7, 0, 0])
|
||||
try self.socket.write(from: buf)
|
||||
try self.sendRowDescription(for: s)
|
||||
return true
|
||||
}
|
||||
else if type[0] == UInt8(Character("P").codePoint) {
|
||||
// Describe portal
|
||||
guard let s = self.portals[name] else { throw QueryServerError.portalNotFound }
|
||||
try self.sendRowDescription(for: s.statement)
|
||||
return true
|
||||
}
|
||||
else {
|
||||
throw QueryServerError.protocolError
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw QueryServerError.protocolError
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private func readQuery() throws -> Bool {
|
||||
if let len = self.readInt32(), let queryData = try self.read(length: len - UInt32(4)) {
|
||||
let trimmed = queryData.subdata(in: 0..<(queryData.endIndex.advanced(by: -1)))
|
||||
if let q = String(data: trimmed, encoding: .utf8) {
|
||||
if let server = self.server {
|
||||
let st = try server.prepare(q, connection: self)
|
||||
self.state = .querying
|
||||
self.currentPortalName = nil
|
||||
try self.sendRowDescription(for: st)
|
||||
|
||||
// When result is nil, there is no result
|
||||
try server.query(st, parameters: [], connection: self) { resultSet in
|
||||
if let result = resultSet {
|
||||
assert(st.willReturnRows, "results may only be returned when the statement promised to do so")
|
||||
try self.send(result: result)
|
||||
}
|
||||
else {
|
||||
assert(!st.willReturnRows, "statements that promise to return rows should result in a non-nil result set")
|
||||
}
|
||||
|
||||
// TODO: obtain tag from prepared statement
|
||||
try self.sendQueryComplete(tag: "SELECT")
|
||||
try self.sendReadyForQuery()
|
||||
self.state = .ready
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func readExecute() throws -> Bool {
|
||||
/* Execute message.
|
||||
Byte1('E') Identifies the message as an Execute command.
|
||||
Int32 Length of message contents in bytes, including self.
|
||||
String The name of the portal to execute (an empty string selects the unnamed
|
||||
portal).
|
||||
Int32 Maximum number of rows to return, if portal contains a query that returns
|
||||
rows (ignored otherwise). Zero denotes "no limit". */
|
||||
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
||||
var reader = DataReader(data: messageData)
|
||||
if let portalName = reader.readZeroTerminatedString() {
|
||||
guard let portal = self.portals[portalName] else { throw QueryServerError.portalNotFound }
|
||||
guard let s = self.server else { return false }
|
||||
self.state = .querying
|
||||
self.currentPortalName = portalName
|
||||
try s.query(portal.statement, parameters: portal.parameters, connection: self) { result in
|
||||
if let result = result {
|
||||
try self.send(result: result)
|
||||
}
|
||||
try self.sendQueryComplete(tag: "SELECT") // TODO fetch tag from command
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private func readSync() throws -> Bool {
|
||||
/* Sync message
|
||||
Byte1('S') Identifies the message as a Sync command.
|
||||
Int32(4) Length of message contents in bytes, including self. */
|
||||
if let messageLength = self.readInt32() {
|
||||
_ = try self.read(length: messageLength - 4)
|
||||
|
||||
// Close the active portal
|
||||
if let portalName = self.currentPortalName {
|
||||
self.portals[portalName] = nil
|
||||
self.currentPortalName = nil
|
||||
self.state = .ready
|
||||
try self.sendReadyForQuery()
|
||||
return true
|
||||
}
|
||||
else {
|
||||
throw QueryServerError.portalNotFound
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/** Reads the next packet in preparing/ready state. Returns whether the connection should continue
|
||||
to process packets. */
|
||||
private func readPacket() throws -> Bool {
|
||||
guard let messageType = try self.readByte() else { return false }
|
||||
guard let messageLetter = Unicode.Scalar(Int(messageType)) else { throw QueryServerError.protocolError }
|
||||
|
||||
switch Character(messageLetter) {
|
||||
case "C": return try self.readClose()
|
||||
case "E": return try self.readExecute()
|
||||
case "P": return try self.readParse()
|
||||
case "Q": return try self.readQuery()
|
||||
case "B": return try self.readBind()
|
||||
case "D": return try self.readDescribe()
|
||||
case "S": return try self.readSync()
|
||||
case "X": return false
|
||||
default:
|
||||
throw QueryServerError.protocolError
|
||||
}
|
||||
}
|
||||
|
||||
func send(row: [PQValue]) throws {
|
||||
switch self.state {
|
||||
case .querying, .closed: break
|
||||
default: fatalError("invalid state")
|
||||
}
|
||||
|
||||
var buf = Data()
|
||||
|
||||
for value in row {
|
||||
switch value {
|
||||
case .null:
|
||||
buf.append(bytesOf: Int32(0).bigEndian)
|
||||
default:
|
||||
let data = value.text.data(using: .utf8)!
|
||||
buf.append(bytesOf: Int32(data.count + 1).bigEndian)
|
||||
buf.append(data)
|
||||
buf.append(0)
|
||||
}
|
||||
}
|
||||
|
||||
var packet = Data()
|
||||
packet.append(UInt8(Character("D").codePoint))
|
||||
packet.append(bytesOf: Int32(buf.count + 4 + 2).bigEndian)
|
||||
packet.append(bytesOf: Int16(row.count).bigEndian)
|
||||
packet.append(buf)
|
||||
try self.socket.write(from: packet)
|
||||
}
|
||||
|
||||
private func send(error: String, severity: PQSeverity = .error, code: String = "42000", endsQuery: Bool = true) throws {
|
||||
var buf = Data()
|
||||
buf.append(UInt8(Character("S").codePoint))
|
||||
let sd = severity.rawValue.data(using: .utf8)!
|
||||
buf.append(sd)
|
||||
buf.append(0)
|
||||
|
||||
buf.append(UInt8(Character("C").codePoint))
|
||||
let cd = code.data(using: .utf8)!
|
||||
buf.append(cd)
|
||||
buf.append(0)
|
||||
|
||||
buf.append(UInt8(Character("M").codePoint))
|
||||
let md = error.data(using: .utf8)!
|
||||
buf.append(md)
|
||||
buf.append(0)
|
||||
|
||||
// Message terminator
|
||||
buf.append(0)
|
||||
|
||||
|
||||
var packet = Data()
|
||||
packet.append(UInt8(Character("E").codePoint))
|
||||
packet.append(bytesOf: Int32(buf.count + 4).bigEndian)
|
||||
packet.append(buf)
|
||||
try self.socket.write(from: packet)
|
||||
|
||||
if endsQuery {
|
||||
self.state = .ready
|
||||
self.run()
|
||||
}
|
||||
}
|
||||
|
||||
private func send(description: [PQField]) throws {
|
||||
var buffer = Data()
|
||||
|
||||
for field in description {
|
||||
let fn = field.name.data(using: .utf8)
|
||||
buffer.append(fn!)
|
||||
buffer.append(0)
|
||||
buffer.append(bytesOf: field.tableId.bigEndian)
|
||||
buffer.append(bytesOf: field.columnId.bigEndian)
|
||||
buffer.append(bytesOf: field.type.rawValue.bigEndian)
|
||||
buffer.append(bytesOf: field.type.typeSize.bigEndian)
|
||||
buffer.append(bytesOf: field.typeModifier.bigEndian)
|
||||
buffer.append(bytesOf: Int16(0).bigEndian) // Binary=1, text=0
|
||||
}
|
||||
|
||||
var packet = Data()
|
||||
packet.append(UInt8(Character("T").codePoint))
|
||||
packet.append(bytesOf: Int32(6 + buffer.count).bigEndian)
|
||||
packet.append(bytesOf: Int16(description.count).bigEndian)
|
||||
packet.append(buffer)
|
||||
try self.socket.write(from: packet)
|
||||
}
|
||||
|
||||
private func sendQueryComplete(tag: String) throws {
|
||||
let data = tag.data(using: .utf8)!
|
||||
|
||||
var packet = Data()
|
||||
packet.append(bytesOf: UInt8(Character("C").codePoint))
|
||||
packet.append(bytesOf: UInt32(data.count + 4 + 1).bigEndian)
|
||||
packet.append(data)
|
||||
packet.append(UInt8(0))
|
||||
try self.socket.write(from: packet)
|
||||
}
|
||||
|
||||
private func run() {
|
||||
// Get the global concurrent queue...
|
||||
let queue = DispatchQueue.global(qos: .default)
|
||||
|
||||
// Create the run loop work item and dispatch to the default priority global queue...
|
||||
queue.async { [unowned self] in
|
||||
var shouldKeepRunning = true
|
||||
do {
|
||||
switch self.state {
|
||||
case .new:
|
||||
if let len = self.readInt32(), let msg = self.readInt32() {
|
||||
if len == 8 && msg == 80877103 {
|
||||
// No SSL, thank you
|
||||
try self.socket.write(from: "N")
|
||||
}
|
||||
else if len > UInt32(8) {
|
||||
// Read client version number
|
||||
self.majorVersion = UInt16(msg >> 16)
|
||||
self.minorVersion = UInt16(msg & 0xFFFF)
|
||||
|
||||
// Read parameters
|
||||
if let p = try self.readParameters(length: len - UInt32(8)) {
|
||||
self.username = p["user"]
|
||||
|
||||
// Send authentication request
|
||||
let buf = Data(bytes: [UInt8(Character("R").codePoint), 0, 0, 0, 8, 0, 0, 0, 3])
|
||||
try self.socket.write(from: buf)
|
||||
|
||||
// Read authentication
|
||||
if let pw = try self.readAuthentication() {
|
||||
self.password = pw
|
||||
|
||||
// Send authentication success
|
||||
let buf = Data(bytes: [UInt8(Character("R").codePoint), 0, 0, 0, 8, 0, 0, 0, 0])
|
||||
try self.socket.write(from: buf)
|
||||
|
||||
self.state = .ready
|
||||
try self.sendReadyForQuery()
|
||||
}
|
||||
else {
|
||||
shouldKeepRunning = false
|
||||
}
|
||||
}
|
||||
else {
|
||||
shouldKeepRunning = false
|
||||
}
|
||||
}
|
||||
else {
|
||||
shouldKeepRunning = false
|
||||
}
|
||||
}
|
||||
|
||||
case .ready, .querying:
|
||||
if try !self.readPacket() {
|
||||
shouldKeepRunning = false
|
||||
}
|
||||
|
||||
case .closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
catch {
|
||||
try? self.send(error: error.localizedDescription)
|
||||
shouldKeepRunning = false
|
||||
}
|
||||
|
||||
if shouldKeepRunning {
|
||||
self.run()
|
||||
}
|
||||
else {
|
||||
self.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func close() {
|
||||
switch self.state {
|
||||
case .closed:
|
||||
break
|
||||
default:
|
||||
self.socket.close()
|
||||
self.state = .closed
|
||||
self.server?.connection(didClose: self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** A portal is a prepared statement with bound parameters. */
|
||||
fileprivate class Portal<PreparedStatementType: PreparedStatement> {
|
||||
let statement: PreparedStatementType
|
||||
let parameters: [PQValue]
|
||||
var result: ResultSet? = nil
|
||||
|
||||
init(statement: PreparedStatementType, parameters: [PQValue]) {
|
||||
self.statement = statement
|
||||
self.parameters = parameters
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fileprivate struct DataReader {
|
||||
private static let isLittleEndian = Int32(42).littleEndian == Int32(42)
|
||||
private var data: Data
|
||||
|
||||
init(data: Data) {
|
||||
self.data = data
|
||||
}
|
||||
|
||||
mutating func readZeroTerminatedString() -> String? {
|
||||
if let nullIndex = data.index(of: 0), let str = String(data: data.subdata(in: 0..<nullIndex), encoding: .utf8) {
|
||||
data = data.subdata(in: nullIndex.advanced(by: 1)..<data.endIndex)
|
||||
return str
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mutating func readBytes(_ length: Int) -> Data? {
|
||||
if data.count >= length {
|
||||
let read = data.subdata(in: 0..<length)
|
||||
data = data.subdata(in: data.startIndex.advanced(by: length)..<data.endIndex)
|
||||
return read
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mutating func readUInt16() -> UInt16? {
|
||||
if data.count >= 2 {
|
||||
let values = data.subdata(in: 0..<2).map { $0 }
|
||||
let number = DataReader.isLittleEndian ? UInt16(values) : UInt16(values).byteSwapped
|
||||
data = data.subdata(in: 2..<data.endIndex)
|
||||
return number
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mutating func readUInt32() -> UInt32? {
|
||||
if data.count >= 4 {
|
||||
let values = data.subdata(in: 0..<4).map { $0 }
|
||||
let number = DataReader.isLittleEndian ? UInt32(values) : UInt32(values).byteSwapped
|
||||
data = data.subdata(in: 4..<data.endIndex)
|
||||
return number
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate extension UnsignedInteger {
|
||||
init(_ bytes: [UInt8]) {
|
||||
precondition(bytes.count <= MemoryLayout<Self>.size)
|
||||
|
||||
var value : UInt64 = 0
|
||||
|
||||
for byte in bytes {
|
||||
value <<= 8
|
||||
value |= UInt64(byte)
|
||||
}
|
||||
|
||||
self.init(value)
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate extension Character {
|
||||
var codePoint: Int {
|
||||
get {
|
||||
let s = String(self).unicodeScalars
|
||||
return Int(s[s.startIndex].value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate extension Data {
|
||||
mutating func append<T>(bytesOf value: T) {
|
||||
var value = value
|
||||
let byteCount = MemoryLayout<T>.size
|
||||
withUnsafePointer(to: &value) { ptr in
|
||||
ptr.withMemoryRebound(to: UInt8.self, capacity: byteCount) { rptr in
|
||||
self.append(rptr, count: byteCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
import Foundation
|
||||
import Socket
|
||||
import Dispatch
|
||||
import HeliumLogger
|
||||
|
||||
public struct PQField {
|
||||
var name: String
|
||||
var tableId: Int32 = 0
|
||||
var columnId: Int16 = 0
|
||||
var type: PQFieldType
|
||||
var typeModifier: Int32 = -1
|
||||
|
||||
public init(name: String, type: PQFieldType) {
|
||||
self.name = name
|
||||
self.type = type
|
||||
}
|
||||
}
|
||||
|
||||
/** List of Postgres types by Oid. More can be found by querying a Postgres instance:
|
||||
SELECT ' case ' || typname || ' = ' || oid FROM pg_type; */
|
||||
public enum PQFieldType: Int32 {
|
||||
case int = 23
|
||||
case text = 25
|
||||
case bool = 16
|
||||
case float4 = 700
|
||||
case float8 = 701
|
||||
case null = 0
|
||||
|
||||
var typeSize: Int16 {
|
||||
switch self {
|
||||
case .int: return 4
|
||||
case .bool: return 1
|
||||
case .float4: return 4
|
||||
case .float8: return 8
|
||||
case .null: return 0
|
||||
case .text: return -1 // variable length
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public enum PQValue {
|
||||
case int(Int32)
|
||||
case text(String)
|
||||
case bool(Bool)
|
||||
case float4(Double)
|
||||
case float8(Double)
|
||||
case null
|
||||
|
||||
var type: PQFieldType {
|
||||
switch self {
|
||||
case .bool(_): return .bool
|
||||
case .float4(_): return .float4
|
||||
case .float8(_): return .float8
|
||||
case .int(_): return .int
|
||||
case .text(_): return .text
|
||||
case .null: return .null
|
||||
}
|
||||
}
|
||||
|
||||
var text: String {
|
||||
switch self {
|
||||
case .text(let s): return s
|
||||
case .null: return ""
|
||||
case .bool(let b): return b ? "t" : "f"
|
||||
case .float4(let d): return "\(d)"
|
||||
case .float8(let d): return "\(d)"
|
||||
case .int(let i): return "\(i)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public enum QueryServerError: LocalizedError {
|
||||
case protocolError
|
||||
case preparedStatementNotFound
|
||||
case portalAlreadyExists
|
||||
case portalNotFound
|
||||
case preparedStatementAlreadyExists
|
||||
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .protocolError: return "protocol error"
|
||||
case .preparedStatementNotFound: return "prepared statement was not found"
|
||||
case .portalAlreadyExists: return "portal already exists"
|
||||
case .preparedStatementAlreadyExists: return "prepared statement already exists"
|
||||
case .portalNotFound: return "portal not found"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal enum PQSeverity: String {
|
||||
case error = "ERROR"
|
||||
case fatal = "FATAL"
|
||||
case info = "INFO"
|
||||
}
|
||||
|
||||
internal enum PQFormat: UInt16 {
|
||||
case text = 0
|
||||
case binary = 1
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
import Foundation
|
||||
import Socket
|
||||
import LoggerAPI
|
||||
import Dispatch
|
||||
|
||||
/** Represents a prepared statement. Prepared statements are instantiated by your QueryServer
|
||||
subclass's instance method `prepare`. */
|
||||
public protocol PreparedStatement {
|
||||
/** Whether execution of this statement will (can) return any rows. Usually 'true' for SELECT,
|
||||
'false' for DDL/DML statements. */
|
||||
var willReturnRows: Bool { get }
|
||||
|
||||
func fields(for parameters: [PQValue]) throws -> [PQField]
|
||||
}
|
||||
|
||||
/** Represents a query result set. The rows returned from this result set should match the columns
|
||||
returned by a call to the `fields` method on the corresponding PreparedStatement. */
|
||||
public protocol ResultSet: class {
|
||||
var error: String? { get }
|
||||
var hasRow: Bool { get }
|
||||
func row() throws -> [PQValue]
|
||||
}
|
||||
|
||||
/** The query server listens on a socket and instantiates QueryClientConnection objects for each
|
||||
client - this object will further handle communications. The QueryServer class should be subclasses
|
||||
to implement a server. The methods `prepare` and `query` should be overridden. */
|
||||
open class QueryServer<PreparedStatementType: PreparedStatement> {
|
||||
public enum Family {
|
||||
case ipv4
|
||||
case ipv6
|
||||
|
||||
fileprivate var socketFamily: Socket.ProtocolFamily {
|
||||
switch self {
|
||||
case .ipv4: return Socket.ProtocolFamily.inet
|
||||
case .ipv6: return Socket.ProtocolFamily.inet6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public let port: Int
|
||||
public let family: Family
|
||||
|
||||
private var connectedSockets = [Int32: QueryClientConnection<PreparedStatementType>]()
|
||||
private var listenSocket: Socket? = nil
|
||||
private var continueRunning = true
|
||||
private let socketLockQueue = DispatchQueue(label: "popsiql.socketLock")
|
||||
|
||||
public init(port: Int, family: Family = .ipv6) {
|
||||
self.port = port
|
||||
self.family = family
|
||||
}
|
||||
|
||||
deinit {
|
||||
self.connectedSockets = [:]
|
||||
self.listenSocket?.close()
|
||||
}
|
||||
|
||||
func connection(didClose connection: QueryClientConnection<PreparedStatementType>) {
|
||||
let fd = connection.socket.socketfd
|
||||
self.socketLockQueue.async {
|
||||
self.connectedSockets[fd] = nil
|
||||
}
|
||||
}
|
||||
|
||||
/** Overriden by child classes; returns a prepared statement for the given SQL query string. */
|
||||
open func prepare(_ sql: String, connection: QueryClientConnection<PreparedStatementType>) throws -> PreparedStatementType {
|
||||
fatalError("Must override")
|
||||
}
|
||||
|
||||
/** Overridden by child classes to perform queries. Should return nil for empty results (e.g.
|
||||
DML/DDL commands) when statement.willReturnRows is false. */
|
||||
open func query(_ query: PreparedStatementType, parameters: [PQValue], connection: QueryClientConnection<PreparedStatementType>, callback: @escaping (ResultSet?) throws -> ()) throws {
|
||||
fatalError("Must override")
|
||||
}
|
||||
|
||||
public func run() {
|
||||
let queue = DispatchQueue.global(qos: .userInteractive)
|
||||
|
||||
queue.async { [weak self] in
|
||||
do {
|
||||
// Create an IPV6 socket...
|
||||
if let s = self {
|
||||
s.listenSocket = try Socket.create(family: s.family.socketFamily)
|
||||
|
||||
guard let socket = self?.listenSocket else {
|
||||
return
|
||||
}
|
||||
|
||||
try socket.listen(on: s.port)
|
||||
}
|
||||
else {
|
||||
return
|
||||
}
|
||||
|
||||
repeat {
|
||||
if let s = self?.listenSocket {
|
||||
let newSocket = try s.acceptClientConnection()
|
||||
self?.addNewConnection(socket: newSocket)
|
||||
}
|
||||
|
||||
} while (self?.continueRunning ?? false)
|
||||
|
||||
}
|
||||
catch let error {
|
||||
guard let socketError = error as? Socket.Error else {
|
||||
return
|
||||
}
|
||||
|
||||
if self?.continueRunning ?? false {
|
||||
Log.error("[PSQL] Error reported:\n \(socketError.description)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func addNewConnection(socket: Socket) {
|
||||
do {
|
||||
try socket.setBlocking(mode: true)
|
||||
}
|
||||
catch {
|
||||
Log.error("[PSQL] Could not set blocking mode: \(error.localizedDescription)")
|
||||
}
|
||||
|
||||
// Add the new socket to the list of connected sockets...
|
||||
socketLockQueue.sync { [unowned self, socket] in
|
||||
self.connectedSockets[socket.socketfd] = QueryClientConnection<PreparedStatementType>(socket: socket, server: self)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
import PostgresWireServer
|
||||
import Dispatch
|
||||
|
||||
class MyPreparedStatement: PreparedStatement {
|
||||
let sql: String
|
||||
|
||||
init(sql: String) throws {
|
||||
self.sql = sql
|
||||
}
|
||||
|
||||
public var willReturnRows: Bool {
|
||||
return true
|
||||
}
|
||||
|
||||
public func fields(for parameters: [PQValue]) throws -> [PQField] {
|
||||
return [PQField(name: "foo", type: .text)]
|
||||
}
|
||||
}
|
||||
|
||||
class MyResultSet: ResultSet {
|
||||
let rows = [["bar"]]
|
||||
var idx = 0
|
||||
|
||||
func row() throws -> [PQValue] {
|
||||
assert(self.hasRow, "should not request next row when has no row")
|
||||
let v = rows[idx].map { PQValue.text($0) }
|
||||
idx += 1
|
||||
return v
|
||||
}
|
||||
|
||||
var hasRow: Bool {
|
||||
return idx < self.rows.count
|
||||
}
|
||||
|
||||
var error: String? {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
class MyQueryServer: QueryServer<MyPreparedStatement> {
|
||||
override public func prepare(_ sql: String, connection: QueryClientConnection<MyPreparedStatement>) throws -> MyPreparedStatement {
|
||||
return try MyPreparedStatement(sql: sql)
|
||||
}
|
||||
|
||||
public override func query(_ query: MyPreparedStatement, parameters: [PQValue], connection: QueryClientConnection<MyPreparedStatement>, callback: @escaping (ResultSet?) throws -> ()) throws {
|
||||
try callback(MyResultSet())
|
||||
}
|
||||
}
|
||||
|
||||
let server = MyQueryServer(port: 6789)
|
||||
|
||||
server.run()
|
||||
|
||||
dispatchMain()
|
Loading…
Reference in New Issue