diff --git a/Package.swift b/Package.swift index ae0e8a5d..f568a686 100644 --- a/Package.swift +++ b/Package.swift @@ -54,7 +54,7 @@ let package = Package( .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), ], - swiftSettings: swiftSettings + swiftSettings: swiftSettings + [.enableExperimentalFeature("Lifetimes")] ), .target( name: "_ConnectionPoolModule", diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index bdfcbd2c..b7807097 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -98,6 +98,131 @@ public struct PostgresCopyFromWriter: Sendable { } } +// PostgresBinaryCopyFromWriter relies on non-Escapable types, which were only introduced in Swift 6.2 +#if compiler(>=6.2) +/// Handle to send binary data for a `COPY ... FROM STDIN` query to the backend. +/// +/// It takes care of serializing `PostgresEncodable` column types into the binary format that Postgres expects. +public struct PostgresBinaryCopyFromWriter: ~Copyable { + /// Handle to serialize columns into a row that is being written by `PostgresBinaryCopyFromWriter`. + public struct ColumnWriter: ~Escapable, ~Copyable { + /// Pointer to the `PostgresBinaryCopyFromWriter` that is gathering the serialized data. + @usableFromInline + let underlying: UnsafeMutablePointer + + /// The number of columns that have been written by this `ColumnWriter`. + @usableFromInline + var columns: UInt16 = 0 + + /// - Warning: Do not call directly, call `withColumnWriter` instead + @usableFromInline + init(_underlying: UnsafeMutablePointer) { + self.underlying = _underlying + } + + @usableFromInline + static func withColumnWriter( + writingTo underlying: inout PostgresBinaryCopyFromWriter, + body: (inout ColumnWriter) throws -> T + ) rethrows -> T { + return try withUnsafeMutablePointer(to: &underlying) { pointerToUnderlying in + // We can guarantee that `ColumWriter` never outlives `underlying` because `ColumnWriter` is + // `~Escapable` and thus cannot escape the context of the closure to `withUnsafeMutablePointer`. + // To model this without resorting to unsafe pointers, we would need to be able to declare an `inout` + // reference to `PostgresBinaryCopyFromWriter` as a member of `ColumnWriter`, which isn't possible at + // the moment (https://github.com/swiftlang/swift/issues/85832). + var columnWriter = ColumnWriter(_underlying: pointerToUnderlying) + return try body(&columnWriter) + } + } + + /// Serialize a single column to a row. + /// + /// - Important: It is critical that that data type encoded here exactly matches the data type in the + /// database. For example, if the database stores an a 4-bit integer the corresponding `writeColumn` must + /// be called with an `Int32`. Serializing an integer of a different width will cause a deserialization + /// failure in the backend. + @inlinable + #if compiler(<6.3) + @_lifetime(&self) + #endif + public mutating func writeColumn(_ column: (some PostgresEncodable)?) throws { + columns += 1 + try invokeWriteColumn(on: underlying, column) + } + + // Needed to work around https://github.com/swiftlang/swift/issues/83309, copying the implementation into + // `writeColumn` causes an assertion failure when thread sanitizer is enabled. + @inlinable + func invokeWriteColumn( + on writer: UnsafeMutablePointer, + _ column: (some PostgresEncodable)? + ) throws { + try writer.pointee.writeColumn(column) + } + } + + /// The underlying `PostgresCopyFromWriter` that sends the serialized data to the backend. + @usableFromInline let underlying: PostgresCopyFromWriter + + /// The buffer in which we accumulate binary data. Once this buffer exceeds `bufferSize`, we flush it to + /// the backend. + @usableFromInline var buffer = ByteBuffer() + + /// Once `buffer` exceeds this size, it gets flushed to the backend. + @usableFromInline let bufferSize: Int + + init(underlying: PostgresCopyFromWriter, bufferSize: Int) { + self.underlying = underlying + // Allocate 10% more than the buffer size because we only flush the buffer once it has exceeded `bufferSize` + buffer.reserveCapacity(bufferSize + bufferSize / 10) + self.bufferSize = bufferSize + } + + /// Serialize a single row to the backend. Call `writeColumn` on `columnWriter` for every column that should be + /// included in the row. + @inlinable + public mutating func writeRow(_ body: (_ columnWriter: inout ColumnWriter) throws -> Void) async throws { + // Write a placeholder for the number of columns + let columnIndex = buffer.writerIndex + buffer.writeInteger(UInt16(0)) + + let columns = try ColumnWriter.withColumnWriter(writingTo: &self) { columnWriter in + try body(&columnWriter) + return columnWriter.columns + } + + // Fill in the number of columns + buffer.setInteger(columns, at: columnIndex) + + if buffer.readableBytes > bufferSize { + try await flush() + } + } + + /// Serialize a single column to the buffer. Should only be called by `ColumnWriter`. + @inlinable + mutating func writeColumn(_ column: (some PostgresEncodable)?) throws { + guard let column else { + buffer.writeInteger(Int32(-1)) + return + } + try buffer.writeLengthPrefixed(as: Int32.self) { buffer in + let startIndex = buffer.writerIndex + try column.encode(into: &buffer, context: .default) + return buffer.writerIndex - startIndex + } + } + + /// Flush any pending data in the buffer to the backend. + @usableFromInline + mutating func flush() async throws { + try await underlying.write(buffer) + buffer.clear() + } +} +#endif + /// Specifies the format in which data is transferred to the backend in a COPY operation. /// /// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings @@ -113,8 +238,14 @@ public struct PostgresCopyFromFormat: Sendable { public init() {} } + /// Options that can be used to modify the `binary` format of a COPY operation. + public struct BinaryOptions: Sendable { + public init() {} + } + enum Format { case text(TextOptions) + case binary(BinaryOptions) } var format: Format @@ -122,6 +253,10 @@ public struct PostgresCopyFromFormat: Sendable { public static func text(_ options: TextOptions) -> PostgresCopyFromFormat { return PostgresCopyFromFormat(format: .text(options)) } + + public static func binary(_ options: BinaryOptions) -> PostgresCopyFromFormat { + return PostgresCopyFromFormat(format: .binary(options)) + } } /// Create a `COPY ... FROM STDIN` query based on the given parameters. @@ -153,6 +288,8 @@ private func buildCopyFromQuery( // Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection. queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'") } + case .binary: + queryOptions.append("FORMAT binary") } precondition(!queryOptions.isEmpty) query += " WITH (" @@ -162,6 +299,51 @@ private func buildCopyFromQuery( } extension PostgresConnection { + #if compiler(>=6.2) + /// Copy data into a table using a `COPY FROM STDIN` query, transferring data in a binary format. + /// + /// - Parameters: + /// - table: The name of the table into which to copy the data. + /// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied. + /// - bufferSize: How many bytes to accumulate a local buffer before flushing it to the database. Can affect + /// performance characteristics of the copy operation. + /// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the + /// writer provided by the closure to send data to the backend and return from the closure once all data is sent. + /// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown + /// by the `copyFromBinary` function. + /// + /// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be + /// susceptible to SQL injection. Ensure no untrusted data is contained in these strings. + public func copyFromBinary( + table: String, + columns: [String] = [], + options: PostgresCopyFromFormat.BinaryOptions = .init(), + bufferSize: Int = 100_000, + logger: Logger, + file: String = #fileID, + line: Int = #line, + writeData: (inout PostgresBinaryCopyFromWriter) async throws -> Void + ) async throws { + try await copyFrom(table: table, columns: columns, format: .binary(PostgresCopyFromFormat.BinaryOptions()), logger: logger) { writer in + var header = ByteBuffer() + header.writeString("PGCOPY\n") + header.writeInteger(UInt8(0xff)) + header.writeString("\r\n\0") + + // Flag fields + header.writeInteger(UInt32(0)) + + // Header extension area length + header.writeInteger(UInt32(0)) + try await writer.write(header) + + var binaryWriter = PostgresBinaryCopyFromWriter(underlying: writer, bufferSize: bufferSize) + try await writeData(&binaryWriter) + try await binaryWriter.flush() + } + } + #endif + /// Copy data into a table using a `COPY
FROM STDIN` query. /// /// - Parameters: diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 35581edb..0587effa 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -487,4 +487,40 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror } } + + #if compiler(>=6.2) // copyFromBinary is only available in Swift 6.2+ + func testCopyFromBinary() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + try await conn.copyFromBinary(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in + let records: [(id: Int, name: String)] = [ + (1, "Alice"), + (42, "Bob") + ] + for record in records { + try await writer.writeRow { columnWriter in + try columnWriter.writeColumn(Int32(record.id)) + try columnWriter.writeColumn(record.name) + } + } + } + let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) } + guard rows.count == 2 else { + XCTFail("Expected 2 columns, received \(rows.count)") + return + } + XCTAssertEqual(rows[0].0, 1) + XCTAssertEqual(rows[0].1, "Alice") + XCTAssertEqual(rows[1].0, 42) + XCTAssertEqual(rows[1].1, "Bob") + } + #endif } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 28c593cb..9a351e9e 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -936,6 +936,70 @@ import Synchronization } } + #if compiler(>=6.2) // copyFromBinary is only available in Swift 6.2+ + @Test func testCopyFromBinary() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> Void in + taskGroup.addTask { + try await connection.copyFromBinary(table: "copy_table", logger: .psqlTest) { + writer in + try await writer.writeRow { columnWriter in + try columnWriter.writeColumn(Int32(1)) + try columnWriter.writeColumn("Alice") + } + try await writer.writeRow { columnWriter in + try columnWriter.writeColumn(Int32(2)) + try columnWriter.writeColumn("Bob") + } + } + } + + let copyRequest = try await channel.waitForUnpreparedRequest() + #expect(copyRequest.parse.query == #"COPY "copy_table" FROM STDIN WITH (FORMAT binary)"#) + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.writeInbound( + PostgresBackendMessage.copyInResponse( + .init(format: .binary, columnFormats: [.binary, .binary]))) + + let copyData = try await channel.waitForCopyData() + #expect(copyData.result == .done) + var data = copyData.data + // Signature + #expect(data.readString(length: 7) == "PGCOPY\n") + #expect(data.readInteger(as: UInt8.self) == 0xff) + #expect(data.readString(length: 3) == "\r\n\0") + // Flags + #expect(data.readInteger(as: UInt32.self) == 0) + // Header extension area length + #expect(data.readInteger(as: UInt32.self) == 0) + + struct Row: Equatable { + let id: Int32 + let name: String + } + var rows: [Row] = [] + while data.readableBytes > 0 { + // Number of columns + #expect(data.readInteger(as: UInt16.self) == 2) + // 'id' column + #expect(data.readInteger(as: UInt32.self) == 4) + let id = data.readInteger(as: Int32.self) + // 'name' column length + let nameLength = data.readInteger(as: UInt32.self) + let name = data.readString(length: Int(try #require(nameLength))) + rows.append(Row(id: try #require(id), name: try #require(name))) + } + #expect(rows == [Row(id: 1, name: "Alice"), Row(id: 2, name: "Bob")]) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + + try await channel.waitForPostgresFrontendMessage(\.sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } + } + #endif + func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in