Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit b72def2

Browse files
authored
Merge pull request #14186 from jamesqo/st-last
Implement & expose SkipLast & TakeLast for Enumerable & Queryable
2 parents c8c3ca8 + 8672da3 commit b72def2

File tree

17 files changed

+408
-18
lines changed

17 files changed

+408
-18
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections;
6+
using System.Collections.Generic;
7+
8+
namespace System.Linq.Tests
9+
{
10+
public class SkipTakeData
11+
{
12+
public static IEnumerable<object[]> EnumerableData()
13+
{
14+
IEnumerable<int> sourceCounts = new[] { 1, 2, 3, 5, 8, 13, 55, 100, 250 };
15+
16+
IEnumerable<int> counts = new[] { 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 100, 250, 500, int.MaxValue };
17+
counts = counts.Concat(counts.Select(c => -c)).Append(0).Append(int.MinValue);
18+
19+
return from sourceCount in sourceCounts
20+
let source = Enumerable.Range(0, sourceCount)
21+
from count in counts
22+
select new object[] { source, count };
23+
}
24+
25+
public static IEnumerable<object[]> EvaluationBehaviorData()
26+
{
27+
return Enumerable.Range(-1, 15).Select(count => new object[] { count });
28+
}
29+
30+
public static IEnumerable<object[]> QueryableData()
31+
{
32+
return EnumerableData().Select(array =>
33+
{
34+
var enumerable = (IEnumerable<int>)array[0];
35+
return new object[] { enumerable.AsQueryable(), array[1] };
36+
});
37+
}
38+
}
39+
}

src/System.Linq.Queryable/ref/System.Linq.Queryable.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ public static partial class Queryable
126126
public static TSource SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
127127
public static TSource SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
128128
public static System.Linq.IQueryable<TSource> Skip<TSource>(this System.Linq.IQueryable<TSource> source, int count) { throw null; }
129+
public static System.Linq.IQueryable<TSource> SkipLast<TSource>(this System.Linq.IQueryable<TSource> source, int count) { throw null; }
129130
public static System.Linq.IQueryable<TSource> SkipWhile<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
130131
public static System.Linq.IQueryable<TSource> SkipWhile<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, int, bool>> predicate) { throw null; }
131132
public static decimal Sum(this System.Linq.IQueryable<decimal> source) { throw null; }
@@ -149,6 +150,7 @@ public static partial class Queryable
149150
public static System.Nullable<float> Sum<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, System.Nullable<float>>> selector) { throw null; }
150151
public static float Sum<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, float>> selector) { throw null; }
151152
public static System.Linq.IQueryable<TSource> Take<TSource>(this System.Linq.IQueryable<TSource> source, int count) { throw null; }
153+
public static System.Linq.IQueryable<TSource> TakeLast<TSource>(this System.Linq.IQueryable<TSource> source, int count) { throw null; }
152154
public static System.Linq.IQueryable<TSource> TakeWhile<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
153155
public static System.Linq.IQueryable<TSource> TakeWhile<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, int, bool>> predicate) { throw null; }
154156
public static System.Linq.IOrderedQueryable<TSource> ThenBy<TSource, TKey>(this System.Linq.IOrderedQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector) { throw null; }

src/System.Linq.Queryable/src/System/Linq/CachedReflection.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,5 +837,19 @@ public static MethodInfo Zip_TFirst_TSecond_TResult_3(Type TFirst, Type TSecond,
837837
(s_Zip_TFirst_TSecond_TResult_3 = new Func<IQueryable<object>, IEnumerable<object>, Expression<Func<object, object, object>>, IQueryable<object>>(Queryable.Zip).GetMethodInfo().GetGenericMethodDefinition()))
838838
.MakeGenericMethod(TFirst, TSecond, TResult);
839839

