Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.6.3"),
.package(url: "https://github.com/vapor/sql-kit.git", from: "3.32.0"),
.package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.9.1"),
],
targets: [
.target(
Expand All @@ -30,6 +31,7 @@ let package = Package(
.product(name: "NIOPosix", package: "swift-nio"),
.product(name: "Logging", package: "swift-log"),
.product(name: "SQLKit", package: "sql-kit"),
.product(name: "ServiceLifecycle", package: "swift-service-lifecycle"),
],
swiftSettings: swiftSettings
),
Expand Down
28 changes: 17 additions & 11 deletions Sources/FluentKit/Database/Databases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import struct NIOConcurrencyHelpers.NIOLock
import NIOCore
import NIOPosix
import Logging
import ServiceLifecycle

public struct DatabaseConfigurationFactory: Sendable {
public let make: @Sendable () -> any DatabaseConfiguration
Expand All @@ -12,7 +13,7 @@ public struct DatabaseConfigurationFactory: Sendable {
}
}

public final class Databases: @unchecked Sendable { // @unchecked is safe here; mutable data is protected by lock
public final class Databases: @unchecked Sendable, Service { // @unchecked is safe here; mutable data is protected by lock
public let eventLoopGroup: any EventLoopGroup
public let threadPool: NIOThreadPool

Expand All @@ -25,7 +26,7 @@ public final class Databases: @unchecked Sendable { // @unchecked is safe here;

// Synchronize access across threads.
private var lock: NIOLock

public struct Middleware {
let databases: Databases

Expand All @@ -40,7 +41,7 @@ public final class Databases: @unchecked Sendable { // @unchecked is safe here;
self.databases.configurations[id] = configuration
}
}

public func clear(on id: DatabaseID? = nil) {
self.databases.lock.withLockVoid {
let id = id ?? self.databases._requireDefaultID()
Expand All @@ -54,23 +55,23 @@ public final class Databases: @unchecked Sendable { // @unchecked is safe here;
public var middleware: Middleware {
.init(databases: self)
}

public init(threadPool: NIOThreadPool, on eventLoopGroup: any EventLoopGroup) {
self.eventLoopGroup = eventLoopGroup
self.threadPool = threadPool
self.configurations = [:]
self.drivers = [:]
self.lock = .init()
}

public func use(
_ configuration: DatabaseConfigurationFactory,
as id: DatabaseID,
isDefault: Bool? = nil
) {
self.use(configuration.make(), as: id, isDefault: isDefault)
}

public func use(
_ driver: any DatabaseConfiguration,
as id: DatabaseID,
Expand All @@ -89,13 +90,13 @@ public final class Databases: @unchecked Sendable { // @unchecked is safe here;
self.defaultID = id
}
}

public func configuration(for id: DatabaseID? = nil) -> (any DatabaseConfiguration)? {
self.lock.withLock {
self.configurations[id ?? self._requireDefaultID()]
}
}

public func database(
_ id: DatabaseID? = nil,
logger: Logger,
Expand Down Expand Up @@ -150,10 +151,10 @@ public final class Databases: @unchecked Sendable { // @unchecked is safe here;
self.drivers = [:]
}
}

public func shutdownAsync() async {
var driversToShutdown: [any DatabaseDriver] = []

self.lock.withLockVoid {
for driver in self.drivers.values {
driversToShutdown.append(driver)
Expand All @@ -171,11 +172,16 @@ public final class Databases: @unchecked Sendable { // @unchecked is safe here;
}
return configuration
}

private func _requireDefaultID() -> DatabaseID {
guard let id = self.defaultID else {
fatalError("No default database configured.")
}
return id
}

public func run() async throws {
try await gracefulShutdown()
await self.shutdownAsync()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try await gracefulShutdown()
await self.shutdownAsync()
do {
try await gracefulShutdown()
} catch is CancellationError {
// ignore; we still need to shut down for sudden cancellation
}
await self.shutdownAsync()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to trigger a shutdown for any error actually right? So a try? might be better.

}
}