Skip to content

Commit 042d286

Browse files
committed
Allow requests to be sent from server to client
This change enables a server to send requests to a client and wait for a response. This is a new capability that has been added to the vscode-languageclient library which we'll start taking advantage of for sending prompt requests to the UI.
1 parent ff524df commit 042d286

File tree

12 files changed

+208
-209
lines changed

12 files changed

+208
-209
lines changed

src/PowerShellEditorServices.Channel.WebSocket/WebsocketServerChannel.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ protected EditorServiceWebSocketConnection()
110110
Channel = new WebSocketServerChannel(this);
111111
}
112112

113-
protected ProtocolServer Server { get; set; }
113+
protected ProtocolEndpoint Server { get; set; }
114114

115115
protected WebSocketServerChannel Channel { get; private set; }
116116

src/PowerShellEditorServices.Host/Program.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
44
//
55

6+
using Microsoft.PowerShell.EditorServices.Protocol.MessageProtocol;
67
using Microsoft.PowerShell.EditorServices.Protocol.Server;
78
using Microsoft.PowerShell.EditorServices.Utility;
89
using System;
@@ -41,7 +42,7 @@ static void Main(string[] args)
4142
// Catch unhandled exceptions for logging purposes
4243
AppDomain.CurrentDomain.UnhandledException += CurrentDomain_UnhandledException;
4344

44-
ProtocolServer server = null;
45+
ProtocolEndpoint server = null;
4546
if (runDebugAdapter)
4647
{
4748
logPath = logPath ?? "DebugAdapter.log";
@@ -105,7 +106,7 @@ static void Main(string[] args)
105106
fileVersionInfo.FileVersion));
106107

107108
// Start the server
108-
server.Start();
109+
server.Start().Wait();
109110
Logger.Write(LogLevel.Normal, "PowerShell Editor Services Host started!");
110111

111112
// Wait for the server to finish

