Skip to content

Commit 422cbce

Browse files
authored
Expose FeatureHasher (#652)
1 parent 6838868 commit 422cbce

File tree

12 files changed

+296
-32
lines changed

12 files changed

+296
-32
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@
1313
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1414
{
1515
[Collection("Spark E2E Tests")]
16-
public class BucketizerTests
16+
public class BucketizerTests : FeatureBaseTests<Bucketizer>
1717
{
1818
private readonly SparkSession _spark;
1919

20-
public BucketizerTests(SparkFixture fixture)
20+
public BucketizerTests(SparkFixture fixture) : base(fixture)
2121
{
2222
_spark = fixture.Spark;
2323
}
2424

25+
/// <summary>
26+
/// Create a <see cref="DataFrame"/>, create a <see cref="Bucketizer"/> and test the
27+
/// available methods. Test the FeatureBase methods using <see cref="FeatureBaseTests"/>.
28+
/// </summary>
2529
[Fact]
2630
public void TestBucketizer()
2731
{
28-
var expectedSplits = new double[] { double.MinValue, 0.0, 10.0, 50.0, double.MaxValue };
32+
var expectedSplits =
33+
new double[] { double.MinValue, 0.0, 10.0, 50.0, double.MaxValue };
2934

3035
string expectedHandle = "skip";
3136
string expectedUid = "uid";
@@ -60,18 +65,7 @@ public void TestBucketizer()
6065
Assert.Equal(bucketizer.Uid(), loadedBucketizer.Uid());
6166
}
6267

63-
Assert.NotEmpty(bucketizer.ExplainParams());
64-
65-
Param handleInvalidParam = bucketizer.GetParam("handleInvalid");
66-
Assert.NotEmpty(handleInvalidParam.Doc);
67-
Assert.NotEmpty(handleInvalidParam.Name);
68-
Assert.Equal(handleInvalidParam.Parent, bucketizer.Uid());
69-
70-
Assert.NotEmpty(bucketizer.ExplainParam(handleInvalidParam));
71-
bucketizer.Set(handleInvalidParam, "keep");
72-
Assert.Equal("keep", bucketizer.GetHandleInvalid());
73-
74-
Assert.Equal("error", bucketizer.Clear(handleInvalidParam).GetHandleInvalid());
68+
TestFeatureBase(bucketizer, "handleInvalid", "keep");
7569
}
7670

7771
[Fact]

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerModelTests.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1313
{
1414
[Collection("Spark E2E Tests")]
15-
public class CountVectorizerModelTests
15+
public class CountVectorizerModelTests : FeatureBaseTests<CountVectorizerModel>
1616
{
1717
private readonly SparkSession _spark;
1818

19-
public CountVectorizerModelTests(SparkFixture fixture)
19+
public CountVectorizerModelTests(SparkFixture fixture) : base(fixture)
2020
{
2121
_spark = fixture.Spark;
2222
}
2323

2424
/// <summary>
25-
/// Test that we can create a CountVectorizerModel, pass in a specifc vocabulary to use
25+
/// Test that we can create a CountVectorizerModel, pass in a specific vocabulary to use
2626
/// when creating the model. Verify the standard features methods as well as load/save.
2727
/// </summary>
2828
[Fact]
@@ -68,6 +68,8 @@ public void TestCountVectorizerModel()
6868
Assert.IsType<int>(countVectorizerModel.GetVocabSize());
6969
Assert.NotEmpty(countVectorizerModel.ExplainParams());
7070
Assert.NotEmpty(countVectorizerModel.ToString());
71+
72+
TestFeatureBase(countVectorizerModel, "minDF", 100);
7173
}
7274
}
7375
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1414
{
1515
[Collection("Spark E2E Tests")]
16-
public class CountVectorizerTests
16+
public class CountVectorizerTests : FeatureBaseTests<CountVectorizer>
1717
{
1818
private readonly SparkSession _spark;
1919

20-
public CountVectorizerTests(SparkFixture fixture)
20+
public CountVectorizerTests(SparkFixture fixture) : base(fixture)
2121
{
2222
_spark = fixture.Spark;
2323
}
@@ -67,6 +67,8 @@ public void TestCountVectorizer()
6767

6868
Assert.NotEmpty(countVectorizer.ExplainParams());
6969
Assert.NotEmpty(countVectorizer.ToString());
70+
71+
TestFeatureBase(countVectorizer, "minDF", 0.4);
7072
}
7173

7274
/// <summary>
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.ML.Feature;
6+
using Microsoft.Spark.ML.Feature.Param;
7+
using Microsoft.Spark.Sql;
8+
using Xunit;
9+
10+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
11+
{
12+
public class FeatureBaseTests<T>
13+
{
14+
private readonly SparkSession _spark;
15+
16+
protected FeatureBaseTests(SparkFixture fixture)
17+
{
18+
_spark = fixture.Spark;
19+
}
20+
21+
/// <summary>
22+
/// Tests the common functionality across all ML.Feature classes.
23+
/// </summary>
24+
/// <param name="testObject">The object that implemented FeatureBase</param>
25+
/// <param name="paramName">The name of a parameter that can be set on this object</param>
26+
/// <param name="paramValue">A parameter value that can be set on this object</param>
27+
public void TestFeatureBase(
28+
FeatureBase<T> testObject,
29+
string paramName,
30+
object paramValue)
31+
{
32+
Assert.NotEmpty(testObject.ExplainParams());
33+
34+
Param param = testObject.GetParam(paramName);
35+
Assert.NotEmpty(param.Doc);
36+
Assert.NotEmpty(param.Name);
37+
Assert.Equal(param.Parent, testObject.Uid());
38+
39+
Assert.NotEmpty(testObject.ExplainParam(param));
40+
testObject.Set(param, paramValue);
41+
Assert.IsAssignableFrom<Identifiable>(testObject.Clear(param));
42+
43+
Assert.IsType<string>(testObject.Uid());
44+
}
45+
}
46+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.Spark.ML.Feature;
8+
using Microsoft.Spark.Sql;
9+
using Microsoft.Spark.Sql.Types;
10+
using Xunit;
11+
12+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
13+
{
14+
[Collection("Spark E2E Tests")]
15+
public class FeatureHasherTests : FeatureBaseTests<FeatureHasher>
16+
{
17+
private readonly SparkSession _spark;
18+
19+
public FeatureHasherTests(SparkFixture fixture) : base(fixture)
20+
{
21+
_spark = fixture.Spark;
22+
}
23+
24+
/// <summary>
25+
/// Create a <see cref="DataFrame"/>, create a <see cref="FeatureHasher"/> and test the
26+
/// available methods. Test the FeatureBase methods using <see cref="FeatureBaseTests"/>.
27+
/// </summary>
28+
[Fact]
29+
public void TestFeatureHasher()
30+
{
31+
DataFrame dataFrame = _spark.CreateDataFrame(
32+
new List<GenericRow>
33+
{
34+
new GenericRow(new object[] { 2.0D, true, "1", "foo" }),
35+
new GenericRow(new object[] { 3.0D, false, "2", "bar" })
36+
},
37+
new StructType(new List<StructField>
38+
{
39+
new StructField("real", new DoubleType()),
40+
new StructField("bool", new BooleanType()),
41+
new StructField("stringNum", new StringType()),
42+
new StructField("string", new StringType())
43+
}));
44+
45+
FeatureHasher hasher = new FeatureHasher()
46+
.SetInputCols(new List<string>() { "real", "bool", "stringNum", "string" })
47+
.SetOutputCol("features")
48+
.SetCategoricalCols(new List<string>() { "real", "string" })
49+
.SetNumFeatures(10);
50+
51+
Assert.IsType<string>(hasher.GetOutputCol());
52+
Assert.IsType<string[]>(hasher.GetInputCols());
53+
Assert.IsType<string[]>(hasher.GetCategoricalCols());
54+
Assert.IsType<int>(hasher.GetNumFeatures());
55+
Assert.IsType<StructType>(hasher.TransformSchema(dataFrame.Schema()));
56+
Assert.IsType<DataFrame>(hasher.Transform(dataFrame));
57+
58+
TestFeatureBase(hasher, "numFeatures", 1000);
59+
}
60+
}
61+
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1212
{
1313
[Collection("Spark E2E Tests")]
14-
public class HashingTFTests
14+
public class HashingTFTests : FeatureBaseTests<HashingTF>
1515
{
1616
private readonly SparkSession _spark;
1717

18-
public HashingTFTests(SparkFixture fixture)
18+
public HashingTFTests(SparkFixture fixture) : base(fixture)
1919
{
2020
_spark = fixture.Spark;
2121
}
@@ -57,6 +57,8 @@ public void TestHashingTF()
5757

5858
hashingTf.SetBinary(true);
5959
Assert.True(hashingTf.GetBinary());
60+
61+
TestFeatureBase(hashingTf, "numFeatures", 1000);
6062
}
6163
}
6264
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1212
{
1313
[Collection("Spark E2E Tests")]
14-
public class IDFModelTests
14+
public class IDFModelTests : FeatureBaseTests<IDFModel>
1515
{
1616
private readonly SparkSession _spark;
1717

18-
public IDFModelTests(SparkFixture fixture)
18+
public IDFModelTests(SparkFixture fixture) : base(fixture)
1919
{
2020
_spark = fixture.Spark;
2121
}
@@ -65,6 +65,8 @@ public void TestIDFModel()
6565
IDFModel loadedModel = IDFModel.Load(modelPath);
6666
Assert.Equal(idfModel.Uid(), loadedModel.Uid());
6767
}
68+
69+
TestFeatureBase(idfModel, "minDocFreq", 1000);
6870
}
6971
}
7072
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1212
{
1313
[Collection("Spark E2E Tests")]
14-
public class IDFTests
14+
public class IDFTests : FeatureBaseTests<IDF>
1515
{
1616
private readonly SparkSession _spark;
1717

18-
public IDFTests(SparkFixture fixture)
18+
public IDFTests(SparkFixture fixture) : base(fixture)
1919
{
2020
_spark = fixture.Spark;
2121
}
@@ -44,6 +44,8 @@ public void TestIDFModel()
4444
IDF loadedIdf = IDF.Load(savePath);
4545
Assert.Equal(idf.Uid(), loadedIdf.Uid());
4646
}
47+
48+
TestFeatureBase(idf, "minDocFreq", 1000);
4749
}
4850
}
4951
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1212
{
1313
[Collection("Spark E2E Tests")]
14-
public class TokenizerTests
14+
public class TokenizerTests : FeatureBaseTests<Tokenizer>
1515
{
1616
private readonly SparkSession _spark;
1717

18-
public TokenizerTests(SparkFixture fixture)
18+
public TokenizerTests(SparkFixture fixture) : base(fixture)
1919
{
2020
_spark = fixture.Spark;
2121
}
@@ -50,6 +50,8 @@ public void TestTokenizer()
5050
}
5151

5252
Assert.Equal(expectedUid, tokenizer.Uid());
53+
54+
TestFeatureBase(tokenizer, "inputCol", "input_col");
5355
}
5456
}
5557
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
1212
{
1313
[Collection("Spark E2E Tests")]
14-
public class Word2VecModelTests
14+
public class Word2VecModelTests : FeatureBaseTests<Word2VecModel>
1515
{
1616
private readonly SparkSession _spark;
1717

18-
public Word2VecModelTests(SparkFixture fixture)
18+
public Word2VecModelTests(SparkFixture fixture) : base(fixture)
1919
{
2020
_spark = fixture.Spark;
2121
}
@@ -47,6 +47,8 @@ public void TestWord2VecModel()
4747
Word2VecModel loadedModel = Word2VecModel.Load(savePath);
4848
Assert.Equal(model.Uid(), loadedModel.Uid());
4949
}
50+
51+
TestFeatureBase(model, "maxIter", 2);
5052
}
5153
}
5254
}

0 commit comments

Comments
 (0)