824 lines
26 KiB
Swift
824 lines
26 KiB
Swift
import Foundation
|
|
import Socket
|
|
import Dispatch
|
|
import LoggerAPI
|
|
|
|
public final class QueryClientConnection<PreparedStatementType: PreparedStatement> {
|
|
private enum State {
|
|
case new
|
|
case ready // building a query
|
|
case querying // sending results
|
|
case closed
|
|
}
|
|
|
|
public private(set) var username: String? = nil
|
|
public private(set) var password: String? = nil
|
|
public private(set) var majorVersion: UInt16? = nil
|
|
public private(set) var minorVersion: UInt16? = nil
|
|
|
|
let socket: Socket
|
|
private var state: State = .new
|
|
private weak var server: QueryServer<PreparedStatementType>?
|
|
private var portals: [String: Portal<PreparedStatementType>] = [:]
|
|
private var preparedStatements: [String: PreparedStatementType] = [:]
|
|
private var currentPortalName: String? = nil
|
|
private var bufferedData = Data()
|
|
|
|
private let bufferSize = 4096
|
|
private let isLittleEndian = Int32(42).littleEndian == Int32(42)
|
|
|
|
public init(socket: Socket, server: QueryServer<PreparedStatementType>) {
|
|
self.socket = socket
|
|
self.server = server
|
|
self.run()
|
|
}
|
|
|
|
deinit {
|
|
switch self.state {
|
|
case .closed:
|
|
break
|
|
default:
|
|
self.server?.connection(didClose: self)
|
|
self.socket.close()
|
|
}
|
|
}
|
|
|
|
private func readInt32() -> UInt32? {
|
|
do {
|
|
if let data = try self.read(length: 4) {
|
|
let x = data.map { return UInt8($0) }
|
|
return self.isLittleEndian ? UInt32(x) : UInt32(x).byteSwapped
|
|
}
|
|
return nil
|
|
}
|
|
catch {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
private func readParameters(length: UInt32) throws -> [String: String]? {
|
|
if let data = try self.read(length: length) {
|
|
let elements = data.split(separator: 0x0, maxSplits: Int.max, omittingEmptySubsequences: false)
|
|
let strs = elements.map { d -> String in
|
|
let dx = Data(d)
|
|
return String(data: dx, encoding: .utf8) ?? ""
|
|
}
|
|
|
|
var parameters: [String: String] = [:]
|
|
for idx in stride(from: 0, to: strs.count, by: 2) {
|
|
parameters[strs[idx]] = strs[idx+1]
|
|
}
|
|
return parameters
|
|
}
|
|
return nil
|
|
}
|
|
|
|
private func read(length: UInt32) throws -> Data? {
|
|
var data = Data(capacity: Int(length))
|
|
|
|
// Do we have any leftover bytes?
|
|
if !bufferedData.isEmpty {
|
|
let maxBuffered = min(Int(length), bufferedData.count)
|
|
data.append(bufferedData.subdata(in: 0..<maxBuffered))
|
|
bufferedData.removeSubrange(0..<maxBuffered)
|
|
}
|
|
|
|
// Fetch bytes from socket
|
|
while data.count < Int(length) {
|
|
let n = try self.socket.read(into: &data)
|
|
if n <= 0 {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Save leftover bytes for later
|
|
if data.count > length {
|
|
bufferedData.append(data.subdata(in: Int(length)..<data.count))
|
|
data.removeSubrange(Int(length)..<data.count)
|
|
}
|
|
return data
|
|
}
|
|
|
|
private func readByte() throws -> CChar? {
|
|
if let data = try self.read(length: 1) {
|
|
let values = data.map { $0 }
|
|
return CChar(values[0])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
private func readAuthentication() throws -> String? {
|
|
if try self.readByte() == CChar(Character("p").codePoint) {
|
|
// Password authentication, get password
|
|
if let len = self.readInt32(), let pwData = try self.read(length: len - UInt32(4)) {
|
|
return String(data: pwData.subdata(in: 0..<Int(len-4-1)), encoding: .utf8)
|
|
}
|
|
else {
|
|
return nil
|
|
}
|
|
}
|
|
else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
private func readBind() throws -> Bool {
|
|
/* Bind message
|
|
Byte1('B') Identifies the message as a Bind command.
|
|
Int32 Length of message contents in bytes, including self.
|
|
String The name of the destination portal (an empty string selects the unnamed
|
|
portal).
|
|
String The name of the source prepared statement (an empty string selects the
|
|
unnamed prepared statement).
|
|
*/
|
|
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
|
var reader = DataReader(data: messageData)
|
|
if let destinationPortalName = reader.readZeroTerminatedString(),
|
|
let preparedStatementName = reader.readZeroTerminatedString(),
|
|
let numberOfParameterFormatCodes = reader.readUInt16() {
|
|
/* Int16 The number of parameter format codes that follow (denoted C below). This
|
|
can be zero to indicate that there are no parameters or that the parameters
|
|
all use the default format (text); or one, in which case the specified
|
|
format code is applied to all parameters; or it can equal the actual number
|
|
of parameters.
|
|
|
|
Int16[C] The parameter format codes. Each must presently be zero (text) or one (binary). */
|
|
let parameterFormatCodes = try (0..<numberOfParameterFormatCodes).map { _ -> PQFormat in
|
|
guard let r = reader.readUInt16() else { throw QueryServerError.protocolError }
|
|
guard let format = PQFormat(rawValue: r) else { throw QueryServerError.protocolError }
|
|
return format
|
|
}
|
|
|
|
/*
|
|
Int16 The number of parameter values that follow (possibly zero). This must match
|
|
the number of parameters needed by the query.
|
|
|
|
Next, the following pair of fields appear for each parameter:
|
|
|
|
Int32 The length of the parameter value, in bytes (this count does not include
|
|
itself). Can be zero. As a special case, -1 indicates a NULL parameter
|
|
value. No value bytes follow in the NULL case.
|
|
Byten The value of the parameter, in the format indicated by the associated
|
|
format code. n is the above length. */
|
|
guard let numberOfParameterValues = reader.readUInt16() else { throw QueryServerError.protocolError }
|
|
|
|
let parameterValues = try (0..<numberOfParameterValues).map { _ -> Data in
|
|
guard let length = reader.readUInt32() else { throw QueryServerError.protocolError }
|
|
guard let bytes = reader.readBytes(Int(length)) else { throw QueryServerError.protocolError }
|
|
return bytes
|
|
}
|
|
|
|
/* After the last parameter, the following fields appear:
|
|
|
|
Int16 The number of result-column format codes that follow (denoted R below).
|
|
This can be zero to indicate that there are no result columns or that the
|
|
result columns should all use the default format (text); or one, in which
|
|
case the specified format code is applied to all result columns (if any);
|
|
or it can equal the actual number of result columns of the query.
|
|
|
|
Int16[R] The result-column format codes. Each must presently be zero (text) or
|
|
one (binary). */
|
|
/*guard let numberOfResultFormatCodes = reader.readUInt16() else { throw QueryServerError.protocolError }
|
|
let resultFormatCodes = try (0..<(Int(numberOfResultFormatCodes))).map { _ -> PQFormat in
|
|
guard let r = reader.readUInt16() else { throw QueryServerError.protocolError }
|
|
guard let format = PQFormat(rawValue: r) else { throw QueryServerError.protocolError }
|
|
return format
|
|
}*/
|
|
|
|
let parsedParameterValues = try self.parse(parameters: parameterValues, formats: parameterFormatCodes)
|
|
|
|
if let statement = self.preparedStatements[preparedStatementName] {
|
|
if let _ = self.portals[destinationPortalName], !destinationPortalName.isEmpty {
|
|
throw QueryServerError.portalAlreadyExists
|
|
}
|
|
else {
|
|
self.portals[destinationPortalName] = Portal(statement: statement, parameters: parsedParameterValues)
|
|
|
|
/* Should send message BindComplete to client:
|
|
Byte1('2') Identifies the message as a Bind-complete indicator.
|
|
Int32(4) Length of message contents in bytes, including self. */
|
|
let buf = Data(bytes: [UInt8(Character("2").codePoint), 0, 0, 0, 4])
|
|
try self.socket.write(from: buf)
|
|
self.state = .ready
|
|
return true
|
|
}
|
|
}
|
|
else {
|
|
// Statement not found
|
|
throw QueryServerError.preparedStatementNotFound
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
private func parse(parameters: [Data], formats: [PQFormat]) throws -> [PQValue] {
|
|
var values = Array<PQValue>(repeating: PQValue.null, count: parameters.count)
|
|
for (idx, data) in parameters.enumerated() {
|
|
let format: PQFormat
|
|
if idx < formats.count {
|
|
format = formats[idx]
|
|
}
|
|
else {
|
|
// When there is only one format code, this is the one we will use
|
|
// Otherwise, default to text.
|
|
format = (formats.count == 1) ? formats[0] : .text
|
|
}
|
|
|
|
switch format {
|
|
case .text:
|
|
if let s = String(data: data, encoding: .utf8) {
|
|
values.append(PQValue.text(s))
|
|
}
|
|
else {
|
|
values.append(PQValue.null)
|
|
}
|
|
|
|
case .binary:
|
|
if let s = String(data: data, encoding: .utf8) {
|
|
values.append(PQValue.text(s))
|
|
}
|
|
else {
|
|
values.append(PQValue.null)
|
|
}
|
|
}
|
|
}
|
|
return values
|
|
}
|
|
|
|
private func readClose() throws -> Bool {
|
|
/* Close message
|
|
Byte1('C') Identifies the message as a Close command.
|
|
Int32 Length of message contents in bytes, including self.
|
|
Byte1 'S' to close a prepared statement; or 'P' to close a portal.
|
|
String The name of the prepared statement or portal to close (an empty string
|
|
selects the unnamed prepared statement or portal). */
|
|
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
|
var reader = DataReader(data: messageData)
|
|
if let type = reader.readBytes(1), let name = reader.readZeroTerminatedString() {
|
|
if type[0] == CChar(Character("S").codePoint) {
|
|
// Close prepared statement
|
|
if let _ = self.preparedStatements[name] {
|
|
self.preparedStatements[name] = nil
|
|
|
|
// Send close complete
|
|
let buf = Data(bytes: [UInt8(Character("3").codePoint), 0, 0, 0, 4])
|
|
try self.socket.write(from: buf)
|
|
return true
|
|
}
|
|
else {
|
|
throw QueryServerError.preparedStatementNotFound
|
|
}
|
|
}
|
|
else if type[0] == CChar(Character("P").codePoint) {
|
|
// Close portal
|
|
if let _ = self.portals[name] {
|
|
self.portals[name] = nil
|
|
|
|
// Send close complete
|
|
let buf = Data(bytes: [UInt8(Character("3").codePoint), 0, 0, 0, 4])
|
|
try self.socket.write(from: buf)
|
|
return true
|
|
}
|
|
else {
|
|
throw QueryServerError.portalNotFound
|
|
}
|
|
}
|
|
else {
|
|
throw QueryServerError.protocolError
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
private func readParse() throws -> Bool {
|
|
/* Parse message: this should parse a statement into a prepared statement and store
|
|
it somewhere in a [name: statement] dictionary. If the name is omitted, the statement
|
|
is erased at the next Parse (with name=unnamed) or Query. A stored prepared statement
|
|
cannot be overwritten unless first closed (except for the unnamed one).
|
|
|
|
Byte1('P') Identifies the message as a Parse command.
|
|
Int32 Length of message contents in bytes, including self.
|
|
String The name of the destination prepared statement (an empty string selects
|
|
the unnamed prepared statement).
|
|
String The query string to be parsed.
|
|
Int16 The number of parameter data types specified (may be zero). Note that this
|
|
is not an indication of the number of parameters that might appear in the
|
|
query string, only the number that the frontend wants to prespecify types
|
|
for.
|
|
|
|
Then, for each parameter, there is the following:
|
|
|
|
Int32 Specifies the object ID of the parameter data type. Placing a zero here is
|
|
equivalent to leaving the type unspecified. */
|
|
if let len = self.readInt32(), let messageData = try self.read(length: len - UInt32(4)) {
|
|
var reader = DataReader(data: messageData)
|
|
|
|
if let destinationName = reader.readZeroTerminatedString(),
|
|
let query = reader.readZeroTerminatedString() {
|
|
// let numParameterType = reader.readUInt16() {
|
|
// let parameterTypes = (0..<numParameterType).map { _ in return reader.readUInt32() ?? 0 }
|
|
|
|
// Remember prepared statement
|
|
guard let server = self.server else { return false }
|
|
let statement = try server.prepare(query, connection: self)
|
|
if self.preparedStatements[destinationName] != nil && !destinationName.isEmpty {
|
|
throw QueryServerError.preparedStatementAlreadyExists
|
|
}
|
|
self.preparedStatements[destinationName] = statement
|
|
|
|
// Send ParseComplete message ('1' + In32(5) indicating length of total message)
|
|
let buf = Data(bytes: [UInt8(Character("1").codePoint), 0, 0, 0, 4])
|
|
try self.socket.write(from: buf)
|
|
self.state = .ready
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
private func sendReadyForQuery() throws {
|
|
// Send 'ready for query' (Z 5 I)
|
|
let buf = Data(bytes: [UInt8(Character("Z").codePoint), 0, 0, 0, 5, UInt8(Character("I").codePoint)])
|
|
try self.socket.write(from: buf)
|
|
}
|
|
|
|
private func sendRowDescription(for statement: PreparedStatementType) throws {
|
|
if statement.willReturnRows {
|
|
/* RowDescription (B)
|
|
Byte1('T') Identifies the message as a row description.
|
|
Int32 Length of message contents in bytes, including self.
|
|
Int16 Specifies the number of fields in a row (may be zero).
|
|
|
|
Then, for each field, there is the following:
|
|
|
|
String The field name.
|
|
Int32 If the field can be identified as a column of a specific table, the object ID
|
|
of the table; otherwise zero.
|
|
Int16 If the field can be identified as a column of a specific table, the attribute
|
|
number of the column; otherwise zero.
|
|
Int32 The object ID of the field's data type.
|
|
Int16 The data type size (see pg_type.typlen). Note that negative values denote
|
|
variable-width types.
|
|
Int32 The type modifier (see pg_attribute.atttypmod). The meaning of the modifier
|
|
is type-specific.
|
|
Int16 The format code being used for the field. Currently will be zero (text) or
|
|
one (binary). In a RowDescription returned from the statement variant of
|
|
Describe, the format code is not yet known and will always be zero. */
|
|
|
|
/// Request columns from prepared statement and send description
|
|
if let cp = self.currentPortalName, let portal = self.portals[cp] {
|
|
try send(description: try statement.fields(for: portal.parameters))
|
|
}
|
|
else {
|
|
try send(description: try statement.fields(for: []))
|
|
}
|
|
}
|
|
else {
|
|
// Send NoData response
|
|
let buf = Data(bytes: [UInt8(Character("n").codePoint), 0, 0, 0, 4])
|
|
try self.socket.write(from: buf)
|
|
}
|
|
}
|
|
|
|
/** Send a result set back to the client. */
|
|
private func send(result: ResultSet) throws {
|
|
if let e = result.error {
|
|
try self.send(error: e)
|
|
return
|
|
}
|
|
|
|
// Send back result
|
|
while result.hasRow {
|
|
let row = try result.row()
|
|
try self.send(row: row)
|
|
}
|
|
}
|
|
|
|
private func readDescribe() throws -> Bool {
|
|
/* Describe (F)
|
|
Byte1('D') Identifies the message as a Describe command.
|
|
Int32 Length of message contents in bytes, including self.
|
|
Byte1 'S' to describe a prepared statement; or 'P' to describe a portal.
|
|
String The name of the prepared statement or portal to describe (an empty string
|
|
selects the unnamed prepared statement or portal). */
|
|
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
|
var reader = DataReader(data: messageData)
|
|
|
|
if let type = reader.readBytes(1), let name = reader.readZeroTerminatedString() {
|
|
if type[0] == UInt8(Character("S").codePoint) {
|
|
// Describe statement
|
|
guard let s = self.preparedStatements[name] else { throw QueryServerError.preparedStatementNotFound }
|
|
|
|
// Send parameter description packet
|
|
// 't' + length of message (Int32) + parameter count (Int16)
|
|
let buf = Data(bytes: [UInt8(Character("t").codePoint), 0, 0, 0, 7, 0, 0])
|
|
try self.socket.write(from: buf)
|
|
try self.sendRowDescription(for: s)
|
|
return true
|
|
}
|
|
else if type[0] == UInt8(Character("P").codePoint) {
|
|
// Describe portal
|
|
guard let s = self.portals[name] else { throw QueryServerError.portalNotFound }
|
|
try self.sendRowDescription(for: s.statement)
|
|
return true
|
|
}
|
|
else {
|
|
throw QueryServerError.protocolError
|
|
}
|
|
}
|
|
else {
|
|
throw QueryServerError.protocolError
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
private func readQuery() throws -> Bool {
|
|
if let len = self.readInt32(), let queryData = try self.read(length: len - UInt32(4)) {
|
|
let trimmed = queryData.subdata(in: 0..<(queryData.endIndex.advanced(by: -1)))
|
|
if let q = String(data: trimmed, encoding: .utf8) {
|
|
if let server = self.server {
|
|
let st = try server.prepare(q, connection: self)
|
|
self.state = .querying
|
|
self.currentPortalName = nil
|
|
try self.sendRowDescription(for: st)
|
|
|
|
// When result is nil, there is no result
|
|
try server.query(st, parameters: [], connection: self) { resultSet in
|
|
if let result = resultSet {
|
|
assert(st.willReturnRows, "results may only be returned when the statement promised to do so")
|
|
try self.send(result: result)
|
|
}
|
|
else {
|
|
assert(!st.willReturnRows, "statements that promise to return rows should result in a non-nil result set")
|
|
}
|
|
|
|
// TODO: obtain tag from prepared statement
|
|
try self.sendQueryComplete(tag: "SELECT")
|
|
try self.sendReadyForQuery()
|
|
self.state = .ready
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
private func readExecute() throws -> Bool {
|
|
/* Execute message.
|
|
Byte1('E') Identifies the message as an Execute command.
|
|
Int32 Length of message contents in bytes, including self.
|
|
String The name of the portal to execute (an empty string selects the unnamed
|
|
portal).
|
|
Int32 Maximum number of rows to return, if portal contains a query that returns
|
|
rows (ignored otherwise). Zero denotes "no limit". */
|
|
if let messageLength = self.readInt32(), let messageData = try self.read(length: messageLength - 4) {
|
|
var reader = DataReader(data: messageData)
|
|
if let portalName = reader.readZeroTerminatedString() {
|
|
guard let portal = self.portals[portalName] else { throw QueryServerError.portalNotFound }
|
|
guard let s = self.server else { return false }
|
|
self.state = .querying
|
|
self.currentPortalName = portalName
|
|
try s.query(portal.statement, parameters: portal.parameters, connection: self) { result in
|
|
if let result = result {
|
|
try self.send(result: result)
|
|
}
|
|
try self.sendQueryComplete(tag: "SELECT") // TODO fetch tag from command
|
|
}
|
|
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
private func readSync() throws -> Bool {
|
|
/* Sync message
|
|
Byte1('S') Identifies the message as a Sync command.
|
|
Int32(4) Length of message contents in bytes, including self. */
|
|
if let messageLength = self.readInt32() {
|
|
_ = try self.read(length: messageLength - 4)
|
|
|
|
// Close the active portal
|
|
if let portalName = self.currentPortalName {
|
|
self.portals[portalName] = nil
|
|
self.currentPortalName = nil
|
|
self.state = .ready
|
|
try self.sendReadyForQuery()
|
|
return true
|
|
}
|
|
else {
|
|
throw QueryServerError.portalNotFound
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
/** Reads the next packet in preparing/ready state. Returns whether the connection should continue
|
|
to process packets. */
|
|
private func readPacket() throws -> Bool {
|
|
guard let messageType = try self.readByte() else { return false }
|
|
guard let messageLetter = Unicode.Scalar(Int(messageType)) else { throw QueryServerError.protocolError }
|
|
|
|
switch Character(messageLetter) {
|
|
case "C": return try self.readClose()
|
|
case "E": return try self.readExecute()
|
|
case "P": return try self.readParse()
|
|
case "Q": return try self.readQuery()
|
|
case "B": return try self.readBind()
|
|
case "D": return try self.readDescribe()
|
|
case "S": return try self.readSync()
|
|
case "X": return false
|
|
default:
|
|
throw QueryServerError.protocolError
|
|
}
|
|
}
|
|
|
|
func send(row: [PQValue]) throws {
|
|
switch self.state {
|
|
case .querying, .closed: break
|
|
default: fatalError("invalid state")
|
|
}
|
|
|
|
var buf = Data()
|
|
|
|
for value in row {
|
|
switch value {
|
|
case .null:
|
|
buf.append(bytesOf: Int32(0).bigEndian)
|
|
default:
|
|
let data = value.text.data(using: .utf8)!
|
|
buf.append(bytesOf: Int32(data.count + 1).bigEndian)
|
|
buf.append(data)
|
|
buf.append(0)
|
|
}
|
|
}
|
|
|
|
var packet = Data()
|
|
packet.append(UInt8(Character("D").codePoint))
|
|
packet.append(bytesOf: Int32(buf.count + 4 + 2).bigEndian)
|
|
packet.append(bytesOf: Int16(row.count).bigEndian)
|
|
packet.append(buf)
|
|
try self.socket.write(from: packet)
|
|
}
|
|
|
|
private func send(error: String, severity: PQSeverity = .error, code: String = "42000", endsQuery: Bool = true) throws {
|
|
var buf = Data()
|
|
buf.append(UInt8(Character("S").codePoint))
|
|
let sd = severity.rawValue.data(using: .utf8)!
|
|
buf.append(sd)
|
|
buf.append(0)
|
|
|
|
buf.append(UInt8(Character("C").codePoint))
|
|
let cd = code.data(using: .utf8)!
|
|
buf.append(cd)
|
|
buf.append(0)
|
|
|
|
buf.append(UInt8(Character("M").codePoint))
|
|
let md = error.data(using: .utf8)!
|
|
buf.append(md)
|
|
buf.append(0)
|
|
|
|
// Message terminator
|
|
buf.append(0)
|
|
|
|
|
|
var packet = Data()
|
|
packet.append(UInt8(Character("E").codePoint))
|
|
packet.append(bytesOf: Int32(buf.count + 4).bigEndian)
|
|
packet.append(buf)
|
|
try self.socket.write(from: packet)
|
|
|
|
if endsQuery {
|
|
self.state = .ready
|
|
self.run()
|
|
}
|
|
}
|
|
|
|
private func send(description: [PQField]) throws {
|
|
var buffer = Data()
|
|
|
|
for field in description {
|
|
let fn = field.name.data(using: .utf8)
|
|
buffer.append(fn!)
|
|
buffer.append(0)
|
|
buffer.append(bytesOf: field.tableId.bigEndian)
|
|
buffer.append(bytesOf: field.columnId.bigEndian)
|
|
buffer.append(bytesOf: field.type.rawValue.bigEndian)
|
|
buffer.append(bytesOf: field.type.typeSize.bigEndian)
|
|
buffer.append(bytesOf: field.typeModifier.bigEndian)
|
|
buffer.append(bytesOf: Int16(0).bigEndian) // Binary=1, text=0
|
|
}
|
|
|
|
var packet = Data()
|
|
packet.append(UInt8(Character("T").codePoint))
|
|
packet.append(bytesOf: Int32(6 + buffer.count).bigEndian)
|
|
packet.append(bytesOf: Int16(description.count).bigEndian)
|
|
packet.append(buffer)
|
|
try self.socket.write(from: packet)
|
|
}
|
|
|
|
private func sendQueryComplete(tag: String) throws {
|
|
let data = tag.data(using: .utf8)!
|
|
|
|
var packet = Data()
|
|
packet.append(bytesOf: UInt8(Character("C").codePoint))
|
|
packet.append(bytesOf: UInt32(data.count + 4 + 1).bigEndian)
|
|
packet.append(data)
|
|
packet.append(UInt8(0))
|
|
try self.socket.write(from: packet)
|
|
}
|
|
|
|
private func run() {
|
|
// Get the global concurrent queue...
|
|
let queue = DispatchQueue.global(qos: .default)
|
|
|
|
// Create the run loop work item and dispatch to the default priority global queue...
|
|
queue.async { [weak self] in
|
|
if let s = self {
|
|
var shouldKeepRunning = true
|
|
do {
|
|
switch s.state {
|
|
case .new:
|
|
if let len = s.readInt32(), let msg = s.readInt32() {
|
|
if len == 8 && msg == 80877103 {
|
|
// No SSL, thank you
|
|
try s.socket.write(from: "N")
|
|
}
|
|
else if len > UInt32(8) {
|
|
// Read client version number
|
|
s.majorVersion = UInt16(msg >> 16)
|
|
s.minorVersion = UInt16(msg & 0xFFFF)
|
|
|
|
// Read parameters
|
|
if let p = try s.readParameters(length: len - UInt32(8)) {
|
|
s.username = p["user"]
|
|
|
|
// Send authentication request
|
|
let buf = Data(bytes: [UInt8(Character("R").codePoint), 0, 0, 0, 8, 0, 0, 0, 3])
|
|
try s.socket.write(from: buf)
|
|
|
|
// Read authentication
|
|
if let pw = try s.readAuthentication() {
|
|
s.password = pw
|
|
|
|
// Send authentication success
|
|
let buf = Data(bytes: [UInt8(Character("R").codePoint), 0, 0, 0, 8, 0, 0, 0, 0])
|
|
try s.socket.write(from: buf)
|
|
|
|
s.state = .ready
|
|
try s.sendReadyForQuery()
|
|
}
|
|
else {
|
|
shouldKeepRunning = false
|
|
}
|
|
}
|
|
else {
|
|
shouldKeepRunning = false
|
|
}
|
|
}
|
|
else {
|
|
shouldKeepRunning = false
|
|
}
|
|
}
|
|
|
|
case .ready, .querying:
|
|
if try !s.readPacket() {
|
|
shouldKeepRunning = false
|
|
}
|
|
|
|
case .closed:
|
|
return
|
|
}
|
|
}
|
|
catch {
|
|
try? s.send(error: error.localizedDescription)
|
|
shouldKeepRunning = false
|
|
}
|
|
|
|
if shouldKeepRunning {
|
|
s.run()
|
|
}
|
|
else {
|
|
s.close()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func close() {
|
|
switch self.state {
|
|
case .closed:
|
|
break
|
|
default:
|
|
self.socket.close()
|
|
self.state = .closed
|
|
self.server?.connection(didClose: self)
|
|
}
|
|
}
|
|
}
|
|
|
|
/** A portal is a prepared statement with bound parameters. */
|
|
fileprivate class Portal<PreparedStatementType: PreparedStatement> {
|
|
let statement: PreparedStatementType
|
|
let parameters: [PQValue]
|
|
var result: ResultSet? = nil
|
|
|
|
init(statement: PreparedStatementType, parameters: [PQValue]) {
|
|
self.statement = statement
|
|
self.parameters = parameters
|
|
}
|
|
}
|
|
|
|
|
|
fileprivate struct DataReader {
|
|
private static let isLittleEndian = Int32(42).littleEndian == Int32(42)
|
|
private var data: Data
|
|
|
|
init(data: Data) {
|
|
self.data = data
|
|
}
|
|
|
|
mutating func readZeroTerminatedString() -> String? {
|
|
if let nullIndex = data.index(of: 0), let str = String(data: data.subdata(in: 0..<nullIndex), encoding: .utf8) {
|
|
data = data.subdata(in: nullIndex.advanced(by: 1)..<data.endIndex)
|
|
return str
|
|
}
|
|
return nil
|
|
}
|
|
|
|
mutating func readBytes(_ length: Int) -> Data? {
|
|
if data.count >= length {
|
|
let read = data.subdata(in: 0..<length)
|
|
data = data.subdata(in: data.startIndex.advanced(by: length)..<data.endIndex)
|
|
return read
|
|
}
|
|
return nil
|
|
}
|
|
|
|
mutating func readUInt16() -> UInt16? {
|
|
if data.count >= 2 {
|
|
let values = data.subdata(in: 0..<2).map { $0 }
|
|
let number = DataReader.isLittleEndian ? UInt16(values) : UInt16(values).byteSwapped
|
|
data = data.subdata(in: 2..<data.endIndex)
|
|
return number
|
|
}
|
|
return nil
|
|
}
|
|
|
|
mutating func readUInt32() -> UInt32? {
|
|
if data.count >= 4 {
|
|
let values = data.subdata(in: 0..<4).map { $0 }
|
|
let number = DataReader.isLittleEndian ? UInt32(values) : UInt32(values).byteSwapped
|
|
data = data.subdata(in: 4..<data.endIndex)
|
|
return number
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
fileprivate extension UnsignedInteger {
|
|
init(_ bytes: [UInt8]) {
|
|
precondition(bytes.count <= MemoryLayout<Self>.size)
|
|
|
|
var value : UInt64 = 0
|
|
|
|
for byte in bytes {
|
|
value <<= 8
|
|
value |= UInt64(byte)
|
|
}
|
|
|
|
self.init(value)
|
|
}
|
|
}
|
|
|
|
fileprivate extension Character {
|
|
var codePoint: Int {
|
|
get {
|
|
let s = String(self).unicodeScalars
|
|
return Int(s[s.startIndex].value)
|
|
}
|
|
}
|
|
}
|
|
|
|
fileprivate extension Data {
|
|
mutating func append<T>(bytesOf value: T) {
|
|
var value = value
|
|
let byteCount = MemoryLayout<T>.size
|
|
withUnsafePointer(to: &value) { ptr in
|
|
ptr.withMemoryRebound(to: UInt8.self, capacity: byteCount) { rptr in
|
|
self.append(rptr, count: byteCount)
|
|
}
|
|
}
|
|
}
|
|
}
|