src/PowerShellEditorServices.Protocol/Client/DebugAdapterClientBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace Microsoft.PowerShell.EditorServices.Protocol.Client
1212
{
13-
public class DebugAdapterClient : ProtocolClient
13+
public class DebugAdapterClient : ProtocolEndpoint
1414
{
1515
public DebugAdapterClient(ChannelBase clientChannel)
1616
: base(clientChannel, MessageProtocolType.DebugAdapter)

src/PowerShellEditorServices.Protocol/Client/LanguageClientBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Microsoft.PowerShell.EditorServices.Protocol.Client
1313
/// <summary>
1414
/// Provides a base implementation for language server clients.
1515
/// </summary>
16-
public abstract class LanguageClientBase : ProtocolClient
16+
public abstract class LanguageClientBase : ProtocolEndpoint
1717
{
1818
/// <summary>
1919
/// Initializes an instance of the language client using the

src/PowerShellEditorServices.Protocol/Server/IEventWriter.cs renamed to src/PowerShellEditorServices.Protocol/MessageProtocol/IMessageSender.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
44
//
55

6-
using Microsoft.PowerShell.EditorServices.Protocol.MessageProtocol;
76
using System.Threading.Tasks;
87

9-
namespace Microsoft.PowerShell.EditorServices.Protocol.Server
8+
namespace Microsoft.PowerShell.EditorServices.Protocol.MessageProtocol
109
{
11-
internal interface IEventWriter
10+
internal interface IMessageSender
1211
{
1312
Task SendEvent<TParams>(
1413
EventType<TParams> eventType,
1514
TParams eventParams);
15+
16+
Task<TResult> SendRequest<TParams, TResult>(
17+
RequestType<TParams, TResult> requestType,
18+
TParams requestParams,
19+
bool waitForResponse);
1620
}
1721
}
1822

src/PowerShellEditorServices.Protocol/Client/ProtocolClient.cs renamed to src/PowerShellEditorServices.Protocol/MessageProtocol/ProtocolEndpoint.cs

Lines changed: 118 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,41 @@
1111
using System.Threading;
1212
using System.Threading.Tasks;
1313

14-
namespace Microsoft.PowerShell.EditorServices.Protocol.Client
14+
namespace Microsoft.PowerShell.EditorServices.Protocol.MessageProtocol
1515
{
16-
public class ProtocolClient
16+
/// <summary>
17+
/// Provides behavior for a client or server endpoint that
18+
/// communicates using the specified protocol.
19+
/// </summary>
20+
public class ProtocolEndpoint : IMessageSender
1721
{
1822
private bool isStarted;
1923
private int currentMessageId;
20-
private ChannelBase clientChannel;
24+
private ChannelBase protocolChannel;
2125
private MessageProtocolType messageProtocolType;
26+
private TaskCompletionSource<bool> endpointExitedTask;
2227
private SynchronizationContext originalSynchronizationContext;
2328

2429
private Dictionary<string, TaskCompletionSource<Message>> pendingRequests =
2530
new Dictionary<string, TaskCompletionSource<Message>>();
2631

2732
/// <summary>
28-
/// Initializes an instance of the protocol client using the
33+
/// Initializes an instance of the protocol server using the
2934
/// specified channel for communication.
3035
/// </summary>
31-
/// <param name="clientChannel">The channel to use for communication with the server.</param>
32-
/// <param name="messageProtocolType">The type of message protocol used by the server.</param>
33-
public ProtocolClient(
34-
ChannelBase clientChannel,
36+
/// <param name="protocolChannel">
37+
/// The channel to use for communication with the connected endpoint.
38+
/// </param>
39+
/// <param name="messageProtocolType">
40+
/// The type of message protocol used by the endpoint.
41+
/// </param>
42+
public ProtocolEndpoint(
43+
ChannelBase protocolChannel,
3544
MessageProtocolType messageProtocolType)
3645
{
37-
this.clientChannel = clientChannel;
46+
this.protocolChannel = protocolChannel;
3847
this.messageProtocolType = messageProtocolType;
48+
this.originalSynchronizationContext = SynchronizationContext.Current;
3949
}
4050

4151
/// <summary>
@@ -46,35 +56,50 @@ public async Task Start()
4656
{
4757
if (!this.isStarted)
4858
{
49-
// Start the provided client channel
50-
this.clientChannel.Start(this.messageProtocolType);
59+
// Start the provided protocol channel
60+
this.protocolChannel.Start(this.messageProtocolType);
5161

5262
// Set the handler for any message responses that come back
53-
this.clientChannel.MessageDispatcher.SetResponseHandler(this.HandleResponse);
63+
this.protocolChannel.MessageDispatcher.SetResponseHandler(this.HandleResponse);
5464

5565
// Listen for unhandled exceptions from the dispatcher
56-
this.clientChannel.MessageDispatcher.UnhandledException += MessageDispatcher_UnhandledException;
66+
this.protocolChannel.MessageDispatcher.UnhandledException += MessageDispatcher_UnhandledException;
5767

58-
// Notify implementation about client start
68+
// Notify implementation about endpoint start
5969
await this.OnStart();
6070

61-
// Client is now started
71+
// Endpoint is now started
6272
this.isStarted = true;
6373
}
6474
}
6575

76+
public void WaitForExit()
77+
{
78+
this.endpointExitedTask = new TaskCompletionSource<bool>();
79+
this.endpointExitedTask.Task.Wait();
80+
}
81+
6682
public async Task Stop()
6783
{
6884
if (this.isStarted)
6985
{
86+
// Make sure no future calls try to stop the endpoint during shutdown
87+
this.isStarted = false;
88+
7089
// Stop the implementation first
7190
await this.OnStop();
91+
this.protocolChannel.Stop();
7292

73-
this.clientChannel.Stop();
74-
this.isStarted = false;
93+
// Notify anyone waiting for exit
94+
if (this.endpointExitedTask != null)
95+
{
96+
this.endpointExitedTask.SetResult(true);
97+
}
7598
}
7699
}
77100

101+
#region Message Sending
102+
78103
/// <summary>
79104
/// Sends a request to the server
80105
/// </summary>
@@ -107,7 +132,7 @@ public async Task<TResult> SendRequest<TParams, TResult>(
107132
responseTask);
108133
}
109134

110-
await this.clientChannel.MessageWriter.WriteRequest<TParams, TResult>(
135+
await this.protocolChannel.MessageWriter.WriteRequest<TParams, TResult>(
111136
requestType,
112137
requestParams,
113138
this.currentMessageId);
@@ -128,19 +153,64 @@ await this.clientChannel.MessageWriter.WriteRequest<TParams, TResult>(
128153
}
129154
}
130155

131-
public async Task SendEvent<TParams>(EventType<TParams> eventType, TParams eventParams)
156+
/// <summary>
157+
/// Sends an event to the channel's endpoint.
158+
/// </summary>
159+
/// <typeparam name="TParams">The event parameter type.</typeparam>
160+
/// <param name="eventType">The type of event being sent.</param>
161+
/// <param name="eventParams">The event parameters being sent.</param>
162+
/// <returns>A Task that tracks completion of the send operation.</returns>
163+
public Task SendEvent<TParams>(
164+
EventType<TParams> eventType,
165+
TParams eventParams)
166+
{
167+
// Some events could be raised from a different thread.
168+
// To ensure that messages are written serially, dispatch
169+
// dispatch the SendEvent call to the message loop thread.
170+
171+
if (!this.protocolChannel.MessageDispatcher.InMessageLoopThread)
172+
{
173+
TaskCompletionSource<bool> writeTask = new TaskCompletionSource<bool>();
174+
175+
this.protocolChannel.MessageDispatcher.SynchronizationContext.Post(
176+
async (obj) =>
177+
{
178+
await this.protocolChannel.MessageWriter.WriteEvent(
179+
eventType,
180+
eventParams);
181+
182+
writeTask.SetResult(true);
183+
}, null);
184+
185+
return writeTask.Task;
186+
}
187+
else
188+
{
189+
return this.protocolChannel.MessageWriter.WriteEvent(
190+
eventType,
191+
eventParams);
192+
}
193+
}
194+
195+
#endregion
196+
197+
#region Message Handling
198+
199+
public void SetRequestHandler<TParams, TResult>(
200+
RequestType<TParams, TResult> requestType,
201+
Func<TParams, RequestContext<TResult>, Task> requestHandler)
132202
{
133-
await this.clientChannel.MessageWriter.WriteMessage(
134-
Message.Event(
135-
eventType.MethodName,
136-
JToken.FromObject(eventParams)));
203+
this.protocolChannel.MessageDispatcher.SetRequestHandler(
204+
requestType,
205+
requestHandler);
137206
}
138207

208+
139209
public void SetEventHandler<TParams>(
140210
EventType<TParams> eventType,
141211
Func<TParams, EventContext, Task> eventHandler)
142212
{
143-
this.clientChannel.MessageDispatcher.SetEventHandler(
213+
this.protocolChannel.MessageDispatcher.SetEventHandler(
144214
eventType,
145215
eventHandler,
146216
false);
@@ -151,20 +221,12 @@ public void SetEventHandler<TParams>(
151221
Func<TParams, EventContext, Task> eventHandler,
152222
bool overrideExisting)
153223
{
154-
this.clientChannel.MessageDispatcher.SetEventHandler(
224+
this.protocolChannel.MessageDispatcher.SetEventHandler(
155225
eventType,
156226
eventHandler,
157227
overrideExisting);
158228
}
159229

160-
private void MessageDispatcher_UnhandledException(object sender, Exception e)
161-
{
162-
if (this.originalSynchronizationContext != null)
163-
{
164-
this.originalSynchronizationContext.Post(o => { throw e; }, null);
165-
}
166-
}
167-
168230
private void HandleResponse(Message responseMessage)
169231
{
170232
TaskCompletionSource<Message> pendingRequestTask = null;
@@ -176,6 +238,10 @@ private void HandleResponse(Message responseMessage)
176238
}
177239
}
178240

241+
#endregion
242+
243+
#region Subclass Lifetime Methods
244+
179245
protected virtual Task OnStart()
180246
{
181247
return Task.FromResult(true);
@@ -185,6 +251,25 @@ protected virtual Task OnStop()
185251
{
186252
return Task.FromResult(true);
187253
}
254+
255+
#endregion
256+
257+
#region Event Handlers
258+
259+
private void MessageDispatcher_UnhandledException(object sender, Exception e)
260+
{
261+
if (this.endpointExitedTask != null)
262+
{
263+
this.endpointExitedTask.SetException(e);
264+
}
265+
266+
else if (this.originalSynchronizationContext != null)
267+
{
268+
this.originalSynchronizationContext.Post(o => { throw e; }, null);
269+
}
270+
}
271+
272+
#endregion
188273
}
189274
}
190275

src/PowerShellEditorServices.Protocol/PowerShellEditorServices.Protocol.csproj

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@
4848
<Reference Include="System.Management.Automation" />
4949
</ItemGroup>
5050
<ItemGroup>
51-
<Compile Include="Client\ProtocolClient.cs" />
5251
<Compile Include="DebugAdapter\AttachRequest.cs" />
5352
<Compile Include="DebugAdapter\Breakpoint.cs" />
5453
<Compile Include="DebugAdapter\ContinueRequest.cs" />
5554
<Compile Include="LanguageServer\FindModuleRequest.cs" />
5655
<Compile Include="LanguageServer\InstallModuleRequest.cs" />
56+
<Compile Include="MessageProtocol\IMessageSender.cs" />
57+
<Compile Include="MessageProtocol\ProtocolEndpoint.cs" />
5758
<Compile Include="Messages\PromptEvents.cs" />
5859
<Compile Include="Server\DebugAdapter.cs" />
5960
<Compile Include="Server\DebugAdapterBase.cs" />
@@ -92,7 +93,6 @@
9293
<Compile Include="LanguageServer\ExpandAliasRequest.cs" />
9394
<Compile Include="LanguageServer\Hover.cs" />
9495
<Compile Include="Client\LanguageClientBase.cs" />
95-
<Compile Include="Server\IEventWriter.cs" />
9696
<Compile Include="Server\LanguageServer.cs" />
9797
<Compile Include="Server\LanguageServerBase.cs" />
9898
<Compile Include="LanguageServer\ShowOnlineHelpRequest.cs" />
@@ -125,7 +125,6 @@
125125
<Compile Include="LanguageServer\References.cs" />
126126
<Compile Include="Server\LanguageServerSettings.cs" />
127127
<Compile Include="Server\PromptHandlers.cs" />
128-
<Compile Include="Server\ProtocolServer.cs" />
129128
</ItemGroup>
130129
<ItemGroup>
131130
<None Include="packages.config" />

src/PowerShellEditorServices.Protocol/Server/DebugAdapterBase.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace Microsoft.PowerShell.EditorServices.Protocol.Server
1212
{
13-
public abstract class DebugAdapterBase : ProtocolServer
13+
public abstract class DebugAdapterBase : ProtocolEndpoint
1414
{
1515
public DebugAdapterBase(ChannelBase serverChannel)
1616
: base (serverChannel, MessageProtocolType.DebugAdapter)
@@ -32,18 +32,22 @@ protected virtual void Shutdown()
3232
// No default implementation yet.
3333
}
3434

35-
protected override void OnStart()
35+
protected override Task OnStart()
3636
{
3737
// Register handlers for server lifetime messages
3838
this.SetRequestHandler(InitializeRequest.Type, this.HandleInitializeRequest);
3939

4040
// Initialize the implementation class
4141
this.Initialize();
42+
43+
return Task.FromResult(true);
4244
}
4345

44-
protected override void OnStop()
46+
protected override Task OnStop()
4547
{
4648
this.Shutdown();
49+
50+
return Task.FromResult(true);
4751
}
4852

4953
private async Task HandleInitializeRequest(

0 commit comments

Comments
 (0)