Skip to content

Commit 6e09b54

Browse files
authored
Add SslOverTdsStream tests (#497)
1 parent 9e13645 commit 6e09b54

File tree

2 files changed

+364
-0
lines changed

2 files changed

+364
-0
lines changed

src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
<ItemGroup Condition="'$(TargetFramework)' != 'netcoreapp2.1' AND '$(TargetGroup)' == 'netcoreapp'">
6060
<PackageReference Include="System.Data.Odbc" Version="$(SystemDataOdbcVersion)" />
6161
</ItemGroup>
62+
<ItemGroup Condition="$(TargetFramework.StartsWith('netcoreapp'))">
63+
<Compile Include="SslOverTdsStreamTest.cs" />
64+
</ItemGroup>
6265
<ItemGroup>
6366
<ProjectReference Include="$(TestsPath)ManualTests\SQL\UdtTest\UDTs\Address\Address.csproj">
6467
<Name>Address</Name>
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Diagnostics;
7+
using System.IO;
8+
using System.Reflection;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Xunit;
12+
13+
namespace Microsoft.Data.SqlClient.Tests
14+
{
15+
public static class SslOverTdsStreamTest
16+
{
17+
public static TheoryData<int, int, int> PacketSizes
18+
{
19+
get
20+
{
21+
const int EncapsulatedPacketCount = 4;
22+
const int PassThroughPacketCount = 5;
23+
24+
TheoryData<int, int, int> data = new TheoryData<int, int, int>();
25+
26+
data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 0);
27+
data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 2);
28+
data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 128);
29+
data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 2048);
30+
data.Add(EncapsulatedPacketCount, PassThroughPacketCount, 8192);
31+
32+
return data;
33+
}
34+
}
35+
36+
37+
[Theory]
38+
[MemberData(nameof(PacketSizes))]
39+
public static void SyncTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
40+
{
41+
byte[] input;
42+
byte[] output;
43+
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);
44+
45+
byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
46+
(Stream stream, int index) =>
47+
{
48+
stream.Write(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE);
49+
}
50+
);
51+
52+
ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
53+
(Stream stream, byte[] bytes, int offset, int count) =>
54+
{
55+
return stream.Read(bytes, offset, count);
56+
}
57+
);
58+
59+
Validate(input, output);
60+
}
61+
62+
[Theory]
63+
[MemberData(nameof(PacketSizes))]
64+
public static void AsyncTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
65+
{
66+
byte[] input;
67+
byte[] output;
68+
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);
69+
byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
70+
async (Stream stream, int index) =>
71+
{
72+
await stream.WriteAsync(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE);
73+
}
74+
);
75+
76+
ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
77+
async (Stream stream, byte[] bytes, int offset, int count) =>
78+
{
79+
return await stream.ReadAsync(bytes, offset, count);
80+
}
81+
);
82+
83+
Validate(input, output);
84+
}
85+
86+
[Theory]
87+
[MemberData(nameof(PacketSizes))]
88+
public static void SyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
89+
{
90+
byte[] input;
91+
byte[] output;
92+
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);
93+
94+
byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
95+
(Stream stream, int index) =>
96+
{
97+
stream.Write(input.AsSpan(TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE));
98+
}
99+
);
100+
101+
ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
102+
(Stream stream, byte[] bytes, int offset, int count) =>
103+
{
104+
return stream.Read(bytes.AsSpan(offset, count));
105+
}
106+
);
107+
108+
Validate(input, output);
109+
}
110+
111+
[Theory]
112+
[MemberData(nameof(PacketSizes))]
113+
public static void AsyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
114+
{
115+
byte[] input;
116+
byte[] output;
117+
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);
118+
119+
byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
120+
async (Stream stream, int index) =>
121+
{
122+
await stream.WriteAsync(
123+
new ReadOnlyMemory<byte>(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE)
124+
);
125+
}
126+
);
127+
128+
ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
129+
async (Stream stream, byte[] bytes, int offset, int count) =>
130+
{
131+
return await stream.ReadAsync(
132+
new Memory<byte>(bytes, offset, count)
133+
);
134+
}
135+
);
136+
137+
Validate(input, output);
138+
}
139+
140+
141+
private static void ReadPackets(byte[] buffer, int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength, byte[] output, Func<Stream, byte[], int, int, Task<int>> action)
142+
{
143+
using (LimitedMemoryStream stream = new LimitedMemoryStream(buffer, maxPacketReadLength))
144+
using (Stream tdsStream = CreateSslOverTdsStream(stream))
145+
{
146+
int offset = 0;
147+
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
148+
for (int index = 0; index < encapsulatedPacketCount; index++)
149+
{
150+
Array.Clear(bytes, 0, bytes.Length);
151+
int packetBytes = ReadPacket(tdsStream, action, bytes).GetAwaiter().GetResult();
152+
Array.Copy(bytes, 0, output, offset, packetBytes);
153+
offset += packetBytes;
154+
}
155+
InvokeFinishHandshake(tdsStream);
156+
for (int index = 0; index < passthroughPacketCount; index++)
157+
{
158+
Array.Clear(bytes, 0, bytes.Length);
159+
int packetBytes = ReadPacket(tdsStream, action, bytes).GetAwaiter().GetResult();
160+
Array.Copy(bytes, 0, output, offset, packetBytes);
161+
offset += packetBytes;
162+
}
163+
}
164+
}
165+
166+
private static void InvokeFinishHandshake(Stream stream)
167+
{
168+
MethodInfo method = stream.GetType().GetMethod("FinishHandshake", BindingFlags.Public | BindingFlags.Instance);
169+
method.Invoke(stream, null);
170+
}
171+
172+
private static Stream CreateSslOverTdsStream(Stream stream)
173+
{
174+
Type type = typeof(SqlClientFactory).Assembly.GetType("Microsoft.Data.SqlClient.SNI.SslOverTdsStream");
175+
ConstructorInfo ctor = type.GetConstructor(new Type[] { typeof(Stream) });
176+
Stream instance = (Stream)ctor.Invoke(new object[] { stream });
177+
return instance;
178+
}
179+
180+
private static void ReadPackets(byte[] buffer, int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength, byte[] output, Func<Stream, byte[], int, int, int> action)
181+
{
182+
using (LimitedMemoryStream stream = new LimitedMemoryStream(buffer, maxPacketReadLength))
183+
using (Stream tdsStream = CreateSslOverTdsStream(stream))
184+
{
185+
int offset = 0;
186+
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
187+
for (int index = 0; index < encapsulatedPacketCount; index++)
188+
{
189+
Array.Clear(bytes, 0, bytes.Length);
190+
int packetBytes = ReadPacket(tdsStream, action, bytes);
191+
Array.Copy(bytes, 0, output, offset, packetBytes);
192+
offset += packetBytes;
193+
}
194+
InvokeFinishHandshake(tdsStream);
195+
for (int index = 0; index < passthroughPacketCount; index++)
196+
{
197+
Array.Clear(bytes, 0, bytes.Length);
198+
int packetBytes = ReadPacket(tdsStream, action, bytes);
199+
Array.Copy(bytes, 0, output, offset, packetBytes);
200+
offset += packetBytes;
201+
}
202+
}
203+
}
204+
205+
private static int ReadPacket(Stream tdsStream, Func<Stream, byte[], int, int, int> action, byte[] output)
206+
{
207+
int readCount;
208+
int offset = 0;
209+
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
210+
do
211+
{
212+
readCount = action(tdsStream, bytes, offset, bytes.Length - offset);
213+
if (readCount > 0)
214+
{
215+
offset += readCount;
216+
}
217+
}
218+
while (readCount > 0 && offset < bytes.Length);
219+
Array.Copy(bytes, 0, output, 0, offset);
220+
return offset;
221+
}
222+
223+
private static async Task<int> ReadPacket(Stream tdsStream, Func<Stream, byte[], int, int, Task<int>> action, byte[] output)
224+
{
225+
int readCount;
226+
int offset = 0;
227+
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
228+
do
229+
{
230+
readCount = await action(tdsStream, bytes, offset, bytes.Length - offset);
231+
if (readCount > 0)
232+
{
233+
offset += readCount;
234+
}
235+
}
236+
while (readCount > 0 && offset < bytes.Length);
237+
Array.Copy(bytes, 0, output, 0, offset);
238+
return offset;
239+
}
240+
241+
private static byte[] WritePackets(int encapsulatedPacketCount, int passthroughPacketCount, Action<Stream, int> action)
242+
{
243+
byte[] buffer = null;
244+
using (LimitedMemoryStream stream = new LimitedMemoryStream())
245+
{
246+
using (Stream tdsStream = CreateSslOverTdsStream(stream))
247+
{
248+
for (int index = 0; index < encapsulatedPacketCount; index++)
249+
{
250+
action(tdsStream, index);
251+
}
252+
InvokeFinishHandshake(tdsStream);//tdsStream.FinishHandshake();
253+
for (int index = 0; index < passthroughPacketCount; index++)
254+
{
255+
action(tdsStream, encapsulatedPacketCount + index);
256+
}
257+
}
258+
buffer = stream.ToArray();
259+
}
260+
return buffer;
261+
}
262+
263+
private static void SetupArrays(int packetCount, out byte[] input, out byte[] output)
264+
{
265+
byte[] pattern = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 };
266+
input = new byte[packetCount * TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
267+
output = new byte[input.Length];
268+
for (int index = 0; index < packetCount; index++)
269+
{
270+
int position = 0;
271+
while (position < TdsEnums.DEFAULT_LOGIN_PACKET_SIZE)
272+
{
273+
int copyCount = Math.Min(pattern.Length, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE - position);
274+
Array.Copy(pattern, 0, input, (TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index) + position, copyCount);
275+
position += copyCount;
276+
}
277+
}
278+
}
279+
280+
private static void Validate(byte[] input, byte[] output)
281+
{
282+
Assert.True(input.AsSpan().SequenceEqual(output.AsSpan()));
283+
}
284+
285+
internal static class TdsEnums
286+
{
287+
public const int DEFAULT_LOGIN_PACKET_SIZE = 4096;
288+
}
289+
}
290+
291+
[DebuggerStepThrough]
292+
public sealed partial class LimitedMemoryStream : MemoryStream
293+
{
294+
private readonly int _readLimit;
295+
private readonly int _delay;
296+
297+
public LimitedMemoryStream(int readLimit = 0, int delay = 0)
298+
{
299+
_readLimit = readLimit;
300+
_delay = delay;
301+
}
302+
303+
public LimitedMemoryStream(byte[] buffer, int readLimit = 0, int delay = 0)
304+
: base(buffer)
305+
{
306+
_readLimit = readLimit;
307+
_delay = delay;
308+
}
309+
310+
public override int Read(byte[] buffer, int offset, int count)
311+
{
312+
if (_readLimit > 0)
313+
{
314+
return base.Read(buffer, offset, Math.Min(_readLimit, count));
315+
}
316+
else
317+
{
318+
return base.Read(buffer, offset, count);
319+
}
320+
}
321+
322+
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
323+
{
324+
if (_delay > 0)
325+
{
326+
await Task.Delay(_delay, cancellationToken);
327+
}
328+
if (_readLimit > 0)
329+
{
330+
return await base.ReadAsync(buffer, offset, Math.Min(_readLimit, count), cancellationToken).ConfigureAwait(false);
331+
}
332+
else
333+
{
334+
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
335+
}
336+
}
337+
public override int Read(Span<byte> destination)
338+
{
339+
if (_readLimit > 0)
340+
{
341+
return base.Read(destination.Slice(0, Math.Min(_readLimit, destination.Length)));
342+
}
343+
else
344+
{
345+
return base.Read(destination);
346+
}
347+
}
348+
349+
public override ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
350+
{
351+
if (_readLimit > 0)
352+
{
353+
return base.ReadAsync(destination.Slice(0, Math.Min(_readLimit, destination.Length)), cancellationToken);
354+
}
355+
else
356+
{
357+
return base.ReadAsync(destination, cancellationToken);
358+
}
359+
}
360+
}
361+
}

0 commit comments

Comments
 (0)