840+
841+
private static MethodInfo s_SkipLast_TSource_2;
842+
843+
public static MethodInfo SkipLast_TSource_2(Type TSource) =>
844+
(s_SkipLast_TSource_2 ??
845+
(s_SkipLast_TSource_2 = new Func<IQueryable<object>, int, IQueryable<object>>(Queryable.SkipLast).GetMethodInfo().GetGenericMethodDefinition()))
846+
.MakeGenericMethod(TSource);
847+
848+
private static MethodInfo s_TakeLast_TSource_2;
849+
850+
public static MethodInfo TakeLast_TSource_2(Type TSource) =>
851+
(s_TakeLast_TSource_2 ??
852+
(s_TakeLast_TSource_2 = new Func<IQueryable<object>, int, IQueryable<object>>(Queryable.TakeLast).GetMethodInfo().GetGenericMethodDefinition()))
853+
.MakeGenericMethod(TSource);
840854
}
841855
}

src/System.Linq.Queryable/src/System/Linq/Queryable.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,5 +1609,29 @@ public static TResult Aggregate<TSource, TAccumulate, TResult>(this IQueryable<T
16091609
null,
16101610
CachedReflectionInfo.Aggregate_TSource_TAccumulate_TResult_4(typeof(TSource), typeof(TAccumulate), typeof(TResult)), source.Expression, Expression.Constant(seed), Expression.Quote(func), Expression.Quote(selector)));
16111611
}
1612+
1613+
public static IQueryable<TSource> SkipLast<TSource>(this IQueryable<TSource> source, int count)
1614+
{
1615+
if (source == null)
1616+
throw Error.ArgumentNull(nameof(source));
1617+
return source.Provider.CreateQuery<TSource>(
1618+
Expression.Call(
1619+
null,
1620+
CachedReflectionInfo.SkipLast_TSource_2(typeof(TSource)),
1621+
source.Expression, Expression.Constant(count)
1622+
));
1623+
}
1624+
1625+
public static IQueryable<TSource> TakeLast<TSource>(this IQueryable<TSource> source, int count)
1626+
{
1627+
if (source == null)
1628+
throw Error.ArgumentNull(nameof(source));
1629+
return source.Provider.CreateQuery<TSource>(
1630+
Expression.Call(
1631+
null,
1632+
CachedReflectionInfo.TakeLast_TSource_2(typeof(TSource)),
1633+
source.Expression, Expression.Constant(count)
1634+
));
1635+
}
16121636
}
16131637
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Xunit;
7+
using static System.Linq.Tests.SkipTakeData;
8+
9+
namespace System.Linq.Tests
10+
{
11+
public class SkipLastTests : EnumerableBasedTests
12+
{
13+
[Theory, MemberData(nameof(QueryableData), MemberType = typeof(SkipTakeData))]
14+
public void SkipLast(IQueryable<int> source, int count)
15+
{
16+
IQueryable<int> expected = source.Reverse().Skip(count).Reverse();
17+
IQueryable<int> actual = source.SkipLast(count);
18+
19+
Assert.Equal(expected, actual);
20+
}
21+
}
22+
}

src/System.Linq.Queryable/tests/System.Linq.Queryable.Tests.csproj

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,21 @@
4545
<Compile Include="SequenceEqualTests.cs" />
4646
<Compile Include="SingleOrDefaultTests.cs" />
4747
<Compile Include="SingleTests.cs" />
48+
<Compile Include="SkipLastTests.cs" />
4849
<Compile Include="SkipTests.cs" />
4950
<Compile Include="SkipWhileTests.cs" />
5051
<Compile Include="SumTests.cs" />
52+
<Compile Include="TakeLastTests.cs" />
5153
<Compile Include="TakeTests.cs" />
5254
<Compile Include="TakeWhileTests.cs" />
5355
<Compile Include="ThenByDescendingTests.cs" />
5456
<Compile Include="ThenByTests.cs" />
5557
<Compile Include="UnionTests.cs" />
5658
<Compile Include="WhereTests.cs" />
5759
<Compile Include="ZipTests.cs" />
60+
<Compile Include="$(CommonTestPath)\System\Linq\SkipTakeData.cs">
61+
<Link>Common\System\Linq\SkipTakeData.cs</Link>
62+
</Compile>
5863
</ItemGroup>
5964
<Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
60-
</Project>
65+
</Project>
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Xunit;
7+
using static System.Linq.Tests.SkipTakeData;
8+
9+
namespace System.Linq.Tests
10+
{
11+
public class TakeLastTests : EnumerableBasedTests
12+
{
13+
[Theory, MemberData(nameof(QueryableData), MemberType = typeof(SkipTakeData))]
14+
public void TakeLast(IQueryable<int> equivalent, int count)
15+
{
16+
IQueryable<int> expected = equivalent.Reverse().Take(count).Reverse();
17+
IQueryable<int> actual = equivalent.TakeLast(count);
18+
19+
Assert.Equal(expected, actual);
20+
}
21+
}
22+
}

