Skip to content

Don’t crash sourcekit-lsp if a known message is missing a field #1154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/LSPTestSupport/TestJSONRPCConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public final class TestServer: MessageHandler {

private let testMessageRegistry = MessageRegistry(
requests: [EchoRequest.self, EchoError.self],
notifications: [EchoNotification.self]
notifications: [EchoNotification.self, ShowMessageNotification.self]
)

#if compiler(<5.11)
Expand Down
194 changes: 140 additions & 54 deletions Sources/LanguageServerProtocolJSONRPC/JSONRPCConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public final class JSONRPCConnection: Connection {
/// - `init`: Reference to `JSONRPCConnection` trivially can't have escaped to other isolation domains yet.
/// - `start`: Is required to be called in the same serial region as the initializer, so
/// `JSONRPCConnection` can't have escaped to other isolation domains yet.
/// - `_close`: Synchronized on `queue`.
/// - `closeAssumingOnQueue`: Synchronized on `queue`.
/// - `readyToSend`: Synchronized on `queue`.
/// - `deinit`: Can also only trivially be called once.
private nonisolated(unsafe) var state: State
Expand Down Expand Up @@ -230,6 +230,131 @@ public final class JSONRPCConnection: Connection {
}
}

/// Send a notification to the client that informs the user about a message decoding error and tells them to file an
/// issue.
///
/// `message` describes what has gone wrong to the user.
///
/// - Important: Must be called on `queue`
private func sendMessageDecodingErrorNotificationToClient(message: String) {
dispatchPrecondition(condition: .onQueue(queue))
let showMessage = ShowMessageNotification(
type: .error,
message: """
\(message). Please run 'sourcekit-lsp diagnose' to file an issue.
"""
)
self.send(.notification(showMessage))
}

/// Decode a single JSONRPC message from the given `messageBytes`.
///
/// `messageBytes` should be valid JSON, ie. this is the message sent from the client without the `Content-Length`
/// header.
///
/// If an error occurs during message parsing, this tries to recover as gracefully as possible and returns `nil`.
/// Callers should consider the message handled and ignore it when this function returns `nil`.
///
/// - Important: Must be called on `queue`
private func decodeJSONRPCMessage(messageBytes: Slice<UnsafeBufferPointer<UInt8>>) -> JSONRPCMessage? {
dispatchPrecondition(condition: .onQueue(queue))
let decoder = JSONDecoder()

// Set message registry to use for model decoding.
decoder.userInfo[.messageRegistryKey] = messageRegistry

// Setup callback for response type.
decoder.userInfo[.responseTypeCallbackKey] = { (id: RequestID) -> ResponseType.Type? in
guard let outstanding = self.outstandingRequests[id] else {
logger.error("Unknown request for \(id, privacy: .public)")
return nil
}
return outstanding.responseType
}

do {
let pointer = UnsafeMutableRawPointer(mutating: UnsafeBufferPointer(rebasing: messageBytes).baseAddress!)
return try decoder.decode(
JSONRPCMessage.self,
from: Data(bytesNoCopy: pointer, count: messageBytes.count, deallocator: .none)
)
} catch let error as MessageDecodingError {
logger.fault("Failed to decode message: \(error.forLogging)")
logger.fault("Malformed message: \(String(bytes: messageBytes, encoding: .utf8) ?? "<invalid UTF-8>")")

// We failed to decode the message. Under those circumstances try to behave as LSP-conforming as possible.
// Always log at the fault level so that we know something is going wrong from the logs.
//
// The pattern below is to handle the message in the best possible way and then `return nil` to acknowledge the
// handling. That way the compiler enforces that we handle all code paths.
switch error.messageKind {
case .request:
if let id = error.id {
// If we know it was a request and we have the request ID, simply reply to the request and tell the client
// that we couldn't parse it. That complies with LSP that all requests should eventually get a response.
logger.fault(
"Replying to request \(id, privacy: .public) with error response because we failed to decode the request"
)
self.send(.errorResponse(ResponseError(error), id: id))
return nil
}
// If we don't know the ID of the request, ignore it and show a notification to the user.
// That way the user at least knows that something is going wrong even if the client never gets a response
// for the request.
logger.fault("Ignoring request because we failed to decode the request and don't have a request ID")
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a request")
return nil
case .response:
if let id = error.id {
if let outstanding = self.outstandingRequests.removeValue(forKey: id) {
// If we received a response to a request we sent to the client, assume that the client responded with an
// error. That complies with LSP that all requests should eventually get a response.
logger.fault(
"Assuming an error response to request \(id, privacy: .public) because response from client could not be decoded"
)
outstanding.replyHandler(.failure(ResponseError(error)))
return nil
}
// If there's an error in the response but we don't even know about the request, we can ignore it.
logger.fault(
"Ignoring response to request \(id, privacy: .public) because it could not be decoded and given request ID is unknown"
)
return nil
}
// And if we can't even recover the ID the response is for, we drop it. This means that whichever code in
// sourcekit-lsp sent the request will probably never get a reply but there's nothing we can do about that.
// Ideally requests sent from sourcekit-lsp to the client would have some kind of timeout anyway.
logger.fault("Ignoring response because its request ID could not be recovered")
return nil
case .notification:
if error.code == .methodNotFound {
// If we receive a notification we don't know about, this might be a client sending a new LSP notification
// that we don't know about. It can't be very critical so we ignore it without bothering the user with an
// error notification.
logger.fault("Ignoring notification because we don't know about it's method")
return nil
}
// Ignoring any other notification might result in corrupted behavior. For example, ignoring a
// `textDocument/didChange` will result in an out-of-sync state between the editor and sourcekit-lsp.
// Warn the user about the error.
logger.fault("Ignoring notification that may cause corrupted behavior")
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a notification")
return nil
case .unknown:
// We don't know what has gone wrong. This could be any level of badness. Inform the user about it.
logger.fault("Ignoring unknown message")
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a message")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you purposefully making this slightly different to the catch case so that we know the difference from eg. a screenshot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

return nil
}
} catch {
// We don't know what has gone wrong. This could be any level of badness. Inform the user about it and ignore the
// message.
logger.fault("Ignoring unknown message")
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode an unknown message")
return nil
}
}

