From be42621f4c430d5dcd9d087b7e47619684603d83 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Fri, 29 Mar 2024 13:52:44 +0100 Subject: [PATCH] =?UTF-8?q?Don=E2=80=99t=20crash=20sourcekit-lsp=20if=20a?= =?UTF-8?q?=20known=20message=20is=20missing=20a=20field?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, we `fatalError`ing when `JSONDecoder` failed to decode a message from the client. Instead of crashing, try recovering from such invalid messages as best as possible. If we know that the state might have gotten out of sync with the client, show a notification message to the user, asking them to file an issue. rdar://112991102 --- .../TestJSONRPCConnection.swift | 2 +- .../JSONRPCConnection.swift | 194 +++++++++++++----- .../MessageSplitting.swift | 29 ++- .../ConnectionTests.swift | 28 +++ .../MessageParsingTests.swift | 4 +- 5 files changed, 189 insertions(+), 68 deletions(-) diff --git a/Sources/LSPTestSupport/TestJSONRPCConnection.swift b/Sources/LSPTestSupport/TestJSONRPCConnection.swift index 0ba13f0bf..4bec75ba2 100644 --- a/Sources/LSPTestSupport/TestJSONRPCConnection.swift +++ b/Sources/LSPTestSupport/TestJSONRPCConnection.swift @@ -197,7 +197,7 @@ public final class TestServer: MessageHandler { private let testMessageRegistry = MessageRegistry( requests: [EchoRequest.self, EchoError.self], - notifications: [EchoNotification.self] + notifications: [EchoNotification.self, ShowMessageNotification.self] ) #if compiler(<5.11) diff --git a/Sources/LanguageServerProtocolJSONRPC/JSONRPCConnection.swift b/Sources/LanguageServerProtocolJSONRPC/JSONRPCConnection.swift index 22621cef1..961bf3839 100644 --- a/Sources/LanguageServerProtocolJSONRPC/JSONRPCConnection.swift +++ b/Sources/LanguageServerProtocolJSONRPC/JSONRPCConnection.swift @@ -56,7 +56,7 @@ public final class JSONRPCConnection: Connection { /// - `init`: Reference to `JSONRPCConnection` trivially can't have escaped to other isolation domains yet. /// - `start`: Is required to be called in the same serial region as the initializer, so /// `JSONRPCConnection` can't have escaped to other isolation domains yet. - /// - `_close`: Synchronized on `queue`. + /// - `closeAssumingOnQueue`: Synchronized on `queue`. /// - `readyToSend`: Synchronized on `queue`. /// - `deinit`: Can also only trivially be called once. private nonisolated(unsafe) var state: State @@ -230,6 +230,131 @@ public final class JSONRPCConnection: Connection { } } + /// Send a notification to the client that informs the user about a message decoding error and tells them to file an + /// issue. + /// + /// `message` describes what has gone wrong to the user. + /// + /// - Important: Must be called on `queue` + private func sendMessageDecodingErrorNotificationToClient(message: String) { + dispatchPrecondition(condition: .onQueue(queue)) + let showMessage = ShowMessageNotification( + type: .error, + message: """ + \(message). Please run 'sourcekit-lsp diagnose' to file an issue. + """ + ) + self.send(.notification(showMessage)) + } + + /// Decode a single JSONRPC message from the given `messageBytes`. + /// + /// `messageBytes` should be valid JSON, ie. this is the message sent from the client without the `Content-Length` + /// header. + /// + /// If an error occurs during message parsing, this tries to recover as gracefully as possible and returns `nil`. + /// Callers should consider the message handled and ignore it when this function returns `nil`. + /// + /// - Important: Must be called on `queue` + private func decodeJSONRPCMessage(messageBytes: Slice>) -> JSONRPCMessage? { + dispatchPrecondition(condition: .onQueue(queue)) + let decoder = JSONDecoder() + + // Set message registry to use for model decoding. + decoder.userInfo[.messageRegistryKey] = messageRegistry + + // Setup callback for response type. + decoder.userInfo[.responseTypeCallbackKey] = { (id: RequestID) -> ResponseType.Type? in + guard let outstanding = self.outstandingRequests[id] else { + logger.error("Unknown request for \(id, privacy: .public)") + return nil + } + return outstanding.responseType + } + + do { + let pointer = UnsafeMutableRawPointer(mutating: UnsafeBufferPointer(rebasing: messageBytes).baseAddress!) + return try decoder.decode( + JSONRPCMessage.self, + from: Data(bytesNoCopy: pointer, count: messageBytes.count, deallocator: .none) + ) + } catch let error as MessageDecodingError { + logger.fault("Failed to decode message: \(error.forLogging)") + logger.fault("Malformed message: \(String(bytes: messageBytes, encoding: .utf8) ?? "")") + + // We failed to decode the message. Under those circumstances try to behave as LSP-conforming as possible. + // Always log at the fault level so that we know something is going wrong from the logs. + // + // The pattern below is to handle the message in the best possible way and then `return nil` to acknowledge the + // handling. That way the compiler enforces that we handle all code paths. + switch error.messageKind { + case .request: + if let id = error.id { + // If we know it was a request and we have the request ID, simply reply to the request and tell the client + // that we couldn't parse it. That complies with LSP that all requests should eventually get a response. + logger.fault( + "Replying to request \(id, privacy: .public) with error response because we failed to decode the request" + ) + self.send(.errorResponse(ResponseError(error), id: id)) + return nil + } + // If we don't know the ID of the request, ignore it and show a notification to the user. + // That way the user at least knows that something is going wrong even if the client never gets a response + // for the request. + logger.fault("Ignoring request because we failed to decode the request and don't have a request ID") + sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a request") + return nil + case .response: + if let id = error.id { + if let outstanding = self.outstandingRequests.removeValue(forKey: id) { + // If we received a response to a request we sent to the client, assume that the client responded with an + // error. That complies with LSP that all requests should eventually get a response. + logger.fault( + "Assuming an error response to request \(id, privacy: .public) because response from client could not be decoded" + ) + outstanding.replyHandler(.failure(ResponseError(error))) + return nil + } + // If there's an error in the response but we don't even know about the request, we can ignore it. + logger.fault( + "Ignoring response to request \(id, privacy: .public) because it could not be decoded and given request ID is unknown" + ) + return nil + } + // And if we can't even recover the ID the response is for, we drop it. This means that whichever code in + // sourcekit-lsp sent the request will probably never get a reply but there's nothing we can do about that. + // Ideally requests sent from sourcekit-lsp to the client would have some kind of timeout anyway. + logger.fault("Ignoring response because its request ID could not be recovered") + return nil + case .notification: + if error.code == .methodNotFound { + // If we receive a notification we don't know about, this might be a client sending a new LSP notification + // that we don't know about. It can't be very critical so we ignore it without bothering the user with an + // error notification. + logger.fault("Ignoring notification because we don't know about it's method") + return nil + } + // Ignoring any other notification might result in corrupted behavior. For example, ignoring a + // `textDocument/didChange` will result in an out-of-sync state between the editor and sourcekit-lsp. + // Warn the user about the error. + logger.fault("Ignoring notification that may cause corrupted behavior") + sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a notification") + return nil + case .unknown: + // We don't know what has gone wrong. This could be any level of badness. Inform the user about it. + logger.fault("Ignoring unknown message") + sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a message") + return nil + } + } catch { + // We don't know what has gone wrong. This could be any level of badness. Inform the user about it and ignore the + // message. + logger.fault("Ignoring unknown message") + sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode an unknown message") + return nil + } + } + /// Whether we can send messages in the current state. /// /// - parameter shouldLog: Whether to log an info message if not ready. @@ -250,69 +375,30 @@ public final class JSONRPCConnection: Connection { /// - Important: Must be called on `queue` func parseAndHandleMessages(from bytes: UnsafeBufferPointer) -> UnsafeBufferPointer.SubSequence { dispatchPrecondition(condition: .onQueue(queue)) - let decoder = JSONDecoder() - - // Set message registry to use for model decoding. - decoder.userInfo[.messageRegistryKey] = messageRegistry - - // Setup callback for response type. - decoder.userInfo[.responseTypeCallbackKey] = - { id in - guard let outstanding = self.outstandingRequests[id] else { - logger.error("Unknown request for \(id, privacy: .public)") - return nil - } - return outstanding.responseType - } as JSONRPCMessage.ResponseTypeCallback var bytes = bytes[...] MESSAGE_LOOP: while true { + // Split the messages based on the Content-Length header. + let messageBytes: Slice> do { - guard let ((messageBytes, _), rest) = try bytes.jsonrpcSplitMessage() else { + guard let (header: _, message: message, rest: rest) = try bytes.jsonrpcSplitMessage() else { return bytes } + messageBytes = message bytes = rest - - let pointer = UnsafeMutableRawPointer(mutating: UnsafeBufferPointer(rebasing: messageBytes).baseAddress!) - let message = try decoder.decode( - JSONRPCMessage.self, - from: Data(bytesNoCopy: pointer, count: messageBytes.count, deallocator: .none) - ) - - handle(message) - } catch let error as MessageDecodingError { - switch error.messageKind { - case .request: - if let id = error.id { - queue.async { - self.send(.errorResponse(ResponseError(error), id: id)) - } - continue MESSAGE_LOOP - } - case .response: - if let id = error.id { - if let outstanding = self.outstandingRequests.removeValue(forKey: id) { - outstanding.replyHandler(.failure(ResponseError(error))) - } else { - logger.error("error in response to unknown request \(id, privacy: .public) \(error.forLogging)") - } - continue MESSAGE_LOOP - } - case .notification: - if error.code == .methodNotFound { - logger.error("ignoring unknown notification \(error.forLogging)") - continue MESSAGE_LOOP - } - case .unknown: - break - } - // FIXME: graceful shutdown? - fatalError("fatal error encountered decoding message \(error)") } catch { - // FIXME: graceful shutdown? - fatalError("fatal error encountered decoding message \(error)") + // We failed to parse the message header. There isn't really much we can do to recover because we lost our + // anchor in the stream where new messages start. Crashing and letting ourselves be restarted by the client is + // probably the best option. + sendMessageDecodingErrorNotificationToClient(message: "Failed to find next message in connection to editor") + fatalError("fatal error encountered while splitting JSON RPC messages \(error)") + } + + guard let message = decodeJSONRPCMessage(messageBytes: messageBytes) else { + continue } + handle(message) } } diff --git a/Sources/LanguageServerProtocolJSONRPC/MessageSplitting.swift b/Sources/LanguageServerProtocolJSONRPC/MessageSplitting.swift index 0c1bdd6f4..6c965efd8 100644 --- a/Sources/LanguageServerProtocolJSONRPC/MessageSplitting.swift +++ b/Sources/LanguageServerProtocolJSONRPC/MessageSplitting.swift @@ -15,7 +15,7 @@ import LanguageServerProtocol public struct JSONRPCMessageHeader: Hashable { static let contentLengthKey: [UInt8] = [UInt8]("Content-Length".utf8) static let separator: [UInt8] = [UInt8]("\r\n".utf8) - static let colon: UInt8 = ":".utf8.first! + static let colon: UInt8 = UInt8(ascii: ":") static let invalidKeyBytes: [UInt8] = [colon] + separator public var contentLength: Int? = nil @@ -25,21 +25,29 @@ public struct JSONRPCMessageHeader: Hashable { } } -extension RandomAccessCollection where Element == UInt8 { - - /// Returns the first message range and header in `self`, or nil. - public func jsonrpcSplitMessage() - throws -> ((SubSequence, header: JSONRPCMessageHeader), SubSequence)? - { +extension RandomAccessCollection { + /// Tries to parse a single message from this collection of bytes. + /// + /// If an entire message could be found, returns + /// - header (representing `Content-Length:\r\n\r\n`) + /// - message: The data that represents the actual message as JSON + /// - rest: The remaining bytes that haven't weren't part of the first message in this collection + /// + /// If a `Content-Length` header could be found but the collection doesn't have enough bytes for the entire message + /// (eg. because the `Content-Length` header has been transmitted yet but not the entire message), returns `nil`. + /// Callers should call this method again once more data is available. + @_spi(Testing) + public func jsonrpcSplitMessage() throws -> (header: JSONRPCMessageHeader, message: SubSequence, rest: SubSequence)? { guard let (header, rest) = try jsonrcpParseHeader() else { return nil } guard let contentLength = header.contentLength else { throw MessageDecodingError.parseError("missing Content-Length header") } if contentLength > rest.count { return nil } - return ((rest.prefix(contentLength), header: header), rest.dropFirst(contentLength)) + return (header: header, message: rest.prefix(contentLength), rest: rest.dropFirst(contentLength)) } - public func jsonrcpParseHeader() throws -> (JSONRPCMessageHeader, SubSequence)? { + @_spi(Testing) + public func jsonrcpParseHeader() throws -> (header: JSONRPCMessageHeader, rest: SubSequence)? { var header = JSONRPCMessageHeader() var slice = self[...] while let (kv, rest) = try slice.jsonrpcParseHeaderField() { @@ -62,6 +70,7 @@ extension RandomAccessCollection where Element == UInt8 { return nil } + @_spi(Testing) public func jsonrpcParseHeaderField() throws -> ((key: SubSequence, value: SubSequence)?, SubSequence)? { if starts(with: JSONRPCMessageHeader.separator) { return (nil, dropFirst(JSONRPCMessageHeader.separator.count)) @@ -85,11 +94,9 @@ extension RandomAccessCollection where Element == UInt8 { } extension RandomAccessCollection where Element: Equatable { - /// Returns the first index where the specified subsequence appears or nil. @inlinable public func firstIndex(of pattern: some RandomAccessCollection) -> Index? { - if pattern.isEmpty { return startIndex } diff --git a/Tests/LanguageServerProtocolJSONRPCTests/ConnectionTests.swift b/Tests/LanguageServerProtocolJSONRPCTests/ConnectionTests.swift index 5dfb964fa..0f8662da3 100644 --- a/Tests/LanguageServerProtocolJSONRPCTests/ConnectionTests.swift +++ b/Tests/LanguageServerProtocolJSONRPCTests/ConnectionTests.swift @@ -279,4 +279,32 @@ class ConnectionTests: XCTestCase { } } } + + func testMessageWithMissingParameter() async throws { + let expectation = self.expectation(description: "Received ShowMessageNotification") + await connection.client.appendOneShotNotificationHandler { (note: ShowMessageNotification) in + XCTAssertEqual(note.type, .error) + expectation.fulfill() + } + + let messageContents = """ + { + "method": "test_server/echo_note", + "jsonrpc": "2.0", + "params": {} + } + """ + connection.clientToServerConnection.send(message: messageContents) + + try await self.fulfillmentOfOrThrow([expectation]) + } +} + +fileprivate extension JSONRPCConnection { + func send(message: String) { + let messageWithHeader = "Content-Length: \(message.utf8.count)\r\n\r\n\(message)".data(using: .utf8)! + messageWithHeader.withUnsafeBytes { bytes in + send(_rawData: DispatchData(bytes: bytes)) + } + } } diff --git a/Tests/LanguageServerProtocolJSONRPCTests/MessageParsingTests.swift b/Tests/LanguageServerProtocolJSONRPCTests/MessageParsingTests.swift index 9e9ea9b16..ff38ada3e 100644 --- a/Tests/LanguageServerProtocolJSONRPCTests/MessageParsingTests.swift +++ b/Tests/LanguageServerProtocolJSONRPCTests/MessageParsingTests.swift @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// import LanguageServerProtocol -import LanguageServerProtocolJSONRPC +@_spi(Testing) import LanguageServerProtocolJSONRPC import XCTest final class MessageParsingTests: XCTestCase { @@ -25,7 +25,7 @@ final class MessageParsingTests: XCTestCase { line: UInt = #line ) throws { let bytes: [UInt8] = [UInt8](string.utf8) - guard let ((content, header), rest) = try bytes.jsonrpcSplitMessage() else { + guard let (header, content, rest) = try bytes.jsonrpcSplitMessage() else { XCTAssert(restLen == nil, "expected non-empty field", file: file, line: line) return }