src/System.Linq/ref/System.Linq.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ public static partial class Enumerable
143143
public static TSource SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
144144
public static TSource SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
145145
public static System.Collections.Generic.IEnumerable<TSource> Skip<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int count) { throw null; }
146+
public static System.Collections.Generic.IEnumerable<TSource> SkipLast<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int count) { throw null; }
146147
public static System.Collections.Generic.IEnumerable<TSource> SkipWhile<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
147148
public static System.Collections.Generic.IEnumerable<TSource> SkipWhile<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, int, bool> predicate) { throw null; }
148149
public static decimal Sum(this System.Collections.Generic.IEnumerable<decimal> source) { throw null; }
@@ -166,6 +167,7 @@ public static partial class Enumerable
166167
public static System.Nullable<float> Sum<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, System.Nullable<float>> selector) { throw null; }
167168
public static float Sum<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, float> selector) { throw null; }
168169
public static System.Collections.Generic.IEnumerable<TSource> Take<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int count) { throw null; }
170+
public static System.Collections.Generic.IEnumerable<TSource> TakeLast<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int count) { throw null; }
169171
public static System.Collections.Generic.IEnumerable<TSource> TakeWhile<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
170172
public static System.Collections.Generic.IEnumerable<TSource> TakeWhile<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, int, bool> predicate) { throw null; }
171173
public static System.Linq.IOrderedEnumerable<TSource> ThenBy<TSource, TKey>(this System.Linq.IOrderedEnumerable<TSource> source, System.Func<TSource, TKey> keySelector) { throw null; }

src/System.Linq/src/System.Linq.csproj

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
<PropertyGroup>
55
<ProjectGuid>{CA488507-3B6E-4494-B7BE-7B4EEEB2C4D1}</ProjectGuid>
66
<AssemblyName>System.Linq</AssemblyName>
7-
<AssemblyVersion>4.2.0.0</AssemblyVersion>
87
<RootNamespace>System.Linq</RootNamespace>
98
<IsPartialFacadeAssembly Condition="'$(TargetGroup)' == 'net461'">true</IsPartialFacadeAssembly>
109
<!-- The following line needs to be removed once we have a targeting pack for 4.6.3 -->
1110
<TargetingPackNugetPackageId Condition="'$(TargetGroup)' == 'net461'">Microsoft.TargetingPack.NETFramework.v4.6.1</TargetingPackNugetPackageId>
12-
<DefineConstants Condition="'$(TargetGroup)' == 'netcoreapp'">$(DefineConstants);netcoreapp11</DefineConstants>
1311
</PropertyGroup>
1412
<!-- Default configurations to help VS understand the configurations -->
1513
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='net461-Windows_NT-Debug|AnyCPU'" />

src/System.Linq/src/System/Linq/Skip.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using System.Diagnostics;
67

78
namespace System.Linq
89
{
@@ -121,5 +122,49 @@ private static IEnumerable<TSource> SkipWhileIterator<TSource>(IEnumerable<TSour
121122
}
122123
}
123124
}
125+
126+
public static IEnumerable<TSource> SkipLast<TSource>(this IEnumerable<TSource> source, int count)
127+
{
128+
if (source == null)
129+
{
130+
throw Error.ArgumentNull(nameof(source));
131+
}
132+
133+
if (count <= 0)
134+
{
135+
return source.Skip(0);
136+
}
137+
138+
return SkipLastIterator(source, count);
139+
}
140+
141+
private static IEnumerable<TSource> SkipLastIterator<TSource>(IEnumerable<TSource> source, int count)
142+
{
143+
Debug.Assert(source != null);
144+
Debug.Assert(count > 0);
145+
146+
var queue = new Queue<TSource>();
147+
148+
using (IEnumerator<TSource> e = source.GetEnumerator())
149+
{
150+
while (e.MoveNext())
151+
{
152+
if (queue.Count == count)
153+
{
154+
do
155+
{
156+
yield return queue.Dequeue();
157+
queue.Enqueue(e.Current);
158+
}
159+
while (e.MoveNext());
160+
break;
161+
}
162+
else
163+
{
164+
queue.Enqueue(e.Current);
165+
}
166+
}
167+
}
168+
}
124169
}
125170
}

