-
Notifications
You must be signed in to change notification settings - Fork 329
Implement ML/CountVectorizer and ML/CountVectorizerModel #608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
6bab996
CountVectorizer
e2a566b
moving private methods to bottom
5f682a6
changing wrap method
31371db
setting min version required
60eb82f
undoing csproj change
ed36375
member doesnt need to be internal
c7baf72
too many lines
d13303c
removing whitespace change
f5b477c
removing whitespace change
73db52b
ionide
a766146
Merge branch 'master' into ml/countvectorizer
GoEddie 2ce91db
changes after review
59923ab
Merge branch 'master' into ml/countvectorizer
GoEddie 85f24bc
chnages after review
591adbb
Merge branch 'master' of github.com:dotnet/spark into ml/countvectorizer
ef6ad6b
Merge branch 'ml/countvectorizer' of github.com:GoEddie/spark into ml…
ed01370
merge
8e4a87d
changes after feedback
159a34f
Merge branch 'master' into ml/countvectorizer
GoEddie 10881bb
Merge branch 'master' into ml/countvectorizer
GoEddie 1de32e7
Merge branch 'master' into ml/countvectorizer
imback82 9d59992
changes after feedback
785a3da
changes after feedback
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
73 changes: 73 additions & 0 deletions
73
src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerModelTests.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using System.IO; | ||
using Microsoft.Spark.ML.Feature; | ||
using Microsoft.Spark.Sql; | ||
using Microsoft.Spark.UnitTest.TestUtils; | ||
using Xunit; | ||
|
||
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature | ||
{ | ||
[Collection("Spark E2E Tests")] | ||
public class CountVectorizerModelTests | ||
{ | ||
private readonly SparkSession _spark; | ||
|
||
public CountVectorizerModelTests(SparkFixture fixture) | ||
{ | ||
_spark = fixture.Spark; | ||
} | ||
|
||
/// <summary> | ||
/// Test that we can create a CountVectorizerModel, pass in a specifc vocabulary to use | ||
/// when creating the model. Verify the standard features methods as well as load/save. | ||
/// </summary> | ||
[Fact] | ||
public void TestCountVectorizerModel() | ||
{ | ||
const string inputColumn = "input"; | ||
const string outputColumn = "output"; | ||
const double minTf = 10.0; | ||
const bool binary = false; | ||
|
||
var vocabulary = new List<string>() | ||
{ | ||
"hello", | ||
"I", | ||
"AM", | ||
"TO", | ||
"TOKENIZE" | ||
}; | ||
|
||
var countVectorizerModel = new CountVectorizerModel(vocabulary); | ||
|
||
Assert.IsType<CountVectorizerModel>(new CountVectorizerModel("my-uid", vocabulary)); | ||
|
||
countVectorizerModel = countVectorizerModel | ||
.SetInputCol(inputColumn) | ||
.SetOutputCol(outputColumn) | ||
.SetMinTF(minTf) | ||
.SetBinary(binary); | ||
|
||
Assert.Equal(inputColumn, countVectorizerModel.GetInputCol()); | ||
Assert.Equal(outputColumn, countVectorizerModel.GetOutputCol()); | ||
Assert.Equal(minTf, countVectorizerModel.GetMinTF()); | ||
Assert.Equal(binary, countVectorizerModel.GetBinary()); | ||
using (var tempDirectory = new TemporaryDirectory()) | ||
{ | ||
string savePath = Path.Join(tempDirectory.Path, "countVectorizerModel"); | ||
countVectorizerModel.Save(savePath); | ||
|
||
CountVectorizerModel loadedModel = CountVectorizerModel.Load(savePath); | ||
Assert.Equal(countVectorizerModel.Uid(), loadedModel.Uid()); | ||
} | ||
|
||
Assert.IsType<int>(countVectorizerModel.GetVocabSize()); | ||
Assert.NotEmpty(countVectorizerModel.ExplainParams()); | ||
Assert.NotEmpty(countVectorizerModel.ToString()); | ||
} | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerTests.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.IO; | ||
using Microsoft.Spark.E2ETest.Utils; | ||
using Microsoft.Spark.ML.Feature; | ||
using Microsoft.Spark.Sql; | ||
using Microsoft.Spark.UnitTest.TestUtils; | ||
using Xunit; | ||
|
||
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature | ||
{ | ||
[Collection("Spark E2E Tests")] | ||
public class CountVectorizerTests | ||
{ | ||
private readonly SparkSession _spark; | ||
|
||
public CountVectorizerTests(SparkFixture fixture) | ||
{ | ||
_spark = fixture.Spark; | ||
} | ||
|
||
/// <summary> | ||
/// Test that we can create a CountVectorizer. Verify the standard features methods as well | ||
/// as load/save. | ||
/// </summary> | ||
[Fact] | ||
public void TestCountVectorizer() | ||
{ | ||
DataFrame input = _spark.Sql("SELECT array('hello', 'I', 'AM', 'a', 'string', 'TO', " + | ||
"'TOKENIZE') as input from range(100)"); | ||
|
||
const string inputColumn = "input"; | ||
const string outputColumn = "output"; | ||
const double minDf = 1; | ||
const double minTf = 10; | ||
const int vocabSize = 10000; | ||
const bool binary = false; | ||
|
||
var countVectorizer = new CountVectorizer(); | ||
|
||
countVectorizer | ||
.SetInputCol(inputColumn) | ||
.SetOutputCol(outputColumn) | ||
.SetMinDF(minDf) | ||
.SetMinTF(minTf) | ||
.SetVocabSize(vocabSize); | ||
|
||
Assert.IsType<CountVectorizerModel>(countVectorizer.Fit(input)); | ||
Assert.Equal(inputColumn, countVectorizer.GetInputCol()); | ||
Assert.Equal(outputColumn, countVectorizer.GetOutputCol()); | ||
Assert.Equal(minDf, countVectorizer.GetMinDF()); | ||
Assert.Equal(minTf, countVectorizer.GetMinTF()); | ||
Assert.Equal(vocabSize, countVectorizer.GetVocabSize()); | ||
Assert.Equal(binary, countVectorizer.GetBinary()); | ||
|
||
using (var tempDirectory = new TemporaryDirectory()) | ||
{ | ||
string savePath = Path.Join(tempDirectory.Path, "countVectorizer"); | ||
countVectorizer.Save(savePath); | ||
|
||
CountVectorizer loadedVectorizer = CountVectorizer.Load(savePath); | ||
Assert.Equal(countVectorizer.Uid(), loadedVectorizer.Uid()); | ||
} | ||
|
||
Assert.NotEmpty(countVectorizer.ExplainParams()); | ||
Assert.NotEmpty(countVectorizer.ToString()); | ||
} | ||
|
||
/// <summary> | ||
/// Test signatures for APIs introduced in Spark 2.4.*. | ||
/// </summary> | ||
[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] | ||
public void TestSignaturesV2_4_X() | ||
{ | ||
const double maxDf = 100; | ||
CountVectorizer countVectorizer = new CountVectorizer().SetMaxDF(maxDf); | ||
Assert.Equal(maxDf, countVectorizer.GetMaxDF()); | ||
} | ||
} | ||
} |
198 changes: 198 additions & 0 deletions
198
src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.Spark.Interop; | ||
using Microsoft.Spark.Interop.Ipc; | ||
using Microsoft.Spark.Sql; | ||
|
||
namespace Microsoft.Spark.ML.Feature | ||
{ | ||
public class CountVectorizer : FeatureBase<CountVectorizer>, IJvmObjectReferenceProvider | ||
{ | ||
private static readonly string s_countVectorizerClassName = | ||
"org.apache.spark.ml.feature.CountVectorizer"; | ||
|
||
/// <summary> | ||
/// Creates a <see cref="CountVectorizer"/> without any parameters. | ||
/// </summary> | ||
public CountVectorizer() : base(s_countVectorizerClassName) | ||
{ | ||
} | ||
|
||
/// <summary> | ||
/// Creates a <see cref="CountVectorizer"/> with a UID that is used to give the | ||
/// <see cref="CountVectorizer"/> a unique ID. | ||
/// </summary> | ||
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param> | ||
public CountVectorizer(string uid) : base(s_countVectorizerClassName, uid) | ||
{ | ||
} | ||
|
||
internal CountVectorizer(JvmObjectReference jvmObject) : base(jvmObject) | ||
{ | ||
} | ||
|
||
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; | ||
|
||
/// <summary>Fits a model to the input data.</summary> | ||
/// <param name="dataFrame">The <see cref="DataFrame"/> to fit the model to.</param> | ||
/// <returns><see cref="CountVectorizerModel"/></returns> | ||
public CountVectorizerModel Fit(DataFrame dataFrame) => | ||
new CountVectorizerModel((JvmObjectReference)_jvmObject.Invoke("fit", dataFrame)); | ||
|
||
/// <summary> | ||
/// Loads the <see cref="CountVectorizer"/> that was previously saved using Save. | ||
/// </summary> | ||
/// <param name="path"> | ||
/// The path the previous <see cref="CountVectorizer"/> was saved to. | ||
/// </param> | ||
/// <returns>New <see cref="CountVectorizer"/> object</returns> | ||
public static CountVectorizer Load(string path) => | ||
WrapAsCountVectorizer((JvmObjectReference) | ||
SparkEnvironment.JvmBridge.CallStaticJavaMethod( | ||
s_countVectorizerClassName,"load", path)); | ||
|
||
/// <summary> | ||
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts | ||
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic | ||
/// models that model binary events rather than integer counts. Default: false | ||
/// </summary> | ||
/// <returns>boolean</returns> | ||
public bool GetBinary() => (bool)_jvmObject.Invoke("getBinary"); | ||
|
||
/// <summary> | ||
/// Sets the binary toggle to control the output vector values. If True, all nonzero counts | ||
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic | ||
/// models that model binary events rather than integer counts. Default: false | ||
/// </summary> | ||
/// <param name="value">Turn the binary toggle on or off</param> | ||
/// <returns><see cref="CountVectorizer"/> with the new binary toggle value set</returns> | ||
public CountVectorizer SetBinary(bool value) => | ||
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setBinary", value)); | ||
|
||
/// <summary> | ||
/// Gets the column that the <see cref="CountVectorizer"/> should read from and convert | ||
/// into buckets. This would have been set by SetInputCol. | ||
/// </summary> | ||
/// <returns>The input column of type string</returns> | ||
public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol"); | ||
|
||
/// <summary> | ||
/// Sets the column that the <see cref="CountVectorizer"/> should read from. | ||
/// </summary> | ||
/// <param name="value">The name of the column to use as the source.</param> | ||
/// <returns><see cref="CountVectorizer"/> with the input column set</returns> | ||
public CountVectorizer SetInputCol(string value) => | ||
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setInputCol", value)); | ||
|
||
/// <summary> | ||
/// Gets the name of the new column the <see cref="CountVectorizer"/> creates in the | ||
/// DataFrame. | ||
/// </summary> | ||
/// <returns>The name of the output column.</returns> | ||
public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol"); | ||
|
||
/// <summary> | ||
/// Sets the name of the new column the <see cref="CountVectorizer"/> creates in the | ||
/// DataFrame. | ||
/// </summary> | ||
/// <param name="value">The name of the output column which will be created.</param> | ||
/// <returns>New <see cref="CountVectorizer"/> with the output column set</returns> | ||
public CountVectorizer SetOutputCol(string value) => | ||
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", value)); | ||
|
||
/// <summary> | ||
/// Gets the maximum number of different documents a term could appear in to be included in | ||
/// the vocabulary. A term that appears more than the threshold will be ignored. If this is | ||
/// an integer greater than or equal to 1, this specifies the maximum number of documents | ||
/// the term could appear in; if this is a double in [0,1), then this specifies the maximum | ||
/// fraction of documents the term could appear in. | ||
/// </summary> | ||
/// <returns>The maximum document term frequency</returns> | ||
[Since(Versions.V2_4_0)] | ||
public double GetMaxDF() => (double)_jvmObject.Invoke("getMaxDF"); | ||
|
||
/// <summary> | ||
/// Sets the maximum number of different documents a term could appear in to be included in | ||
/// the vocabulary. A term that appears more than the threshold will be ignored. If this is | ||
/// an integer greater than or equal to 1, this specifies the maximum number of documents | ||
/// the term could appear in; if this is a double in [0,1), then this specifies the maximum | ||
/// fraction of documents the term could appear in. | ||
/// </summary> | ||
/// <param name="value">The maximum document term frequency</param> | ||
/// <returns>New <see cref="CountVectorizer"/> with the max df value set</returns> | ||
[Since(Versions.V2_4_0)] | ||
public CountVectorizer SetMaxDF(double value) => | ||
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setMaxDF", value)); | ||
|
||
/// <summary> | ||
/// Gets the minimum number of different documents a term must appear in to be included in | ||
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the | ||
/// number of documents the term must appear in; if this is a double in [0,1), then this | ||
/// specifies the fraction of documents. | ||
/// </summary> | ||
/// <returns>The minimum document term frequency</returns> | ||
public double GetMinDF() => (double)_jvmObject.Invoke("getMinDF"); | ||
|
||
/// <summary> | ||
/// Sets the minimum number of different documents a term must appear in to be included in | ||
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the | ||
/// number of documents the term must appear in; if this is a double in [0,1), then this | ||
/// specifies the fraction of documents. | ||
/// </summary> | ||
/// <param name="value">The minimum document term frequency</param> | ||
/// <returns>New <see cref="CountVectorizer"/> with the min df value set</returns> | ||
public CountVectorizer SetMinDF(double value) => | ||
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setMinDF", value)); | ||
|
||
/// <summary> | ||
/// Gets the filter to ignore rare words in a document. For each document, terms with | ||
/// frequency/count less than the given threshold are ignored. If this is an integer | ||
/// greater than or equal to 1, then this specifies a count (of times the term must appear | ||
/// in the document); if this is a double in [0,1), then this specifies a fraction (out of | ||
/// the document's token count). | ||
/// | ||
/// Note that the parameter is only used in transform of CountVectorizerModel and does not | ||
/// affect fitting. | ||
/// </summary> | ||
/// <returns>Minimum term frequency</returns> | ||
public double GetMinTF() => (double)_jvmObject.Invoke("getMinTF"); | ||
|
||
/// <summary> | ||
/// Sets the filter to ignore rare words in a document. For each document, terms with | ||
/// frequency/count less than the given threshold are ignored. If this is an integer | ||
/// greater than or equal to 1, then this specifies a count (of times the term must appear | ||
/// in the document); if this is a double in [0,1), then this specifies a fraction (out of | ||
/// the document's token count). | ||
/// | ||
/// Note that the parameter is only used in transform of CountVectorizerModel and does not | ||
/// affect fitting. | ||
/// </summary> | ||
/// <param name="value">Minimum term frequency</param> | ||
/// <returns>New <see cref="CountVectorizer"/> with the min term frequency set</returns> | ||
public CountVectorizer SetMinTF(double value) => | ||
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setMinTF", value)); | ||
|
||
/// <summary> | ||
/// Gets the max size of the vocabulary. <see cref="CountVectorizer"/> will build a | ||
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across | ||
/// the corpus. | ||
/// </summary> | ||
/// <returns>The max size of the vocabulary of type int.</returns> | ||
public int GetVocabSize() => (int)_jvmObject.Invoke("getVocabSize"); | ||
|
||
/// <summary> | ||
/// Sets the max size of the vocabulary. <see cref="CountVectorizer"/> will build a | ||
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across | ||
/// the corpus. | ||
/// </summary> | ||
/// <param name="value">The max vocabulary size</param> | ||
/// <returns><see cref="CountVectorizer"/> with the max vocab value set</returns> | ||
public CountVectorizer SetVocabSize(int value) => | ||
WrapAsCountVectorizer(_jvmObject.Invoke("setVocabSize", value)); | ||
|
||
private static CountVectorizer WrapAsCountVectorizer(object obj) => | ||
new CountVectorizer((JvmObjectReference)obj); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.