Skip to content

Commit 5072b01

Browse files
authored
KTOR-9102 Fix closing underlying connection in Java engine, fix sse closing (#5246)
1 parent 835d7f9 commit 5072b01

File tree

5 files changed

+94
-40
lines changed

5 files changed

+94
-40
lines changed

ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/sse/DefaultClientSSESession.kt

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44

55
package io.ktor.client.plugins.sse
66

7-
import io.ktor.client.network.sockets.SocketTimeoutException
7+
import io.ktor.client.network.sockets.*
88
import io.ktor.client.request.*
99
import io.ktor.http.*
1010
import io.ktor.sse.*
11+
import io.ktor.util.*
1112
import io.ktor.util.logging.*
12-
import io.ktor.util.rootCause
1313
import io.ktor.utils.io.*
1414
import io.ktor.utils.io.CancellationException
15+
import kotlinx.atomicfu.atomic
1516
import kotlinx.coroutines.*
1617
import kotlinx.coroutines.flow.Flow
1718
import kotlinx.coroutines.flow.catch
1819
import kotlinx.coroutines.flow.flow
1920
import kotlinx.coroutines.flow.onCompletion
21+
import kotlinx.io.IOException
2022
import kotlin.coroutines.CoroutineContext
2123

2224
@OptIn(InternalAPI::class)
@@ -33,13 +35,12 @@ public class DefaultClientSSESession(
3335
private val maxReconnectionAttempts = content.maxReconnectionAttempts
3436
private var needToReconnect = maxReconnectionAttempts > 0
3537
private var bodyBuffer: BodyBuffer = content.bufferPolicy.toBodyBuffer()
36-
3738
private val initialRequest = content.initialRequest
38-
3939
private val clientForReconnection = initialRequest.attributes[SSEClientForReconnectionAttr]
40-
4140
private val callContext = content.callContext
4241

42+
private val closed = atomic(false)
43+
4344
override fun bodyBuffer(): ByteArray = bodyBuffer.toByteArray()
4445

4546
public constructor(
@@ -64,27 +65,26 @@ public class DefaultClientSSESession(
6465
if (needToReconnect) {
6566
doReconnection()
6667
} else {
67-
close()
68+
break
6869
}
6970
}
7071
}.catch { cause ->
71-
when (cause) {
72-
is CancellationException -> {
73-
// CancellationException will be handled by onCompletion operator
74-
}
72+
if (cause is CancellationException) {
73+
return@catch
74+
}
7575

76+
LOGGER.trace { "Error during SSE session processing: $cause" }
77+
throw cause
78+
}.onCompletion { cause ->
79+
close()
80+
81+
when (cause) {
82+
null -> return@onCompletion
83+
is CancellationException -> return@onCompletion
7684
else -> {
77-
LOGGER.trace { "Error during SSE session processing: $cause" }
78-
close()
79-
throw cause
85+
throw SSEClientException(cause = cause, message = cause.message)
8086
}
8187
}
82-
}.onCompletion { cause ->
83-
// Because catch operator only catch throwable occurs in upstream flow, so we use onCompletion operator instead
84-
// to handle CancellationException occurs in either upstream flow or downstream flow.
85-
if (cause is CancellationException) {
86-
close()
87-
}
8888
}
8989

9090
init {
@@ -144,6 +144,8 @@ public class DefaultClientSSESession(
144144
get() = _incoming
145145

146146
private fun close() {
147+
if (!closed.compareAndSet(expect = false, update = true)) return
148+
147149
coroutineContext.cancel()
148150
input.cancel()
149151
callContext.cancel()
@@ -152,7 +154,7 @@ public class DefaultClientSSESession(
152154
private suspend fun ByteReadChannel.tryParseEvent(): ServerSentEvent? =
153155
try {
154156
parseEvent()
155-
} catch (cause: ClosedByteChannelException) {
157+
} catch (cause: IOException) {
156158
val rootCause = cause.rootCause
157159
if (rootCause is SocketTimeoutException) {
158160
throw rootCause

ktor-client/ktor-client-java/jvm/src/io/ktor/client/engine/java/JavaHttpResponseBodyHandler.kt

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,26 @@
44

55
package io.ktor.client.engine.java
66

7-
import io.ktor.client.plugins.sse.*
87
import io.ktor.client.request.*
98
import io.ktor.http.*
109
import io.ktor.util.date.*
1110
import io.ktor.utils.io.*
12-
import kotlinx.atomicfu.*
13-
import kotlinx.coroutines.*
14-
import kotlinx.coroutines.channels.*
15-
import java.io.*
16-
import java.net.http.*
17-
import java.nio.*
18-
import java.util.concurrent.*
19-
import kotlin.coroutines.*
11+
import kotlinx.atomicfu.atomic
12+
import kotlinx.coroutines.CoroutineScope
13+
import kotlinx.coroutines.Job
14+
import kotlinx.coroutines.channels.Channel
15+
import kotlinx.coroutines.channels.ClosedReceiveChannelException
16+
import kotlinx.coroutines.channels.consume
17+
import kotlinx.coroutines.isActive
18+
import kotlinx.coroutines.launch
19+
import java.io.IOException
20+
import java.net.http.HttpClient
21+
import java.net.http.HttpResponse
22+
import java.nio.ByteBuffer
23+
import java.util.concurrent.CompletableFuture
24+
import java.util.concurrent.CompletionStage
25+
import java.util.concurrent.Flow
26+
import kotlin.coroutines.CoroutineContext
2027

2128
internal class JavaHttpResponseBodyHandler(
2229
private val coroutineContext: CoroutineContext,
@@ -94,8 +101,7 @@ internal class JavaHttpResponseBodyHandler(
94101
}
95102
}.apply {
96103
invokeOnCompletion {
97-
responseChannel.close(it)
98-
consumerJob.complete()
104+
close(it)
99105
}
100106
}
101107
}
@@ -147,7 +153,7 @@ internal class JavaHttpResponseBodyHandler(
147153
return CompletableFuture.completedStage(httpResponse)
148154
}
149155

150-
private fun close(cause: Throwable) {
156+
private fun close(cause: Throwable?) {
151157
if (!closed.compareAndSet(expect = false, update = true)) {
152158
return
153159
}
@@ -156,7 +162,7 @@ internal class JavaHttpResponseBodyHandler(
156162
queue.close(cause)
157163
subscription.getAndSet(null)?.cancel()
158164
} finally {
159-
consumerJob.completeExceptionally(cause)
165+
cause?.let(consumerJob::completeExceptionally) ?: consumerJob.complete()
160166
responseChannel.cancel(cause)
161167
}
162168
}
Binary file not shown.

ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/plugins/ServerSentEventsTest.kt

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ class ServerSentEventsTest : ClientLoader() {
632632
}
633633
}
634634

635-
assertTrue(events.size == 7)
635+
assertEquals(7, events.size)
636636
events.forEachIndexed { index, event ->
637637
assertEquals(index + 1, event.id?.toInt())
638638
}
@@ -788,7 +788,7 @@ class ServerSentEventsTest : ClientLoader() {
788788
test { client ->
789789
try {
790790
client.sse("$TEST_SERVER/sse/hello") {
791-
incoming.collect { it }
791+
incoming.collect { }
792792
throw IllegalStateException("exception")
793793
}
794794
} catch (e: SSEClientException) {
@@ -810,7 +810,7 @@ class ServerSentEventsTest : ClientLoader() {
810810
client.sse(urlString = "$TEST_SERVER/sse/hello", {
811811
bufferPolicy(SSEBufferPolicy.Off)
812812
}) {
813-
incoming.collect { it }
813+
incoming.collect { }
814814
throw IllegalStateException("exception")
815815
}
816816
} catch (e: SSEClientException) {
@@ -842,7 +842,7 @@ class ServerSentEventsTest : ClientLoader() {
842842
bufferPolicy(SSEBufferPolicy.LastLines(count))
843843
}
844844
) {
845-
incoming.collect { it }
845+
incoming.collect { }
846846
throw IllegalStateException("exception")
847847
}
848848
} catch (e: SSEClientException) {
@@ -873,7 +873,7 @@ class ServerSentEventsTest : ClientLoader() {
873873
bufferPolicy(SSEBufferPolicy.LastEvent)
874874
}
875875
) {
876-
incoming.collect { it }
876+
incoming.collect { }
877877
throw IllegalStateException("exception")
878878
}
879879
} catch (e: SSEClientException) {
@@ -896,7 +896,7 @@ class ServerSentEventsTest : ClientLoader() {
896896
bufferPolicy(SSEBufferPolicy.LastEvents(2))
897897
}
898898
) {
899-
incoming.collect { it }
899+
incoming.collect {}
900900
throw IllegalStateException("exception")
901901
}
902902
} catch (e: SSEClientException) {
@@ -1006,6 +1006,28 @@ class ServerSentEventsTest : ClientLoader() {
10061006
}
10071007
}
10081008

1009+
@Test
1010+
fun testCancellingUnderlyingConnection() = clientTests {
1011+
config {
1012+
install(SSE)
1013+
}
1014+
1015+
test { client ->
1016+
val sseSession = client.sseSession("$TEST_SERVER/sse/active-sessions")
1017+
sseSession.incoming.collect {
1018+
assertEquals("ok", it.data)
1019+
sseSession.cancel()
1020+
}
1021+
withTimeout(5000) {
1022+
while (true) {
1023+
val count = client.get("$TEST_SERVER/sse/active-sessions-count").bodyAsText().toInt()
1024+
if (count == 0) break
1025+
delay(100)
1026+
}
1027+
}
1028+
}
1029+
}
1030+
10091031
private fun checkBody(expected: String, actual: String?, count: Int? = null) {
10101032
assertNotNull(actual)
10111033
val expectedLines = expected.split("\r\n").let { lines ->

ktor-test-server/src/main/kotlin/test/server/tests/ServerSentEvents.kt

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,25 @@ internal fun Application.serverSentEvents() {
178178
call.response.header(HttpHeaders.ContentType, ContentType.Application.Json.toString())
179179
call.respond(HttpStatusCode.BadRequest, "{ 'error': 'Bad request' }")
180180
}
181+
182+
var activeSessions = 0
183+
get("/active-sessions-count") {
184+
call.respond("$activeSessions")
185+
}
186+
get("/active-sessions") {
187+
activeSessions = 1
188+
call.respondBytesWriter(contentType = ContentType.Text.EventStream) {
189+
try {
190+
while (!isClosedForWrite) {
191+
writeSseEvent(SseEvent("ok"))
192+
flush()
193+
delay(100)
194+
}
195+
} finally {
196+
activeSessions = 0
197+
}
198+
}
199+
}
181200
}
182201
}
183202
}
@@ -188,7 +207,12 @@ private suspend fun ApplicationCall.respondSseEvents(events: Flow<SseEvent>) {
188207
}
189208
}
190209

191-
private suspend fun ByteWriteChannel.writeSseEvents(events: Flow<SseEvent>): Unit = events.collect { event ->
210+
211+
private suspend fun ByteWriteChannel.writeSseEvents(events: Flow<SseEvent>): Unit = events.collect {
212+
writeSseEvent(it)
213+
}
214+
215+
private suspend fun ByteWriteChannel.writeSseEvent(event: SseEvent) {
192216
if (event.id != null) {
193217
writeStringUtf8WithNewlineAndFlush("id: ${event.id}")
194218
}

0 commit comments

Comments
 (0)