src/System.Linq/src/System/Linq/Take.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using System.Diagnostics;
67

78
namespace System.Linq
89
{
@@ -96,5 +97,56 @@ private static IEnumerable<TSource> TakeWhileIterator<TSource>(IEnumerable<TSour
9697
yield return element;
9798
}
9899
}
100+
101+
public static IEnumerable<TSource> TakeLast<TSource>(this IEnumerable<TSource> source, int count)
102+
{
103+
if (source == null)
104+
{
105+
throw Error.ArgumentNull(nameof(source));
106+
}
107+
108+
if (count <= 0)
109+
{
110+
return EmptyPartition<TSource>.Instance;
111+
}
112+
113+
return TakeLastIterator(source, count);
114+
}
115+
116+
private static IEnumerable<TSource> TakeLastIterator<TSource>(IEnumerable<TSource> source, int count)
117+
{
118+
Debug.Assert(source != null);
119+
Debug.Assert(count > 0);
120+
121+
var queue = new Queue<TSource>();
122+
123+
using (IEnumerator<TSource> e = source.GetEnumerator())
124+
{
125+
while (e.MoveNext())
126+
{
127+
if (queue.Count < count)
128+
{
129+
queue.Enqueue(e.Current);
130+
}
131+
else
132+
{
133+
do
134+
{
135+
queue.Dequeue();
136+
queue.Enqueue(e.Current);
137+
}
138+
while (e.MoveNext());
139+
break;
140+
}
141+
}
142+
}
143+
144+
Debug.Assert(queue.Count <= count);
145+
do
146+
{
147+
yield return queue.Dequeue();
148+
}
149+
while (queue.Count > 0);
150+
}
99151
}
100152
}

src/System.Linq/src/System/Linq/ToCollection.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ private static Dictionary<TKey, TElement> ToDictionary<TSource, TKey, TElement>(
177177
return d;
178178
}
179179

180-
#if netcoreapp11
181-
public static HashSet<TSource> ToHashSet<TSource>(this IEnumerable<TSource> source) => source.ToHashSet(null);
180+
public static HashSet<TSource> ToHashSet<TSource>(this IEnumerable<TSource> source) => source.ToHashSet(comparer: null);
182181

183182
public static HashSet<TSource> ToHashSet<TSource>(this IEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
184183
{
@@ -187,9 +186,8 @@ public static HashSet<TSource> ToHashSet<TSource>(this IEnumerable<TSource> sour
187186
throw new ArgumentNullException(nameof(source));
188187
}
189188

190-
// Don't pre-allocate based on knowledge of size as potentially many elements will be dropped.
189+
// Don't pre-allocate based on knowledge of size, as potentially many elements will be dropped.
191190
return new HashSet<TSource>(source, comparer);
192191
}
193-
#endif
194192
}
195193
}

src/System.Linq/tests/ConsistencyTests.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ public static void MatchSequencePattern()
2222
typeof(Enumerable),
2323
typeof(Queryable),
2424
new[] {
25-
"ToLookup",
26-
"ToDictionary",
27-
"ToArray",
28-
"AsEnumerable",
29-
"ToList",
25+
nameof(Enumerable.ToLookup),
26+
nameof(Enumerable.ToDictionary),
27+
nameof(Enumerable.ToArray),
28+
nameof(Enumerable.AsEnumerable),
29+
nameof(Enumerable.ToList),
3030
"Fold",
3131
"LeftJoin",
32-
"Append",
33-
"Prepend",
32+
nameof(Enumerable.Append),
33+
nameof(Enumerable.Prepend),
3434
"ToHashSet"
3535
}
3636
);
@@ -41,7 +41,7 @@ public static void MatchSequencePattern()
4141
typeof(Queryable),
4242
typeof(Enumerable),
4343
new[] {
44-
"AsQueryable"
44+
nameof(Queryable.AsQueryable)
4545
}
4646
);
4747

0 commit comments

Comments
 (0)