diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 2c99cadb..cf7524ce 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -78,29 +78,29 @@ jobs:
if: matrix.skip_release != '1'
run: make XCODEBUILD_ARGUMENT="${{ matrix.command }}" CONFIG=Release PLATFORM="${{ matrix.platform }}" xcodebuild
- linux:
- name: linux
- strategy:
- matrix:
- swift-version: ["5.10"]
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - uses: swift-actions/setup-swift@v2
- with:
- swift-version: ${{ matrix.swift-version }}
- - name: Cache build
- uses: actions/cache@v3
- with:
- path: |
- .build
- key: |
- build-spm-linux-${{ matrix.swift-version }}-${{ hashFiles('**/Sources/**/*.swift', '**/Tests/**/*.swift', '**/Package.resolved') }}
- restore-keys: |
- build-spm-linux-${{ matrix.swift-version }}-
- - run: make dot-env
- - name: Run tests
- run: swift test --skip IntegrationTests
+ # linux:
+ # name: linux
+ # strategy:
+ # matrix:
+ # swift-version: ["5.10"]
+ # runs-on: ubuntu-latest
+ # steps:
+ # - uses: actions/checkout@v4
+ # - uses: swift-actions/setup-swift@v2
+ # with:
+ # swift-version: ${{ matrix.swift-version }}
+ # - name: Cache build
+ # uses: actions/cache@v3
+ # with:
+ # path: |
+ # .build
+ # key: |
+ # build-spm-linux-${{ matrix.swift-version }}-${{ hashFiles('**/Sources/**/*.swift', '**/Tests/**/*.swift', '**/Package.resolved') }}
+ # restore-keys: |
+ # build-spm-linux-${{ matrix.swift-version }}-
+ # - run: make dot-env
+ # - name: Run tests
+ # run: swift test --skip IntegrationTests
# library-evolution:
# name: Library (evolution)
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index b70bb56d..50c1dd70 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "2.24.1"
+ ".": "2.24.2"
}
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ced012f7..6f8335d6 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,12 @@
# Changelog
+## [2.24.2](https://github.com/supabase/supabase-swift/compare/v2.24.1...v2.24.2) (2025-01-08)
+
+
+### Bug Fixes
+
+* **realtime:** auto reconnect after calling disconnect, and several refactors ([#627](https://github.com/supabase/supabase-swift/issues/627)) ([1887f4f](https://github.com/supabase/supabase-swift/commit/1887f4f376e172bb7fbcec84506fea6c4797fde7))
+
## [2.24.1](https://github.com/supabase/supabase-swift/compare/v2.24.0...v2.24.1) (2024-12-16)
diff --git a/Examples/Examples/Auth/SignInWithApple.swift b/Examples/Examples/Auth/SignInWithApple.swift
index f87591ce..399dead4 100644
--- a/Examples/Examples/Auth/SignInWithApple.swift
+++ b/Examples/Examples/Auth/SignInWithApple.swift
@@ -5,6 +5,7 @@
// Created by Guilherme Souza on 16/12/23.
//
+import Auth
import AuthenticationServices
import SwiftUI
@@ -14,7 +15,7 @@ struct SignInWithApple: View {
var body: some View {
VStack {
SignInWithAppleButton { request in
- request.requestedScopes = [.email]
+ request.requestedScopes = [.email, .fullName]
} onCompletion: { result in
switch result {
case let .failure(error):
@@ -29,16 +30,23 @@ struct SignInWithApple: View {
return
}
- guard let identityToken = credential.identityToken.flatMap({ String(
- data: $0,
- encoding: .utf8
- ) }) else {
+ guard
+ let identityToken = credential.identityToken.flatMap({
+ String(
+ data: $0,
+ encoding: .utf8
+ )
+ })
+ else {
debug("Invalid identity token")
return
}
Task {
- await signInWithApple(using: identityToken)
+ await signInWithApple(
+ using: identityToken,
+ fullName: credential.fullName?.formatted()
+ )
}
}
}
@@ -55,13 +63,21 @@ struct SignInWithApple: View {
}
}
- private func signInWithApple(using idToken: String) async {
+ private func signInWithApple(using idToken: String, fullName: String?) async {
actionState = .inFlight
let result = await Result {
- _ = try await supabase.auth.signInWithIdToken(credentials: .init(
- provider: .apple,
- idToken: idToken
- ))
+ _ = try await supabase.auth.signInWithIdToken(
+ credentials: .init(
+ provider: .apple,
+ idToken: idToken
+ ))
+
+ // fullName is provided only in the first time (account creation),
+ // so checking if it is non-nil to not erase data on login.
+ if let fullName {
+ _ = try? await supabase.auth.update(
+ user: UserAttributes(data: ["full_name": .string(fullName)]))
+ }
}
actionState = .result(result)
}
diff --git a/Examples/Examples/Info.plist b/Examples/Examples/Info.plist
index 8070cab2..c3751b58 100644
--- a/Examples/Examples/Info.plist
+++ b/Examples/Examples/Info.plist
@@ -12,6 +12,14 @@
$(PRODUCT_BUNDLE_IDENTIFIER)
+
+ CFBundleTypeRole
+ Editor
+ CFBundleURLSchemes
+
+ DOT_REVERSED_IOS_CLIENT_ID
+
+
GIDClientID
YOUR_IOS_CLIENT_ID
diff --git a/Package.resolved b/Package.resolved
index 2a6d5e7d..aae76928 100644
--- a/Package.resolved
+++ b/Package.resolved
@@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/pointfreeco/swift-concurrency-extras",
"state" : {
- "revision" : "163409ef7dae9d960b87f34b51587b6609a76c1f",
- "version" : "1.3.0"
+ "revision" : "82a4ae7170d98d8538ec77238b7eb8e7199ef2e8",
+ "version" : "1.3.1"
}
},
{
diff --git a/Package.swift b/Package.swift
index 86d770cd..3d192186 100644
--- a/Package.swift
+++ b/Package.swift
@@ -92,10 +92,7 @@ let package = Package(
.product(name: "InlineSnapshotTesting", package: "swift-snapshot-testing"),
.product(name: "XCTestDynamicOverlay", package: "xctest-dynamic-overlay"),
"Helpers",
- "Auth",
- "PostgREST",
- "Realtime",
- "Storage",
+ "Supabase",
"TestHelpers",
],
resources: [.process("Fixtures")]
diff --git a/Sources/Helpers/EventEmitter.swift b/Sources/Helpers/EventEmitter.swift
index 4dd48e6f..99ad965f 100644
--- a/Sources/Helpers/EventEmitter.swift
+++ b/Sources/Helpers/EventEmitter.swift
@@ -8,6 +8,9 @@
import ConcurrencyExtras
import Foundation
+/// A token for cancelling observations.
+///
+/// When this token gets deallocated it cancels the observation it was associated with. Store this token in another object to keep the observation alive.
public final class ObservationToken: @unchecked Sendable, Hashable {
private let _isCancelled = LockIsolated(false)
package var onCancel: @Sendable () -> Void
@@ -44,9 +47,7 @@ public final class ObservationToken: @unchecked Sendable, Hashable {
public func hash(into hasher: inout Hasher) {
hasher.combine(ObjectIdentifier(self))
}
-}
-extension ObservationToken {
public func store(in collection: inout some RangeReplaceableCollection) {
collection.append(self)
}
@@ -59,9 +60,15 @@ extension ObservationToken {
package final class EventEmitter: Sendable {
public typealias Listener = @Sendable (Event) -> Void
- private let listeners = LockIsolated<[(key: ObjectIdentifier, listener: Listener)]>([])
- private let _lastEvent: LockIsolated
- package var lastEvent: Event { _lastEvent.value }
+ struct MutableState {
+ var listeners: [(key: ObjectIdentifier, listener: Listener)] = []
+ var lastEvent: Event
+ }
+
+ let mutableState: LockIsolated
+
+ /// The last event emitted by this Emiter, or the initial event.
+ package var lastEvent: Event { mutableState.lastEvent }
let emitsLastEventWhenAttaching: Bool
@@ -69,10 +76,13 @@ package final class EventEmitter: Sendable {
initialEvent event: Event,
emitsLastEventWhenAttaching: Bool = true
) {
- _lastEvent = LockIsolated(event)
+ mutableState = LockIsolated(MutableState(lastEvent: event))
self.emitsLastEventWhenAttaching = emitsLastEventWhenAttaching
}
+ /// Attaches a new listener for observing event emissions.
+ ///
+ /// If emitter initialized with `emitsLastEventWhenAttaching = true`, listener gets called right away with last event.
package func attach(_ listener: @escaping Listener) -> ObservationToken {
defer {
if emitsLastEventWhenAttaching {
@@ -84,21 +94,24 @@ package final class EventEmitter: Sendable {
let key = ObjectIdentifier(token)
token.onCancel = { [weak self] in
- self?.listeners.withValue {
- $0.removeAll { $0.key == key }
+ self?.mutableState.withValue {
+ $0.listeners.removeAll { $0.key == key }
}
}
- listeners.withValue {
- $0.append((key, listener))
+ mutableState.withValue {
+ $0.listeners.append((key, listener))
}
return token
}
+ /// Trigger a new event on all attached listeners, or a specific listener owned by the `token` provided.
package func emit(_ event: Event, to token: ObservationToken? = nil) {
- _lastEvent.setValue(event)
- let listeners = listeners.value
+ let listeners = mutableState.withValue {
+ $0.lastEvent = event
+ return $0.listeners
+ }
if let token {
listeners.first { $0.key == ObjectIdentifier(token) }?.listener(event)
@@ -109,6 +122,7 @@ package final class EventEmitter: Sendable {
}
}
+ /// Returns a new ``AsyncStream`` for observing events emitted by this emitter.
package func stream() -> AsyncStream {
AsyncStream { continuation in
let token = attach { status in
diff --git a/Sources/Helpers/Version.swift b/Sources/Helpers/Version.swift
index b774eb6d..e00d9405 100644
--- a/Sources/Helpers/Version.swift
+++ b/Sources/Helpers/Version.swift
@@ -1 +1 @@
-package let version = "2.24.1" // {x-release-please-version}
+package let version = "2.24.2" // {x-release-please-version}
diff --git a/Sources/Realtime/V2/PushV2.swift b/Sources/Realtime/V2/PushV2.swift
index 199e6b74..884fc981 100644
--- a/Sources/Realtime/V2/PushV2.swift
+++ b/Sources/Realtime/V2/PushV2.swift
@@ -31,7 +31,7 @@ actor PushV2 {
return .error
}
- await channel.socket.push(message)
+ channel.socket.push(message)
if !channel.config.broadcast.acknowledgeBroadcasts {
// channel was configured with `ack = false`,
@@ -40,7 +40,7 @@ actor PushV2 {
}
do {
- return try await withTimeout(interval: channel.socket.options().timeoutInterval) {
+ return try await withTimeout(interval: channel.socket.options.timeoutInterval) {
await withCheckedContinuation { continuation in
self.receivedContinuation = continuation
}
diff --git a/Sources/Realtime/V2/RealtimeChannelV2.swift b/Sources/Realtime/V2/RealtimeChannelV2.swift
index 41f9797c..5a39318f 100644
--- a/Sources/Realtime/V2/RealtimeChannelV2.swift
+++ b/Sources/Realtime/V2/RealtimeChannelV2.swift
@@ -25,46 +25,6 @@ public struct RealtimeChannelConfig: Sendable {
public var isPrivate: Bool
}
-struct Socket: Sendable {
- var broadcastURL: @Sendable () -> URL
- var status: @Sendable () -> RealtimeClientStatus
- var options: @Sendable () -> RealtimeClientOptions
- var accessToken: @Sendable () async -> String?
- var apiKey: @Sendable () -> String?
- var makeRef: @Sendable () -> Int
-
- var connect: @Sendable () async -> Void
- var addChannel: @Sendable (_ channel: RealtimeChannelV2) -> Void
- var removeChannel: @Sendable (_ channel: RealtimeChannelV2) async -> Void
- var push: @Sendable (_ message: RealtimeMessageV2) async -> Void
- var httpSend: @Sendable (_ request: Helpers.HTTPRequest) async throws -> Helpers.HTTPResponse
-}
-
-extension Socket {
- init(client: RealtimeClientV2) {
- self.init(
- broadcastURL: { [weak client] in client?.broadcastURL ?? URL(string: "http://localhost")! },
- status: { [weak client] in client?.status ?? .disconnected },
- options: { [weak client] in client?.options ?? .init() },
- accessToken: { [weak client] in
- if let accessToken = try? await client?.options.accessToken?() {
- return accessToken
- }
- return client?.mutableState.accessToken
- },
- apiKey: { [weak client] in client?.apikey },
- makeRef: { [weak client] in client?.makeRef() ?? 0 },
- connect: { [weak client] in await client?.connect() },
- addChannel: { [weak client] in client?.addChannel($0) },
- removeChannel: { [weak client] in await client?.removeChannel($0) },
- push: { [weak client] in await client?.push($0) },
- httpSend: { [weak client] in
- try await client?.http.send($0) ?? .init(data: Data(), response: HTTPURLResponse())
- }
- )
- }
-}
-
public final class RealtimeChannelV2: Sendable {
struct MutableState {
var clientChanges: [PostgresJoinConfig] = []
@@ -77,7 +37,8 @@ public final class RealtimeChannelV2: Sendable {
let topic: String
let config: RealtimeChannelConfig
let logger: (any SupabaseLogger)?
- let socket: Socket
+ let socket: RealtimeClientV2
+ var joinRef: String? { mutableState.joinRef }
let callbackManager = CallbackManager()
private let statusEventEmitter = EventEmitter(initialEvent: .unsubscribed)
@@ -105,7 +66,7 @@ public final class RealtimeChannelV2: Sendable {
init(
topic: String,
config: RealtimeChannelConfig,
- socket: Socket,
+ socket: RealtimeClientV2,
logger: (any SupabaseLogger)?
) {
self.topic = topic
@@ -120,8 +81,8 @@ public final class RealtimeChannelV2: Sendable {
/// Subscribes to the channel
public func subscribe() async {
- if socket.status() != .connected {
- if socket.options().connectOnSubscribe != true {
+ if socket.status != .connected {
+ if socket.options.connectOnSubscribe != true {
reportIssue(
"You can't subscribe to a channel while the realtime client is not connected. Did you forget to call `realtime.connect()`?"
)
@@ -130,8 +91,6 @@ public final class RealtimeChannelV2: Sendable {
await socket.connect()
}
- socket.addChannel(self)
-
status = .subscribing
logger?.debug("Subscribing to channel \(topic)")
@@ -144,10 +103,10 @@ public final class RealtimeChannelV2: Sendable {
let payload = RealtimeJoinPayload(
config: joinConfig,
- accessToken: await socket.accessToken()
+ accessToken: await socket._getAccessToken()
)
- let joinRef = socket.makeRef().description
+ let joinRef = socket.makeRef()
mutableState.withValue { $0.joinRef = joinRef }
logger?.debug("Subscribing to channel with body: \(joinConfig)")
@@ -159,7 +118,7 @@ public final class RealtimeChannelV2: Sendable {
)
do {
- try await withTimeout(interval: socket.options().timeoutInterval) { [self] in
+ try await withTimeout(interval: socket.options.timeoutInterval) { [self] in
_ = await statusChange.first { @Sendable in $0 == .subscribed }
}
} catch {
@@ -215,17 +174,17 @@ public final class RealtimeChannelV2: Sendable {
}
var headers: HTTPFields = [.contentType: "application/json"]
- if let apiKey = socket.apiKey() {
+ if let apiKey = socket.options.apikey {
headers[.apiKey] = apiKey
}
- if let accessToken = await socket.accessToken() {
+ if let accessToken = await socket._getAccessToken() {
headers[.authorization] = "Bearer \(accessToken)"
}
let task = Task { [headers] in
- _ = try? await socket.httpSend(
+ _ = try? await socket.http.send(
HTTPRequest(
- url: socket.broadcastURL(),
+ url: socket.broadcastURL,
method: .post,
headers: headers,
body: JSONEncoder().encode(
@@ -245,7 +204,7 @@ public final class RealtimeChannelV2: Sendable {
}
if config.broadcast.acknowledgeBroadcasts {
- try? await withTimeout(interval: socket.options().timeoutInterval) {
+ try? await withTimeout(interval: socket.options.timeoutInterval) {
await task.value
}
}
@@ -406,7 +365,7 @@ public final class RealtimeChannelV2: Sendable {
callbackManager.triggerBroadcast(event: event, json: payload)
case .close:
- await socket.removeChannel(self)
+ socket._remove(self)
logger?.debug("Unsubscribed from channel \(message.topic)")
status = .unsubscribed
@@ -582,7 +541,7 @@ public final class RealtimeChannelV2: Sendable {
let push = mutableState.withValue {
let message = RealtimeMessageV2(
joinRef: $0.joinRef,
- ref: ref ?? socket.makeRef().description,
+ ref: ref ?? socket.makeRef(),
topic: self.topic,
event: event,
payload: payload
diff --git a/Sources/Realtime/V2/RealtimeClientV2.swift b/Sources/Realtime/V2/RealtimeClientV2.swift
index e56a023e..0ad1b9c2 100644
--- a/Sources/Realtime/V2/RealtimeClientV2.swift
+++ b/Sources/Realtime/V2/RealtimeClientV2.swift
@@ -15,11 +15,14 @@ import Helpers
public typealias JSONObject = Helpers.JSONObject
+/// Factory function for returning a new WebSocket connection.
+typealias WebSocketTransport = @Sendable () async throws -> any WebSocket
+
public final class RealtimeClientV2: Sendable {
struct MutableState {
var accessToken: String?
var ref = 0
- var pendingHeartbeatRef: Int?
+ var pendingHeartbeatRef: String?
/// Long-running task that keeps sending heartbeat messages.
var heartbeatTask: Task?
@@ -28,20 +31,29 @@ public final class RealtimeClientV2: Sendable {
var messageTask: Task?
var connectionTask: Task?
- var channels: [String: RealtimeChannelV2] = [:]
- var sendBuffer: [@Sendable () async -> Void] = []
+ var channels: [RealtimeChannelV2] = []
+ var sendBuffer: [@Sendable () -> Void] = []
+
+ var conn: (any WebSocket)?
}
let url: URL
let options: RealtimeClientOptions
- let ws: any WebSocketClient
+ let wsTransport: WebSocketTransport
let mutableState = LockIsolated(MutableState())
let http: any HTTPClientType
let apikey: String?
+ var conn: (any WebSocket)? {
+ mutableState.conn
+ }
+
/// All managed channels indexed by their topics.
public var channels: [String: RealtimeChannelV2] {
- mutableState.channels
+ mutableState.channels.reduce(
+ into: [:],
+ { $0[$1.topic] = $1 }
+ )
}
private let statusEventEmitter = EventEmitter(initialEvent: .disconnected)
@@ -80,13 +92,17 @@ public final class RealtimeClientV2: Sendable {
self.init(
url: url,
options: options,
- ws: WebSocket(
- realtimeURL: Self.realtimeWebSocketURL(
- baseURL: Self.realtimeBaseURL(url: url),
- apikey: options.apikey
- ),
- options: options
- ),
+ wsTransport: {
+ let configuration = URLSessionConfiguration.default
+ configuration.httpAdditionalHeaders = options.headers.dictionary
+ return try await URLSessionWebSocket.connect(
+ to: Self.realtimeWebSocketURL(
+ baseURL: Self.realtimeBaseURL(url: url),
+ apikey: options.apikey
+ ),
+ configuration: configuration
+ )
+ },
http: HTTPClient(
fetch: options.fetch ?? { try await URLSession.shared.data(for: $0) },
interceptors: interceptors
@@ -97,12 +113,12 @@ public final class RealtimeClientV2: Sendable {
init(
url: URL,
options: RealtimeClientOptions,
- ws: any WebSocketClient,
+ wsTransport: @escaping WebSocketTransport,
http: any HTTPClientType
) {
self.url = url
self.options = options
- self.ws = ws
+ self.wsTransport = wsTransport
self.http = http
apikey = options.apikey
@@ -119,7 +135,7 @@ public final class RealtimeClientV2: Sendable {
mutableState.withValue {
$0.heartbeatTask?.cancel()
$0.messageTask?.cancel()
- $0.channels = [:]
+ $0.channels = []
}
}
@@ -149,21 +165,12 @@ public final class RealtimeClientV2: Sendable {
status = .connecting
- for await connectionStatus in ws.connect() {
- if Task.isCancelled {
- break
- }
-
- switch connectionStatus {
- case .connected:
- await onConnected(reconnect: reconnect)
-
- case .disconnected:
- await onDisconnected()
-
- case let .error(error):
- await onError(error)
- }
+ do {
+ let conn = try await wsTransport()
+ mutableState.withValue { $0.conn = conn }
+ onConnected(reconnect: reconnect)
+ } catch {
+ onError(error)
}
}
@@ -175,37 +182,46 @@ public final class RealtimeClientV2: Sendable {
_ = await statusChange.first { @Sendable in $0 == .connected }
}
- private func onConnected(reconnect: Bool) async {
+ private func onConnected(reconnect: Bool) {
status = .connected
options.logger?.debug("Connected to realtime WebSocket")
listenForMessages()
startHeartbeating()
if reconnect {
- await rejoinChannels()
+ rejoinChannels()
}
- await flushSendBuffer()
+ flushSendBuffer()
}
- private func onDisconnected() async {
+ private func onDisconnected() {
options.logger?
.debug(
"WebSocket disconnected. Trying again in \(options.reconnectDelay)"
)
- await reconnect()
+ reconnect()
}
- private func onError(_ error: (any Error)?) async {
+ private func onError(_ error: (any Error)?) {
options.logger?
.debug(
"WebSocket error \(error?.localizedDescription ?? ""). Trying again in \(options.reconnectDelay)"
)
- await reconnect()
+ reconnect()
}
- private func reconnect() async {
- disconnect()
- await connect(reconnect: true)
+ private func onClose(code: Int?, reason: String?) {
+ options.logger?.debug(
+ "WebSocket closed. Code: \(code?.description ?? ""), Reason: \(reason ?? "")")
+
+ reconnect()
+ }
+
+ private func reconnect() {
+ Task {
+ disconnect()
+ await connect(reconnect: true)
+ }
}
/// Creates a new channel and bind it to this client.
@@ -226,17 +242,28 @@ public final class RealtimeClientV2: Sendable {
)
options(&config)
- return RealtimeChannelV2(
+ let channel = RealtimeChannelV2(
topic: "realtime:\(topic)",
config: config,
- socket: Socket(client: self),
+ socket: self,
logger: self.options.logger
)
+
+ mutableState.withValue {
+ $0.channels.append(channel)
+ }
+
+ return channel
}
+ @available(
+ *, deprecated,
+ message:
+ "Client handles channels automatically, this method will be removed on the next major release."
+ )
public func addChannel(_ channel: RealtimeChannelV2) {
mutableState.withValue {
- $0.channels[channel.topic] = channel
+ $0.channels.append(channel)
}
}
@@ -248,16 +275,20 @@ public final class RealtimeClientV2: Sendable {
await channel.unsubscribe()
}
- mutableState.withValue {
- $0.channels[channel.topic] = nil
- }
-
if channels.isEmpty {
options.logger?.debug("No more subscribed channel in socket")
disconnect()
}
}
+ func _remove(_ channel: RealtimeChannelV2) {
+ mutableState.withValue {
+ $0.channels.removeAll {
+ $0.joinRef == channel.joinRef
+ }
+ }
+ }
+
/// Unsubscribes and removes all channels.
public func removeAllChannels() async {
await withTaskGroup(of: Void.self) { group in
@@ -269,35 +300,44 @@ public final class RealtimeClientV2: Sendable {
}
}
- private func rejoinChannels() async {
- await withTaskGroup(of: Void.self) { group in
+ func _getAccessToken() async -> String? {
+ if let accessToken = try? await options.accessToken?() {
+ return accessToken
+ }
+ return mutableState.accessToken
+ }
+
+ private func rejoinChannels() {
+ Task {
for channel in channels.values {
- group.addTask {
- await channel.subscribe()
- }
+ await channel.subscribe()
}
-
- await group.waitForAll()
}
}
private func listenForMessages() {
let messageTask = Task { [weak self] in
- guard let self else { return }
+ guard let self, let conn = self.conn else { return }
do {
- for try await message in ws.receive() {
- if Task.isCancelled {
- return
- }
+ for await event in conn.events {
+ if Task.isCancelled { return }
- await onMessage(message)
+ switch event {
+ case .binary:
+ self.options.logger?.error("Unsupported binary event received.")
+ break
+ case .text(let text):
+ let data = Data(text.utf8)
+ let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data)
+ await onMessage(message)
+
+ case let .close(code, reason):
+ onClose(code: code, reason: reason)
+ }
}
} catch {
- options.logger?.debug(
- "Error while listening for messages. Trying again in \(options.reconnectDelay) \(error)"
- )
- await reconnect()
+ onError(error)
}
}
mutableState.withValue {
@@ -312,7 +352,7 @@ public final class RealtimeClientV2: Sendable {
if Task.isCancelled {
break
}
- await self?.sendHeartbeat()
+ self?.sendHeartbeat()
}
}
mutableState.withValue {
@@ -320,8 +360,8 @@ public final class RealtimeClientV2: Sendable {
}
}
- private func sendHeartbeat() async {
- let pendingHeartbeatRef: Int? = mutableState.withValue {
+ private func sendHeartbeat() {
+ let pendingHeartbeatRef: String? = mutableState.withValue {
if $0.pendingHeartbeatRef != nil {
$0.pendingHeartbeatRef = nil
return nil
@@ -333,10 +373,10 @@ public final class RealtimeClientV2: Sendable {
}
if let pendingHeartbeatRef {
- await push(
+ push(
RealtimeMessageV2(
joinRef: nil,
- ref: pendingHeartbeatRef.description,
+ ref: pendingHeartbeatRef,
topic: "phoenix",
event: "heartbeat",
payload: [:]
@@ -344,7 +384,7 @@ public final class RealtimeClientV2: Sendable {
)
} else {
options.logger?.debug("Heartbeat timeout")
- await reconnect()
+ reconnect()
}
}
@@ -354,13 +394,17 @@ public final class RealtimeClientV2: Sendable {
/// - reason: A custom reason for the disconnect.
public func disconnect(code: Int? = nil, reason: String? = nil) {
options.logger?.debug("Closing WebSocket connection")
+
+ conn?.close(code: code, reason: reason)
+
mutableState.withValue {
$0.ref = 0
$0.messageTask?.cancel()
$0.heartbeatTask?.cancel()
$0.connectionTask?.cancel()
+ $0.conn = nil
}
- ws.disconnect(code: code, reason: reason)
+
status = .disconnected
}
@@ -405,35 +449,33 @@ public final class RealtimeClientV2: Sendable {
}
private func onMessage(_ message: RealtimeMessageV2) async {
- let channel = mutableState.withValue {
- let channel = $0.channels[message.topic]
-
- if let ref = message.ref, Int(ref) == $0.pendingHeartbeatRef {
+ let channels = mutableState.withValue {
+ if let ref = message.ref, ref == $0.pendingHeartbeatRef {
$0.pendingHeartbeatRef = nil
options.logger?.debug("heartbeat received")
} else {
options.logger?
- .debug("Received event \(message.event) for channel \(channel?.topic ?? "null")")
+ .debug("Received event \(message.event) for channel \(message.topic)")
}
- return channel
+
+ return $0.channels.filter { $0.topic == message.topic }
}
- if let channel {
+ for channel in channels {
await channel.onMessage(message)
- } else {
- options.logger?.warning("No channel subscribed to \(message.topic). Ignoring message.")
}
}
/// Push out a message if the socket is connected.
///
/// If the socket is not connected, the message gets enqueued within a local buffer, and sent out when a connection is next established.
- public func push(_ message: RealtimeMessageV2) async {
+ public func push(_ message: RealtimeMessageV2) {
let callback = { @Sendable [weak self] in
do {
// Check cancellation before sending, because this push may have been cancelled before a connection was established.
try Task.checkCancellation()
- try await self?.ws.send(message)
+ let data = try JSONEncoder().encode(message)
+ self?.conn?.send(String(decoding: data, as: UTF8.self))
} catch {
self?.options.logger?.error(
"""
@@ -447,7 +489,7 @@ public final class RealtimeClientV2: Sendable {
}
if status == .connected {
- await callback()
+ callback()
} else {
mutableState.withValue {
$0.sendBuffer.append(callback)
@@ -455,22 +497,17 @@ public final class RealtimeClientV2: Sendable {
}
}
- private func flushSendBuffer() async {
- let sendBuffer = mutableState.withValue {
- let copy = $0.sendBuffer
+ private func flushSendBuffer() {
+ mutableState.withValue {
+ $0.sendBuffer.forEach { $0() }
$0.sendBuffer = []
- return copy
- }
-
- for send in sendBuffer {
- await send()
}
}
- func makeRef() -> Int {
+ func makeRef() -> String {
mutableState.withValue {
$0.ref += 1
- return $0.ref
+ return $0.ref.description
}
}
diff --git a/Sources/Realtime/V2/WebSocketClient.swift b/Sources/Realtime/V2/WebSocketClient.swift
deleted file mode 100644
index 0634f774..00000000
--- a/Sources/Realtime/V2/WebSocketClient.swift
+++ /dev/null
@@ -1,153 +0,0 @@
-//
-// WebSocketClient.swift
-//
-//
-// Created by Guilherme Souza on 29/12/23.
-//
-
-import ConcurrencyExtras
-import Foundation
-import Helpers
-
-#if canImport(FoundationNetworking)
- import FoundationNetworking
-#endif
-
-enum WebSocketClientError: Error {
- case unsupportedData
-}
-
-enum ConnectionStatus {
- case connected
- case disconnected(reason: String, code: URLSessionWebSocketTask.CloseCode)
- case error((any Error)?)
-}
-
-protocol WebSocketClient: Sendable {
- func send(_ message: RealtimeMessageV2) async throws
- func receive() -> AsyncThrowingStream
- func connect() -> AsyncStream
- func disconnect(code: Int?, reason: String?)
-}
-
-final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @unchecked Sendable {
- private let realtimeURL: URL
- private let configuration: URLSessionConfiguration
- private let logger: (any SupabaseLogger)?
-
- struct MutableState {
- var continuation: AsyncStream.Continuation?
- var task: URLSessionWebSocketTask?
- }
-
- private let mutableState = LockIsolated(MutableState())
-
- init(realtimeURL: URL, options: RealtimeClientOptions) {
- self.realtimeURL = realtimeURL
-
- let sessionConfiguration = URLSessionConfiguration.default
- sessionConfiguration.httpAdditionalHeaders = options.headers.dictionary
- configuration = sessionConfiguration
- logger = options.logger
- }
-
- deinit {
- mutableState.task?.cancel(with: .goingAway, reason: nil)
- }
-
- func connect() -> AsyncStream {
- mutableState.withValue { state in
- let session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil)
- let task = session.webSocketTask(with: realtimeURL)
- state.task = task
- task.resume()
-
- let (stream, continuation) = AsyncStream.makeStream()
- state.continuation = continuation
- return stream
- }
- }
-
- func disconnect(code: Int?, reason: String?) {
- mutableState.withValue { state in
- if let code {
- state.task?.cancel(
- with: URLSessionWebSocketTask.CloseCode(rawValue: code) ?? .invalid,
- reason: reason?.data(using: .utf8))
- } else {
- state.task?.cancel()
- }
- }
- }
-
- func receive() -> AsyncThrowingStream {
- AsyncThrowingStream { [weak self] in
- guard let self else { return nil }
-
- let task = mutableState.task
-
- guard
- let message = try await task?.receive(),
- !Task.isCancelled
- else { return nil }
-
- switch message {
- case .data(let data):
- let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data)
- return message
-
- case .string(let string):
- guard let data = string.data(using: .utf8) else {
- throw WebSocketClientError.unsupportedData
- }
-
- let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data)
- return message
-
- @unknown default:
- assertionFailure("Unsupported message type.")
- task?.cancel(with: .unsupportedData, reason: nil)
- throw WebSocketClientError.unsupportedData
- }
- }
- }
-
- func send(_ message: RealtimeMessageV2) async throws {
- logger?.verbose("Sending message: \(message)")
-
- let data = try JSONEncoder().encode(message)
- try await mutableState.task?.send(.data(data))
- }
-
- // MARK: - URLSessionWebSocketDelegate
-
- func urlSession(
- _: URLSession,
- webSocketTask _: URLSessionWebSocketTask,
- didOpenWithProtocol _: String?
- ) {
- mutableState.continuation?.yield(.connected)
- }
-
- func urlSession(
- _: URLSession,
- webSocketTask _: URLSessionWebSocketTask,
- didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
- reason: Data?
- ) {
- let status = ConnectionStatus.disconnected(
- reason: reason.flatMap { String(data: $0, encoding: .utf8) } ?? "",
- code: closeCode
- )
-
- mutableState.continuation?.yield(status)
- }
-
- func urlSession(
- _: URLSession,
- task _: URLSessionTask,
- didCompleteWithError error: (any Error)?
- ) {
- mutableState.continuation?.yield(.error(error))
- }
-}
diff --git a/Sources/Realtime/WebSocket/URLSessionWebSocket.swift b/Sources/Realtime/WebSocket/URLSessionWebSocket.swift
new file mode 100644
index 00000000..61bafc70
--- /dev/null
+++ b/Sources/Realtime/WebSocket/URLSessionWebSocket.swift
@@ -0,0 +1,297 @@
+import ConcurrencyExtras
+import Foundation
+
+#if canImport(FoundationNetworking)
+ import FoundationNetworking
+#endif
+
+/// A WebSocket connection that uses `URLSession`.
+final class URLSessionWebSocket: WebSocket {
+ private init(
+ _task: URLSessionWebSocketTask,
+ _protocol: String
+ ) {
+ self._task = _task
+ self._protocol = _protocol
+
+ _scheduleReceive()
+ }
+
+ /// Create a new WebSocket connection.
+ /// - Parameters:
+ /// - url: The URL to connect to.
+ /// - protocols: An optional array of protocols to negotiate with the server.
+ /// - configuration: An optional `URLSessionConfiguration` to use for the connection.
+ /// - Returns: A `URLSessionWebSocket` instance.
+ /// - Throws: An error if the connection fails.
+ static func connect(
+ to url: URL,
+ protocols: [String]? = nil,
+ configuration: URLSessionConfiguration? = nil
+ ) async throws -> URLSessionWebSocket {
+ guard url.scheme == "ws" || url.scheme == "wss" else {
+ preconditionFailure("only ws: and wss: schemes are supported")
+ }
+
+ // It is safe to use `nonisolated(unsafe)` because all completion handlers runs on the same queue.
+ nonisolated(unsafe) var continuation: CheckedContinuation!
+ nonisolated(unsafe) var webSocket: URLSessionWebSocket?
+
+ let session = URLSession.sessionWithConfiguration(
+ configuration ?? .default,
+ onComplete: { session, task, error in
+ if let webSocket {
+ // There are three possibilities here:
+ // 1. the peer sent a close Frame, `onWebSocketTaskClosed` was already
+ // called and `_connectionClosed` is a no-op.
+ // 2. we sent a close Frame (through `close()`) and `_connectionClosed`
+ // is a no-op.
+ // 3. an error occurred (e.g. network failure) and `_connectionClosed`
+ // will signal that and close `event`.
+ webSocket._connectionClosed(
+ code: 1006, reason: Data("abnormal close".utf8))
+ } else if let error {
+ continuation.resume(
+ throwing: WebSocketError.connection(
+ message: "connection ended unexpectedly", error: error))
+ } else {
+ // `onWebSocketTaskOpened` should have been called and resumed continuation.
+ // So either there was an error creating the connection or a logic error.
+ assertionFailure("expected an error or `onWebSocketTaskOpened` to have been called first")
+ }
+ },
+ onWebSocketTaskOpened: { session, task, `protocol` in
+ webSocket = URLSessionWebSocket(_task: task, _protocol: `protocol` ?? "")
+ continuation.resume(returning: webSocket!)
+ },
+ onWebSocketTaskClosed: { session, task, code, reason in
+ assert(webSocket != nil, "connection should exist by this time")
+ webSocket!._connectionClosed(code: code, reason: reason)
+ }
+ )
+
+ session.webSocketTask(with: url, protocols: protocols ?? []).resume()
+ return try await withCheckedThrowingContinuation { continuation = $0 }
+ }
+
+ let _task: URLSessionWebSocketTask
+ let _protocol: String
+
+ struct MutableState {
+ var isClosed = false
+ var onEvent: (@Sendable (WebSocketEvent) -> Void)?
+
+ var closeCode: Int?
+ var closeReason: String?
+ }
+
+ let mutableState = LockIsolated(MutableState())
+
+ var closeCode: Int? {
+ mutableState.value.closeCode
+ }
+
+ var closeReason: String? {
+ mutableState.value.closeReason
+ }
+
+ var isClosed: Bool {
+ mutableState.value.isClosed
+ }
+
+ private func _handleMessage(_ value: URLSessionWebSocketTask.Message) {
+ guard !isClosed else { return }
+
+ let event =
+ switch value {
+ case .string(let string):
+ WebSocketEvent.text(string)
+ case .data(let data):
+ WebSocketEvent.binary(data)
+ @unknown default:
+ fatalError("Unsupported message.")
+ }
+ _trigger(event)
+ _scheduleReceive()
+ }
+
+ private func _scheduleReceive() {
+ _task.receive { [weak self] result in
+ switch result {
+ case .success(let value): self?._handleMessage(value)
+ case .failure(let error): self?._closeConnectionWithError(error)
+ }
+ }
+ }
+
+ private func _closeConnectionWithError(_ error: any Error) {
+ let nsError = error as NSError
+ if nsError.domain == NSPOSIXErrorDomain && nsError.code == 57 {
+ // Socket is not connected.
+ // onWebsocketTaskClosed/onComplete will be invoked and may indicate a close code.
+ return
+ }
+ let (code, reason) =
+ switch (nsError.domain, nsError.code) {
+ case (NSPOSIXErrorDomain, 100):
+ (1002, nsError.localizedDescription)
+ case (_, _):
+ (1006, nsError.localizedDescription)
+ }
+ _task.cancel()
+ _connectionClosed(code: code, reason: Data(reason.utf8))
+ }
+
+ private func _connectionClosed(code: Int?, reason: Data?) {
+ guard !isClosed else { return }
+
+ let closeReason = reason.map { String(decoding: $0, as: UTF8.self) } ?? ""
+ _trigger(.close(code: code, reason: closeReason))
+ }
+
+ func send(_ text: String) {
+ guard !isClosed else {
+ return
+ }
+
+ _task.send(.string(text)) { [weak self] error in
+ if let error {
+ self?._closeConnectionWithError(error)
+ }
+ }
+ }
+
+ var onEvent: (@Sendable (WebSocketEvent) -> Void)? {
+ get { mutableState.value.onEvent }
+ set { mutableState.withValue { $0.onEvent = newValue } }
+ }
+
+ private func _trigger(_ event: WebSocketEvent) {
+ mutableState.withValue {
+ $0.onEvent?(event)
+
+ if case .close(let code, let reason) = event {
+ $0.onEvent = nil
+ $0.isClosed = true
+ $0.closeCode = code
+ $0.closeReason = reason
+ }
+ }
+ }
+
+ func send(_ binary: Data) {
+ guard !isClosed else {
+ return
+ }
+
+ _task.send(.data(binary)) { [weak self] error in
+ if let error {
+ self?._closeConnectionWithError(error)
+ }
+ }
+ }
+
+ func close(code: Int?, reason: String?) {
+ guard !isClosed else {
+ return
+ }
+
+ if code != nil, code != 1000, !(code! >= 3000 && code! <= 4999) {
+ preconditionFailure(
+ "Invalid argument: \(code!), close code must be 1000 or in the range 3000-4999")
+ }
+
+ if reason != nil, reason!.utf8.count > 123 {
+ preconditionFailure("reason must be <= 123 bytes long and encoded as UTF-8")
+ }
+
+ mutableState.withValue {
+ if !$0.isClosed {
+ if code != nil {
+ let reason = reason ?? ""
+ _task.cancel(
+ with: URLSessionWebSocketTask.CloseCode(rawValue: code!)!,
+ reason: Data(reason.utf8)
+ )
+ } else {
+ _task.cancel()
+ }
+ }
+ }
+ }
+
+ var `protocol`: String { _protocol }
+}
+
+extension URLSession {
+ static func sessionWithConfiguration(
+ _ configuration: URLSessionConfiguration,
+ onComplete: (@Sendable (URLSession, URLSessionTask, (any Error)?) -> Void)? = nil,
+ onWebSocketTaskOpened: (@Sendable (URLSession, URLSessionWebSocketTask, String?) -> Void)? =
+ nil,
+ onWebSocketTaskClosed: (@Sendable (URLSession, URLSessionWebSocketTask, Int?, Data?) -> Void)? =
+ nil
+ ) -> URLSession {
+ let queue = OperationQueue()
+ queue.maxConcurrentOperationCount = 1
+
+ let hasDelegate =
+ onComplete != nil || onWebSocketTaskOpened != nil || onWebSocketTaskClosed != nil
+
+ if hasDelegate {
+ return URLSession(
+ configuration: configuration,
+ delegate: _Delegate(
+ onComplete: onComplete,
+ onWebSocketTaskOpened: onWebSocketTaskOpened,
+ onWebSocketTaskClosed: onWebSocketTaskClosed
+ ),
+ delegateQueue: queue
+ )
+ } else {
+ return URLSession(configuration: configuration)
+ }
+ }
+}
+
+final class _Delegate: NSObject, URLSessionDelegate, URLSessionDataDelegate, URLSessionTaskDelegate,
+ URLSessionWebSocketDelegate
+{
+ let onComplete: (@Sendable (URLSession, URLSessionTask, (any Error)?) -> Void)?
+ let onWebSocketTaskOpened: (@Sendable (URLSession, URLSessionWebSocketTask, String?) -> Void)?
+ let onWebSocketTaskClosed: (@Sendable (URLSession, URLSessionWebSocketTask, Int?, Data?) -> Void)?
+
+ init(
+ onComplete: (@Sendable (URLSession, URLSessionTask, (any Error)?) -> Void)?,
+ onWebSocketTaskOpened: (
+ @Sendable (URLSession, URLSessionWebSocketTask, String?) -> Void
+ )?,
+ onWebSocketTaskClosed: (
+ @Sendable (URLSession, URLSessionWebSocketTask, Int?, Data?) -> Void
+ )?
+ ) {
+ self.onComplete = onComplete
+ self.onWebSocketTaskOpened = onWebSocketTaskOpened
+ self.onWebSocketTaskClosed = onWebSocketTaskClosed
+ }
+
+ func urlSession(
+ _ session: URLSession, task: URLSessionTask, didCompleteWithError error: (any Error)?
+ ) {
+ onComplete?(session, task, error)
+ }
+
+ func urlSession(
+ _ session: URLSession, webSocketTask: URLSessionWebSocketTask,
+ didOpenWithProtocol protocol: String?
+ ) {
+ onWebSocketTaskOpened?(session, webSocketTask, `protocol`)
+ }
+
+ func urlSession(
+ _ session: URLSession, webSocketTask: URLSessionWebSocketTask,
+ didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?
+ ) {
+ onWebSocketTaskClosed?(session, webSocketTask, closeCode.rawValue, reason)
+ }
+}
diff --git a/Sources/Realtime/WebSocket/WebSocket.swift b/Sources/Realtime/WebSocket/WebSocket.swift
new file mode 100644
index 00000000..8512c335
--- /dev/null
+++ b/Sources/Realtime/WebSocket/WebSocket.swift
@@ -0,0 +1,90 @@
+import Foundation
+
+/// Represents events that can occur on a WebSocket connection.
+enum WebSocketEvent: Sendable, Hashable {
+ case text(String)
+ case binary(Data)
+ case close(code: Int?, reason: String)
+}
+
+/// Represents errors that can occur on a WebSocket connection.
+enum WebSocketError: Error, LocalizedError {
+ /// An error occurred while connecting to the peer.
+ case connection(message: String, error: any Error)
+
+ var errorDescription: String? {
+ switch self {
+ case .connection(let message, let error): "\(message) \(error.localizedDescription)"
+ }
+ }
+}
+
+/// The interface for WebSocket connection.
+protocol WebSocket: Sendable, AnyObject {
+ var closeCode: Int? { get }
+ var closeReason: String? { get }
+
+ /// Sends text data to the connected peer.
+ /// - Parameter text: The text data to send.
+ func send(_ text: String)
+
+ /// Sends binary data to the connected peer.
+ /// - Parameter binary: The binary data to send.
+ func send(_ binary: Data)
+
+ /// Closes the WebSocket connection and the ``events`` `AsyncStream`.
+ ///
+ /// Sends a Close frame to the peer. If the optional `code` and `reason` arguments are given, they will be included in the Close frame. If no `code` is set then the peer will see a 1005 status code. If no `reason` is set then the peer will not receive a reason string.
+ /// - Parameters:
+ /// - code: The close code to send to the peer.
+ /// - reason: The reason for closing the connection.
+ func close(code: Int?, reason: String?)
+
+ /// Listen for event messages in the connection.
+ var onEvent: (@Sendable (WebSocketEvent) -> Void)? { get set }
+
+ /// The WebSocket subprotocol negotiated with the peer.
+ ///
+ /// Will be the empty string if no subprotocol was negotiated.
+ ///
+ /// See [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9).
+ var `protocol`: String { get }
+
+ /// Whether connection is closed.
+ var isClosed: Bool { get }
+}
+
+extension WebSocket {
+ /// Closes the WebSocket connection and the ``events`` `AsyncStream`.
+ ///
+ /// Sends a Close frame to the peer. If the optional `code` and `reason` arguments are given, they will be included in the Close frame. If no `code` is set then the peer will see a 1005 status code. If no `reason` is set then the peer will not receive a reason string.
+ func close() {
+ self.close(code: nil, reason: nil)
+ }
+
+ /// An `AsyncStream` of ``WebSocketEvent`` received from the peer.
+ ///
+ /// Data received by the peer will be delivered as a ``WebSocketEvent/text(_:)`` or ``WebSocketEvent/binary(_:)``.
+ ///
+ /// If a ``WebSocketEvent/close(code:reason:)`` event is received then the `AsyncStream` will be closed. A ``WebSocketEvent/close(code:reason:)`` event indicates either that:
+ ///
+ /// - A close frame was received from the peer. `code` and `reason` will be set by the peer.
+ /// - A failure occurred (e.g. the peer disconnected). `code` and `reason` will be a failure code defined by [RFC-6455](https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1) (e.g. 1006).
+ ///
+ /// Errors will never appear in this `AsyncStream`.
+ var events: AsyncStream {
+ let (stream, continuation) = AsyncStream.makeStream()
+ self.onEvent = { event in
+ continuation.yield(event)
+
+ if case .close = event {
+ continuation.finish()
+ }
+ }
+
+ continuation.onTermination = { _ in
+ self.onEvent = nil
+ }
+ return stream
+ }
+}
diff --git a/Tests/IntegrationTests/RealtimeIntegrationTests.swift b/Tests/IntegrationTests/RealtimeIntegrationTests.swift
index 4b2b543a..74e5f7f3 100644
--- a/Tests/IntegrationTests/RealtimeIntegrationTests.swift
+++ b/Tests/IntegrationTests/RealtimeIntegrationTests.swift
@@ -7,25 +7,33 @@
import ConcurrencyExtras
import CustomDump
+import Helpers
+import InlineSnapshotTesting
import PostgREST
-@testable import Realtime
import Supabase
import TestHelpers
import XCTest
+@testable import Realtime
+
+struct TestLogger: SupabaseLogger {
+ func log(message: SupabaseLogMessage) {
+ print(message.description)
+ }
+}
+
final class RealtimeIntegrationTests: XCTestCase {
- let realtime = RealtimeClientV2(
- url: URL(string: "\(DotEnv.SUPABASE_URL)/realtime/v1")!,
- options: RealtimeClientOptions(
- headers: ["apikey": DotEnv.SUPABASE_ANON_KEY]
- )
- )
- let db = PostgrestClient(
- url: URL(string: "\(DotEnv.SUPABASE_URL)/rest/v1")!,
- headers: [
- "apikey": DotEnv.SUPABASE_ANON_KEY,
- ]
+ static let reconnectDelay: TimeInterval = 1
+
+ let client = SupabaseClient(
+ supabaseURL: URL(string: DotEnv.SUPABASE_URL)!,
+ supabaseKey: DotEnv.SUPABASE_ANON_KEY,
+ options: SupabaseClientOptions(
+ realtime: RealtimeClientOptions(
+ reconnectDelay: reconnectDelay
+ )
+ )
)
override func invokeTest() {
@@ -34,23 +42,26 @@ final class RealtimeIntegrationTests: XCTestCase {
}
}
- func testBroadcast() async throws {
- let expectation = expectation(description: "receivedBroadcastMessages")
- expectation.expectedFulfillmentCount = 3
+ func testDisconnectByUser_shouldNotReconnect() async {
+ await client.realtimeV2.connect()
+ XCTAssertEqual(client.realtimeV2.status, .connected)
+
+ client.realtimeV2.disconnect()
- let channel = realtime.channel("integration") {
+ /// Wait for the reconnection delay
+ try? await Task.sleep(
+ nanoseconds: NSEC_PER_SEC * UInt64(Self.reconnectDelay) + 1)
+
+ XCTAssertEqual(client.realtimeV2.status, .disconnected)
+ }
+
+ func testBroadcast() async throws {
+ let channel = client.realtimeV2.channel("integration") {
$0.broadcast.receiveOwnBroadcasts = true
}
- let receivedMessages = LockIsolated<[JSONObject]>([])
-
- Task {
- for await message in channel.broadcastStream(event: "test") {
- receivedMessages.withValue {
- $0.append(message)
- }
- expectation.fulfill()
- }
+ let receivedMessagesTask = Task {
+ await channel.broadcastStream(event: "test").prefix(3).collect()
}
await Task.yield()
@@ -65,41 +76,44 @@ final class RealtimeIntegrationTests: XCTestCase {
try await channel.broadcast(event: "test", message: Message(value: 2))
try await channel.broadcast(event: "test", message: ["value": 3, "another_value": 42])
- await fulfillment(of: [expectation], timeout: 0.5)
+ let receivedMessages = try await withTimeout(interval: 5) {
+ await receivedMessagesTask.value
+ }
- expectNoDifference(
- receivedMessages.value,
+ assertInlineSnapshot(of: receivedMessages, as: .json) {
+ """
[
- [
- "event": "test",
- "payload": [
- "value": 1,
- ],
- "type": "broadcast",
- ],
- [
- "event": "test",
- "payload": [
- "value": 2,
- ],
- "type": "broadcast",
- ],
- [
- "event": "test",
- "payload": [
- "value": 3,
- "another_value": 42,
- ],
- "type": "broadcast",
- ],
+ {
+ "event" : "test",
+ "payload" : {
+ "value" : 1
+ },
+ "type" : "broadcast"
+ },
+ {
+ "event" : "test",
+ "payload" : {
+ "value" : 2
+ },
+ "type" : "broadcast"
+ },
+ {
+ "event" : "test",
+ "payload" : {
+ "another_value" : 42,
+ "value" : 3
+ },
+ "type" : "broadcast"
+ }
]
- )
+ """
+ }
await channel.unsubscribe()
}
func testBroadcastWithUnsubscribedChannel() async throws {
- let channel = realtime.channel("integration") {
+ let channel = client.realtimeV2.channel("integration") {
$0.broadcast.acknowledgeBroadcasts = true
}
@@ -113,22 +127,12 @@ final class RealtimeIntegrationTests: XCTestCase {
}
func testPresence() async throws {
- let channel = realtime.channel("integration") {
+ let channel = client.realtimeV2.channel("integration") {
$0.broadcast.receiveOwnBroadcasts = true
}
- let expectation = expectation(description: "presenceChange")
- expectation.expectedFulfillmentCount = 4
-
- let receivedPresenceChanges = LockIsolated<[any PresenceAction]>([])
-
- Task {
- for await presence in channel.presenceChange() {
- receivedPresenceChanges.withValue {
- $0.append(presence)
- }
- expectation.fulfill()
- }
+ let receivedPresenceChangesTask = Task {
+ await channel.presenceChange().prefix(4).collect()
}
await Task.yield()
@@ -144,14 +148,16 @@ final class RealtimeIntegrationTests: XCTestCase {
await channel.untrack()
- await fulfillment(of: [expectation], timeout: 0.5)
+ let receivedPresenceChanges = try await withTimeout(interval: 5) {
+ await receivedPresenceChangesTask.value
+ }
- let joins = try receivedPresenceChanges.value.map { try $0.decodeJoins(as: UserState.self) }
- let leaves = try receivedPresenceChanges.value.map { try $0.decodeLeaves(as: UserState.self) }
+ let joins = try receivedPresenceChanges.map { try $0.decodeJoins(as: UserState.self) }
+ let leaves = try receivedPresenceChanges.map { try $0.decodeLeaves(as: UserState.self) }
expectNoDifference(
joins,
[
- [], // This is the first PRESENCE_STATE event.
+ [], // This is the first PRESENCE_STATE event.
[UserState(email: "test@supabase.com")],
[UserState(email: "test2@supabase.com")],
[],
@@ -161,7 +167,7 @@ final class RealtimeIntegrationTests: XCTestCase {
expectNoDifference(
leaves,
[
- [], // This is the first PRESENCE_STATE event.
+ [], // This is the first PRESENCE_STATE event.
[],
[UserState(email: "test@supabase.com")],
[UserState(email: "test2@supabase.com")],
@@ -171,86 +177,87 @@ final class RealtimeIntegrationTests: XCTestCase {
await channel.unsubscribe()
}
- // FIXME: Test getting stuck
-// func testPostgresChanges() async throws {
-// let channel = realtime.channel("db-changes")
-//
-// let receivedInsertActions = Task {
-// await channel.postgresChange(InsertAction.self, schema: "public").prefix(1).collect()
-// }
-//
-// let receivedUpdateActions = Task {
-// await channel.postgresChange(UpdateAction.self, schema: "public").prefix(1).collect()
-// }
-//
-// let receivedDeleteActions = Task {
-// await channel.postgresChange(DeleteAction.self, schema: "public").prefix(1).collect()
-// }
-//
-// let receivedAnyActionsTask = Task {
-// await channel.postgresChange(AnyAction.self, schema: "public").prefix(3).collect()
-// }
-//
-// await Task.yield()
-// await channel.subscribe()
-//
-// struct Entry: Codable, Equatable {
-// let key: String
-// let value: AnyJSON
-// }
-//
-// let key = try await (
-// db.from("key_value_storage")
-// .insert(["key": AnyJSON.string(UUID().uuidString), "value": "value1"]).select().single()
-// .execute().value as Entry
-// ).key
-// try await db.from("key_value_storage").update(["value": "value2"]).eq("key", value: key)
-// .execute()
-// try await db.from("key_value_storage").delete().eq("key", value: key).execute()
-//
-// let insertedEntries = try await receivedInsertActions.value.map {
-// try $0.decodeRecord(
-// as: Entry.self,
-// decoder: JSONDecoder()
-// )
-// }
-// let updatedEntries = try await receivedUpdateActions.value.map {
-// try $0.decodeRecord(
-// as: Entry.self,
-// decoder: JSONDecoder()
-// )
-// }
-// let deletedEntryIds = await receivedDeleteActions.value.compactMap {
-// $0.oldRecord["key"]?.stringValue
-// }
-//
-// expectNoDifference(insertedEntries, [Entry(key: key, value: "value1")])
-// expectNoDifference(updatedEntries, [Entry(key: key, value: "value2")])
-// expectNoDifference(deletedEntryIds, [key])
-//
-// let receivedAnyActions = await receivedAnyActionsTask.value
-// XCTAssertEqual(receivedAnyActions.count, 3)
-//
-// if case let .insert(action) = receivedAnyActions[0] {
-// let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder())
-// expectNoDifference(record, Entry(key: key, value: "value1"))
-// } else {
-// XCTFail("Expected a `AnyAction.insert` on `receivedAnyActions[0]`")
-// }
-//
-// if case let .update(action) = receivedAnyActions[1] {
-// let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder())
-// expectNoDifference(record, Entry(key: key, value: "value2"))
-// } else {
-// XCTFail("Expected a `AnyAction.update` on `receivedAnyActions[1]`")
-// }
-//
-// if case let .delete(action) = receivedAnyActions[2] {
-// expectNoDifference(key, action.oldRecord["key"]?.stringValue)
-// } else {
-// XCTFail("Expected a `AnyAction.delete` on `receivedAnyActions[2]`")
-// }
-//
-// await channel.unsubscribe()
-// }
+ func testPostgresChanges() async throws {
+ let channel = client.realtimeV2.channel("db-changes")
+
+ let receivedInsertActions = Task {
+ await channel.postgresChange(InsertAction.self, schema: "public").prefix(1).collect()
+ }
+
+ let receivedUpdateActions = Task {
+ await channel.postgresChange(UpdateAction.self, schema: "public").prefix(1).collect()
+ }
+
+ let receivedDeleteActions = Task {
+ await channel.postgresChange(DeleteAction.self, schema: "public").prefix(1).collect()
+ }
+
+ let receivedAnyActionsTask = Task {
+ await channel.postgresChange(AnyAction.self, schema: "public").prefix(3).collect()
+ }
+
+ await Task.yield()
+ await channel.subscribe()
+
+ struct Entry: Codable, Equatable {
+ let key: String
+ let value: AnyJSON
+ }
+
+ // Wait until a system event for makind sure DB change listeners are set before making DB changes.
+ _ = await channel.system().first(where: { _ in true })
+
+ let key = try await
+ (client.from("key_value_storage")
+ .insert(["key": AnyJSON.string(UUID().uuidString), "value": "value1"]).select().single()
+ .execute().value as Entry).key
+ try await client.from("key_value_storage").update(["value": "value2"]).eq("key", value: key)
+ .execute()
+ try await client.from("key_value_storage").delete().eq("key", value: key).execute()
+
+ let insertedEntries = try await receivedInsertActions.value.map {
+ try $0.decodeRecord(
+ as: Entry.self,
+ decoder: JSONDecoder()
+ )
+ }
+ let updatedEntries = try await receivedUpdateActions.value.map {
+ try $0.decodeRecord(
+ as: Entry.self,
+ decoder: JSONDecoder()
+ )
+ }
+ let deletedEntryIds = await receivedDeleteActions.value.compactMap {
+ $0.oldRecord["key"]?.stringValue
+ }
+
+ expectNoDifference(insertedEntries, [Entry(key: key, value: "value1")])
+ expectNoDifference(updatedEntries, [Entry(key: key, value: "value2")])
+ expectNoDifference(deletedEntryIds, [key])
+
+ let receivedAnyActions = await receivedAnyActionsTask.value
+ XCTAssertEqual(receivedAnyActions.count, 3)
+
+ if case let .insert(action) = receivedAnyActions[0] {
+ let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder())
+ expectNoDifference(record, Entry(key: key, value: "value1"))
+ } else {
+ XCTFail("Expected a `AnyAction.insert` on `receivedAnyActions[0]`")
+ }
+
+ if case let .update(action) = receivedAnyActions[1] {
+ let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder())
+ expectNoDifference(record, Entry(key: key, value: "value2"))
+ } else {
+ XCTFail("Expected a `AnyAction.update` on `receivedAnyActions[1]`")
+ }
+
+ if case let .delete(action) = receivedAnyActions[2] {
+ expectNoDifference(key, action.oldRecord["key"]?.stringValue)
+ } else {
+ XCTFail("Expected a `AnyAction.delete` on `receivedAnyActions[2]`")
+ }
+
+ await channel.unsubscribe()
+ }
}
diff --git a/Tests/RealtimeTests/FakeWebSocket.swift b/Tests/RealtimeTests/FakeWebSocket.swift
new file mode 100644
index 00000000..357f7ddd
--- /dev/null
+++ b/Tests/RealtimeTests/FakeWebSocket.swift
@@ -0,0 +1,118 @@
+import ConcurrencyExtras
+import Foundation
+
+@testable import Realtime
+
+final class FakeWebSocket: WebSocket {
+ struct MutableState {
+ var isClosed: Bool = false
+ weak var other: FakeWebSocket?
+ var onEvent: (@Sendable (WebSocketEvent) -> Void)?
+
+ var sentEvents: [WebSocketEvent] = []
+ var receivedEvents: [WebSocketEvent] = []
+ var closeCode: Int?
+ var closeReason: String?
+ }
+
+ private let mutableState = LockIsolated(MutableState())
+
+ private init(`protocol`: String) {
+ self.`protocol` = `protocol`
+ }
+
+ /// Events send by this connection.
+ var sentEvents: [WebSocketEvent] {
+ mutableState.value.sentEvents
+ }
+
+ /// Events received by this connection.
+ var receivedEvents: [WebSocketEvent] {
+ mutableState.value.receivedEvents
+ }
+
+ var closeCode: Int? {
+ mutableState.value.closeCode
+ }
+
+ var closeReason: String? {
+ mutableState.value.closeReason
+ }
+
+ func close(code: Int?, reason: String?) {
+ mutableState.withValue { s in
+ if s.isClosed { return }
+
+ s.sentEvents.append(.close(code: code, reason: reason ?? ""))
+
+ s.isClosed = true
+ if s.other?.isClosed == false {
+ s.other?._trigger(.close(code: code ?? 1005, reason: reason ?? ""))
+ }
+ }
+ }
+
+ func send(_ text: String) {
+ mutableState.withValue {
+ guard !$0.isClosed else { return }
+
+ $0.sentEvents.append(.text(text))
+
+ if $0.other?.isClosed == false {
+ $0.other?._trigger(.text(text))
+ }
+ }
+ }
+
+ func send(_ binary: Data) {
+ mutableState.withValue {
+ guard !$0.isClosed else { return }
+
+ $0.sentEvents.append(.binary(binary))
+
+ if $0.other?.isClosed == false {
+ $0.other?._trigger(.binary(binary))
+ }
+ }
+ }
+
+ var onEvent: (@Sendable (WebSocketEvent) -> Void)? {
+ get { mutableState.value.onEvent }
+ set { mutableState.withValue { $0.onEvent = newValue } }
+ }
+
+ let `protocol`: String
+
+ var isClosed: Bool {
+ mutableState.value.isClosed
+ }
+
+ func _trigger(_ event: WebSocketEvent) {
+ mutableState.withValue {
+ $0.receivedEvents.append(event)
+ $0.onEvent?(event)
+
+ if case .close(let code, let reason) = event {
+ $0.onEvent = nil
+ $0.isClosed = true
+ $0.closeCode = code
+ $0.closeReason = reason
+ }
+ }
+ }
+
+ /// Creates a pair of fake ``WebSocket``s that are connected to each other.
+ ///
+ /// Sending a message on one ``WebSocket`` will result in that same message being
+ /// received by the other.
+ ///
+ /// This can be useful in constructing tests.
+ static func fakes(`protocol`: String = "") -> (FakeWebSocket, FakeWebSocket) {
+ let (peer1, peer2) = (FakeWebSocket(protocol: `protocol`), FakeWebSocket(protocol: `protocol`))
+
+ peer1.mutableState.withValue { $0.other = peer2 }
+ peer2.mutableState.withValue { $0.other = peer1 }
+
+ return (peer1, peer2)
+ }
+}
diff --git a/Tests/RealtimeTests/MockWebSocketClient.swift b/Tests/RealtimeTests/MockWebSocketClient.swift
deleted file mode 100644
index bcabc958..00000000
--- a/Tests/RealtimeTests/MockWebSocketClient.swift
+++ /dev/null
@@ -1,98 +0,0 @@
-//
-// MockWebSocketClient.swift
-//
-//
-// Created by Guilherme Souza on 29/12/23.
-//
-
-import ConcurrencyExtras
-import Foundation
-@testable import Realtime
-import XCTestDynamicOverlay
-
-#if canImport(FoundationNetworking)
- import FoundationNetworking
-#endif
-
-final class MockWebSocketClient: WebSocketClient {
- struct MutableState {
- var receiveContinuation: AsyncThrowingStream.Continuation?
- var sentMessages: [RealtimeMessageV2] = []
- var onCallback: ((RealtimeMessageV2) -> RealtimeMessageV2?)?
- var connectContinuation: AsyncStream.Continuation?
-
- var sendMessageBuffer: [RealtimeMessageV2] = []
- var connectionStatusBuffer: [ConnectionStatus] = []
- }
-
- private let mutableState = LockIsolated(MutableState())
-
- var sentMessages: [RealtimeMessageV2] {
- mutableState.sentMessages
- }
-
- func send(_ message: RealtimeMessageV2) async throws {
- mutableState.withValue {
- $0.sentMessages.append(message)
-
- if let callback = $0.onCallback, let response = callback(message) {
- mockReceive(response)
- }
- }
- }
-
- func mockReceive(_ message: RealtimeMessageV2) {
- mutableState.withValue {
- if let continuation = $0.receiveContinuation {
- continuation.yield(message)
- } else {
- $0.sendMessageBuffer.append(message)
- }
- }
- }
-
- func on(_ callback: @escaping (RealtimeMessageV2) -> RealtimeMessageV2?) {
- mutableState.withValue {
- $0.onCallback = callback
- }
- }
-
- func receive() -> AsyncThrowingStream {
- let (stream, continuation) = AsyncThrowingStream.makeStream()
- mutableState.withValue {
- $0.receiveContinuation = continuation
-
- while !$0.sendMessageBuffer.isEmpty {
- let message = $0.sendMessageBuffer.removeFirst()
- $0.receiveContinuation?.yield(message)
- }
- }
- return stream
- }
-
- func mockConnect(_ status: ConnectionStatus) {
- mutableState.withValue {
- if let continuation = $0.connectContinuation {
- continuation.yield(status)
- } else {
- $0.connectionStatusBuffer.append(status)
- }
- }
- }
-
- func connect() -> AsyncStream {
- let (stream, continuation) = AsyncStream.makeStream()
- mutableState.withValue {
- $0.connectContinuation = continuation
-
- while !$0.connectionStatusBuffer.isEmpty {
- let status = $0.connectionStatusBuffer.removeFirst()
- $0.connectContinuation?.yield(status)
- }
- }
- return stream
- }
-
- func disconnect(code: Int?, reason: String?) {
- }
-}
diff --git a/Tests/RealtimeTests/RealtimeChannelTests.swift b/Tests/RealtimeTests/RealtimeChannelTests.swift
index a6403cd3..c213d2d6 100644
--- a/Tests/RealtimeTests/RealtimeChannelTests.swift
+++ b/Tests/RealtimeTests/RealtimeChannelTests.swift
@@ -19,7 +19,10 @@ final class RealtimeChannelTests: XCTestCase {
presence: PresenceJoinConfig(),
isPrivate: false
),
- socket: .mock,
+ socket: RealtimeClientV2(
+ url: URL(string: "https://localhost:54321/realtime/v1")!,
+ options: RealtimeClientOptions()
+ ),
logger: nil
)
@@ -126,21 +129,3 @@ final class RealtimeChannelTests: XCTestCase {
}
}
}
-
-extension Socket {
- static var mock: Socket {
- Socket(
- broadcastURL: unimplemented(),
- status: unimplemented(),
- options: unimplemented(),
- accessToken: unimplemented(),
- apiKey: unimplemented(),
- makeRef: unimplemented(),
- connect: unimplemented(),
- addChannel: unimplemented(),
- removeChannel: unimplemented(),
- push: unimplemented(),
- httpSend: unimplemented()
- )
- }
-}
diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift
index 3497738f..3e0c19cc 100644
--- a/Tests/RealtimeTests/RealtimeTests.swift
+++ b/Tests/RealtimeTests/RealtimeTests.swift
@@ -21,27 +21,32 @@ final class RealtimeTests: XCTestCase {
}
}
- var ws: MockWebSocketClient!
+ var server: FakeWebSocket!
+ var client: FakeWebSocket!
var http: HTTPClientMock!
var sut: RealtimeClientV2!
+ let heartbeatInterval: TimeInterval = 1
+ let reconnectDelay: TimeInterval = 1
+ let timeoutInterval: TimeInterval = 2
+
override func setUp() {
super.setUp()
- ws = MockWebSocketClient()
+ (client, server) = FakeWebSocket.fakes()
http = HTTPClientMock()
sut = RealtimeClientV2(
url: url,
options: RealtimeClientOptions(
headers: ["apikey": apiKey],
- heartbeatInterval: 1,
- reconnectDelay: 1,
- timeoutInterval: 2,
+ heartbeatInterval: heartbeatInterval,
+ reconnectDelay: reconnectDelay,
+ timeoutInterval: timeoutInterval,
accessToken: {
"custom.access.token"
}
),
- ws: ws,
+ wsTransport: { self.client },
http: http
)
}
@@ -75,7 +80,7 @@ final class RealtimeTests: XCTestCase {
}
.store(in: &subscriptions)
- await connectSocketAndWait()
+ await sut.connect()
XCTAssertEqual(socketStatuses.value, [.disconnected, .connecting, .connected])
@@ -93,47 +98,57 @@ final class RealtimeTests: XCTestCase {
}
.store(in: &subscriptions)
- ws.mockReceive(.messagesSubscribed)
- await channel.subscribe()
+ let subscribeTask = Task {
+ await channel.subscribe()
+ }
+ await Task.yield()
+ server.send(.messagesSubscribed)
+
+ // Wait until it subscribes to assert WS events
+ await subscribeTask.value
- assertInlineSnapshot(of: ws.sentMessages, as: .json) {
+ XCTAssertEqual(channelStatuses.value, [.unsubscribed, .subscribing, .subscribed])
+
+ assertInlineSnapshot(of: client.sentEvents.map(\.json), as: .json) {
"""
[
{
- "event" : "phx_join",
- "join_ref" : "1",
- "payload" : {
- "access_token" : "custom.access.token",
- "config" : {
- "broadcast" : {
- "ack" : false,
- "self" : false
- },
- "postgres_changes" : [
- {
- "event" : "INSERT",
- "schema" : "public",
- "table" : "messages"
+ "text" : {
+ "event" : "phx_join",
+ "join_ref" : "1",
+ "payload" : {
+ "access_token" : "custom.access.token",
+ "config" : {
+ "broadcast" : {
+ "ack" : false,
+ "self" : false
},
- {
- "event" : "UPDATE",
- "schema" : "public",
- "table" : "messages"
+ "postgres_changes" : [
+ {
+ "event" : "INSERT",
+ "schema" : "public",
+ "table" : "messages"
+ },
+ {
+ "event" : "UPDATE",
+ "schema" : "public",
+ "table" : "messages"
+ },
+ {
+ "event" : "DELETE",
+ "schema" : "public",
+ "table" : "messages"
+ }
+ ],
+ "presence" : {
+ "key" : ""
},
- {
- "event" : "DELETE",
- "schema" : "public",
- "table" : "messages"
- }
- ],
- "presence" : {
- "key" : ""
- },
- "private" : false
- }
- },
- "ref" : "1",
- "topic" : "realtime:public:messages"
+ "private" : false
+ }
+ },
+ "ref" : "1",
+ "topic" : "realtime:public:messages"
+ }
}
]
"""
@@ -144,38 +159,39 @@ final class RealtimeTests: XCTestCase {
let channel = sut.channel("public:messages")
let joinEventCount = LockIsolated(0)
- ws.on { message in
- if message.event == "heartbeat" {
- return RealtimeMessageV2(
- joinRef: message.joinRef,
- ref: message.ref,
- topic: "phoenix",
- event: "phx_reply",
- payload: [
- "response": [:],
- "status": "ok",
- ]
+ server.onEvent = { @Sendable [server] event in
+ guard let msg = event.realtimeMessage else { return }
+
+ if msg.event == "heartbeat" {
+ server?.send(
+ RealtimeMessageV2(
+ joinRef: msg.joinRef,
+ ref: msg.ref,
+ topic: "phoenix",
+ event: "phx_reply",
+ payload: ["response": [:]]
+ )
)
- }
-
- if message.event == "phx_join" {
+ } else if msg.event == "phx_join" {
joinEventCount.withValue { $0 += 1 }
// Skip first join.
if joinEventCount.value == 2 {
- return .messagesSubscribed
+ server?.send(.messagesSubscribed)
}
}
-
- return nil
}
- await connectSocketAndWait()
+ await sut.connect()
await channel.subscribe()
- try? await Task.sleep(nanoseconds: NSEC_PER_SEC * 2)
+ // Wait for the timeout for rejoining.
+ await sleep(seconds: UInt64(timeoutInterval))
- assertInlineSnapshot(of: ws.sentMessages.filter { $0.event == "phx_join" }, as: .json) {
+ let events = client.sentEvents.compactMap { $0.realtimeMessage }.filter {
+ $0.event == "phx_join"
+ }
+ assertInlineSnapshot(of: events, as: .json) {
"""
[
{
@@ -231,25 +247,27 @@ final class RealtimeTests: XCTestCase {
let expectation = expectation(description: "heartbeat")
expectation.expectedFulfillmentCount = 2
- ws.on { message in
- if message.event == "heartbeat" {
+ server.onEvent = { @Sendable [server] event in
+ guard let msg = event.realtimeMessage else { return }
+
+ if msg.event == "heartbeat" {
expectation.fulfill()
- return RealtimeMessageV2(
- joinRef: message.joinRef,
- ref: message.ref,
- topic: "phoenix",
- event: "phx_reply",
- payload: [
- "response": [:],
- "status": "ok",
- ]
+ server?.send(
+ RealtimeMessageV2(
+ joinRef: msg.joinRef,
+ ref: msg.ref,
+ topic: "phoenix",
+ event: "phx_reply",
+ payload: [
+ "response": [:],
+ "status": "ok",
+ ]
+ )
)
}
-
- return nil
}
- await connectSocketAndWait()
+ await sut.connect()
await fulfillment(of: [expectation], timeout: 3)
}
@@ -257,25 +275,21 @@ final class RealtimeTests: XCTestCase {
func testHeartbeat_whenNoResponse_shouldReconnect() async throws {
let sentHeartbeatExpectation = expectation(description: "sentHeartbeat")
- ws.on {
- if $0.event == "heartbeat" {
+ server.onEvent = { @Sendable in
+ if $0.realtimeMessage?.event == "heartbeat" {
sentHeartbeatExpectation.fulfill()
}
-
- return nil
}
let statuses = LockIsolated<[RealtimeClientStatus]>([])
-
- Task {
- for await status in sut.statusChange {
- statuses.withValue {
- $0.append(status)
- }
+ let subscription = sut.onStatusChange { status in
+ statuses.withValue {
+ $0.append(status)
}
}
- await Task.yield()
- await connectSocketAndWait()
+ defer { subscription.cancel() }
+
+ await sut.connect()
await fulfillment(of: [sentHeartbeatExpectation], timeout: 2)
@@ -283,10 +297,10 @@ final class RealtimeTests: XCTestCase {
XCTAssertNotNil(pendingHeartbeatRef)
// Wait until next heartbeat
- try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2)
+ await sleep(seconds: 2)
// Wait for reconnect delay
- try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1)
+ await sleep(seconds: 1)
XCTAssertEqual(
statuses.value,
@@ -296,6 +310,7 @@ final class RealtimeTests: XCTestCase {
.connected,
.disconnected,
.connecting,
+ .connected,
]
)
}
@@ -365,11 +380,6 @@ final class RealtimeTests: XCTestCase {
let token = "sb-token"
await sut.setAuth(token)
}
-
- private func connectSocketAndWait() async {
- ws.mockConnect(.connected)
- await sut.connect()
- }
}
extension RealtimeMessageV2 {
@@ -390,3 +400,38 @@ extension RealtimeMessageV2 {
]
)
}
+
+extension FakeWebSocket {
+ func send(_ message: RealtimeMessageV2) {
+ try! self.send(String(decoding: JSONEncoder().encode(message), as: UTF8.self))
+ }
+}
+
+extension WebSocketEvent {
+ var json: Any {
+ switch self {
+ case .binary(let data):
+ let json = try? JSONSerialization.jsonObject(with: data)
+ return ["binary": json]
+ case .text(let text):
+ let json = try? JSONSerialization.jsonObject(with: Data(text.utf8))
+ return ["text": json]
+ case .close(let code, let reason):
+ return [
+ "close": [
+ "code": code as Any,
+ "reason": reason,
+ ]
+ ]
+ }
+ }
+
+ var realtimeMessage: RealtimeMessageV2? {
+ guard case .text(let text) = self else { return nil }
+ return try? JSONDecoder().decode(RealtimeMessageV2.self, from: Data(text.utf8))
+ }
+}
+
+func sleep(seconds: UInt64) async {
+ try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds)
+}
diff --git a/Tests/RealtimeTests/_PushTests.swift b/Tests/RealtimeTests/_PushTests.swift
index 67efc7a1..943fe01e 100644
--- a/Tests/RealtimeTests/_PushTests.swift
+++ b/Tests/RealtimeTests/_PushTests.swift
@@ -6,12 +6,13 @@
//
import ConcurrencyExtras
-@testable import Realtime
import TestHelpers
import XCTest
+@testable import Realtime
+
final class _PushTests: XCTestCase {
- var ws: MockWebSocketClient!
+ var ws: FakeWebSocket!
var socket: RealtimeClientV2!
override func invokeTest() {
@@ -23,13 +24,14 @@ final class _PushTests: XCTestCase {
override func setUp() {
super.setUp()
- ws = MockWebSocketClient()
+ let (client, server) = FakeWebSocket.fakes()
+ ws = server
socket = RealtimeClientV2(
url: URL(string: "https://localhost:54321/v1/realtime")!,
options: RealtimeClientOptions(
headers: ["apiKey": "apikey"]
),
- ws: ws,
+ wsTransport: { client },
http: HTTPClientMock()
)
}
@@ -42,7 +44,7 @@ final class _PushTests: XCTestCase {
presence: .init(),
isPrivate: false
),
- socket: Socket(client: socket),
+ socket: socket,
logger: nil
)
let push = PushV2(
@@ -61,34 +63,35 @@ final class _PushTests: XCTestCase {
}
// FIXME: Flaky test, it fails some time due the task scheduling, even tho we're using withMainSerialExecutor.
-// func testPushWithAck() async {
-// let channel = RealtimeChannelV2(
-// topic: "realtime:users",
-// config: RealtimeChannelConfig(
-// broadcast: .init(acknowledgeBroadcasts: true),
-// presence: .init()
-// ),
-// socket: socket,
-// logger: nil
-// )
-// let push = PushV2(
-// channel: channel,
-// message: RealtimeMessageV2(
-// joinRef: nil,
-// ref: "1",
-// topic: "realtime:users",
-// event: "broadcast",
-// payload: [:]
-// )
-// )
-//
-// let task = Task {
-// await push.send()
-// }
-// await Task.megaYield()
-// await push.didReceive(status: .ok)
-//
-// let status = await task.value
-// XCTAssertEqual(status, .ok)
-// }
+ // func testPushWithAck() async {
+ // let channel = RealtimeChannelV2(
+ // topic: "realtime:users",
+ // config: RealtimeChannelConfig(
+ // broadcast: .init(acknowledgeBroadcasts: true),
+ // presence: .init(),
+ // isPrivate: false
+ // ),
+ // socket: Socket(client: socket),
+ // logger: nil
+ // )
+ // let push = PushV2(
+ // channel: channel,
+ // message: RealtimeMessageV2(
+ // joinRef: nil,
+ // ref: "1",
+ // topic: "realtime:users",
+ // event: "broadcast",
+ // payload: [:]
+ // )
+ // )
+ //
+ // let task = Task {
+ // await push.send()
+ // }
+ // await Task.yield()
+ // await push.didReceive(status: .ok)
+ //
+ // let status = await task.value
+ // XCTAssertEqual(status, .ok)
+ // }
}