/// Whether we can send messages in the current state.
///
/// - parameter shouldLog: Whether to log an info message if not ready.
Expand All @@ -250,69 +375,30 @@ public final class JSONRPCConnection: Connection {
/// - Important: Must be called on `queue`
func parseAndHandleMessages(from bytes: UnsafeBufferPointer<UInt8>) -> UnsafeBufferPointer<UInt8>.SubSequence {
dispatchPrecondition(condition: .onQueue(queue))
let decoder = JSONDecoder()

// Set message registry to use for model decoding.
decoder.userInfo[.messageRegistryKey] = messageRegistry

// Setup callback for response type.
decoder.userInfo[.responseTypeCallbackKey] =
{ id in
guard let outstanding = self.outstandingRequests[id] else {
logger.error("Unknown request for \(id, privacy: .public)")
return nil
}
return outstanding.responseType
} as JSONRPCMessage.ResponseTypeCallback

var bytes = bytes[...]

MESSAGE_LOOP: while true {
// Split the messages based on the Content-Length header.
let messageBytes: Slice<UnsafeBufferPointer<UInt8>>
do {
guard let ((messageBytes, _), rest) = try bytes.jsonrpcSplitMessage() else {
guard let (header: _, message: message, rest: rest) = try bytes.jsonrpcSplitMessage() else {
return bytes
}
messageBytes = message
bytes = rest

let pointer = UnsafeMutableRawPointer(mutating: UnsafeBufferPointer(rebasing: messageBytes).baseAddress!)
let message = try decoder.decode(
JSONRPCMessage.self,
from: Data(bytesNoCopy: pointer, count: messageBytes.count, deallocator: .none)
)

handle(message)
} catch let error as MessageDecodingError {
switch error.messageKind {
case .request:
if let id = error.id {
queue.async {
self.send(.errorResponse(ResponseError(error), id: id))
}
continue MESSAGE_LOOP
}
case .response:
if let id = error.id {
if let outstanding = self.outstandingRequests.removeValue(forKey: id) {
outstanding.replyHandler(.failure(ResponseError(error)))
} else {
logger.error("error in response to unknown request \(id, privacy: .public) \(error.forLogging)")
}
continue MESSAGE_LOOP
}
case .notification:
if error.code == .methodNotFound {
logger.error("ignoring unknown notification \(error.forLogging)")
continue MESSAGE_LOOP
}
case .unknown:
break
}
// FIXME: graceful shutdown?
fatalError("fatal error encountered decoding message \(error)")
} catch {
// FIXME: graceful shutdown?
fatalError("fatal error encountered decoding message \(error)")
// We failed to parse the message header. There isn't really much we can do to recover because we lost our
// anchor in the stream where new messages start. Crashing and letting ourselves be restarted by the client is
// probably the best option.
sendMessageDecodingErrorNotificationToClient(message: "Failed to find next message in connection to editor")
fatalError("fatal error encountered while splitting JSON RPC messages \(error)")
}

guard let message = decodeJSONRPCMessage(messageBytes: messageBytes) else {
continue
}
handle(message)
}
}

Expand Down
29 changes: 18 additions & 11 deletions Sources/LanguageServerProtocolJSONRPC/MessageSplitting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import LanguageServerProtocol
public struct JSONRPCMessageHeader: Hashable {
static let contentLengthKey: [UInt8] = [UInt8]("Content-Length".utf8)
static let separator: [UInt8] = [UInt8]("\r\n".utf8)
static let colon: UInt8 = ":".utf8.first!
static let colon: UInt8 = UInt8(ascii: ":")
static let invalidKeyBytes: [UInt8] = [colon] + separator

public var contentLength: Int? = nil
Expand All @@ -25,21 +25,29 @@ public struct JSONRPCMessageHeader: Hashable {
}
}

extension RandomAccessCollection where Element == UInt8 {

/// Returns the first message range and header in `self`, or nil.
public func jsonrpcSplitMessage()
throws -> ((SubSequence, header: JSONRPCMessageHeader), SubSequence)?
{
extension RandomAccessCollection<UInt8> {
/// Tries to parse a single message from this collection of bytes.
///
/// If an entire message could be found, returns
/// - header (representing `Content-Length:<length>\r\n\r\n`)
/// - message: The data that represents the actual message as JSON
/// - rest: The remaining bytes that haven't weren't part of the first message in this collection
///
/// If a `Content-Length` header could be found but the collection doesn't have enough bytes for the entire message
/// (eg. because the `Content-Length` header has been transmitted yet but not the entire message), returns `nil`.
/// Callers should call this method again once more data is available.
@_spi(Testing)
public func jsonrpcSplitMessage() throws -> (header: JSONRPCMessageHeader, message: SubSequence, rest: SubSequence)? {
guard let (header, rest) = try jsonrcpParseHeader() else { return nil }
guard let contentLength = header.contentLength else {
throw MessageDecodingError.parseError("missing Content-Length header")
}
if contentLength > rest.count { return nil }
return ((rest.prefix(contentLength), header: header), rest.dropFirst(contentLength))
return (header: header, message: rest.prefix(contentLength), rest: rest.dropFirst(contentLength))
}

public func jsonrcpParseHeader() throws -> (JSONRPCMessageHeader, SubSequence)? {
@_spi(Testing)
public func jsonrcpParseHeader() throws -> (header: JSONRPCMessageHeader, rest: SubSequence)? {
var header = JSONRPCMessageHeader()
var slice = self[...]
while let (kv, rest) = try slice.jsonrpcParseHeaderField() {
Expand All @@ -62,6 +70,7 @@ extension RandomAccessCollection where Element == UInt8 {
return nil
}

@_spi(Testing)
public func jsonrpcParseHeaderField() throws -> ((key: SubSequence, value: SubSequence)?, SubSequence)? {
if starts(with: JSONRPCMessageHeader.separator) {
return (nil, dropFirst(JSONRPCMessageHeader.separator.count))
Expand All @@ -85,11 +94,9 @@ extension RandomAccessCollection where Element == UInt8 {
}

extension RandomAccessCollection where Element: Equatable {

/// Returns the first index where the specified subsequence appears or nil.
@inlinable
public func firstIndex(of pattern: some RandomAccessCollection<Element>) -> Index? {

if pattern.isEmpty {
return startIndex
}
Expand Down
28 changes: 28 additions & 0 deletions Tests/LanguageServerProtocolJSONRPCTests/ConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,32 @@ class ConnectionTests: XCTestCase {
}
}
}

func testMessageWithMissingParameter() async throws {
let expectation = self.expectation(description: "Received ShowMessageNotification")
await connection.client.appendOneShotNotificationHandler { (note: ShowMessageNotification) in
XCTAssertEqual(note.type, .error)
expectation.fulfill()
}

let messageContents = """
{
"method": "test_server/echo_note",
"jsonrpc": "2.0",
"params": {}
}
"""
connection.clientToServerConnection.send(message: messageContents)

try await self.fulfillmentOfOrThrow([expectation])
}
}

fileprivate extension JSONRPCConnection {
func send(message: String) {
let messageWithHeader = "Content-Length: \(message.utf8.count)\r\n\r\n\(message)".data(using: .utf8)!
messageWithHeader.withUnsafeBytes { bytes in
send(_rawData: DispatchData(bytes: bytes))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//

import LanguageServerProtocol
import LanguageServerProtocolJSONRPC
@_spi(Testing) import LanguageServerProtocolJSONRPC
import XCTest

final class MessageParsingTests: XCTestCase {
Expand All @@ -25,7 +25,7 @@ final class MessageParsingTests: XCTestCase {
line: UInt = #line
) throws {
let bytes: [UInt8] = [UInt8](string.utf8)
guard let ((content, header), rest) = try bytes.jsonrpcSplitMessage() else {
guard let (header, content, rest) = try bytes.jsonrpcSplitMessage() else {
XCTAssert(restLen == nil, "expected non-empty field", file: file, line: line)
return
}
Expand Down