-
Notifications
You must be signed in to change notification settings - Fork 329
Implement ML Features: Word2Vec #491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c6e873a
Word2Vec and Word2VecModel
0a8d821
Merge branch 'master' of github.com:dotnet/spark into ml/word2vec
1c2d55a
whitespace
b2245dc
tidying:
daf6f2d
whitespace
17325df
reverting csproj file change
f28e914
reverting csproj file change
0a969e3
reverting csproj file change
fe6fb61
reverting csproj file change
14f5312
reverting csproj file change
3b135a9
reverting csproj file change
18c3789
reverting csproj file change
29242f7
speeding up tests
f1bfb7b
disabling logging
5e36609
removing logging off
2b1ded3
disabling logging for test
c215d4b
tidying after review
6a2904b
incorrect indentation
77ce67e
Apply suggestions from code review
GoEddie 5b4a99a
feedback after review
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
53 changes: 53 additions & 0 deletions
53
src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.IO; | ||
using Microsoft.Spark.E2ETest.Utils; | ||
using Microsoft.Spark.ML.Feature; | ||
using Microsoft.Spark.Sql; | ||
using Xunit; | ||
|
||
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature | ||
{ | ||
[Collection("Spark E2E Tests")] | ||
public class Word2VecModelTests | ||
{ | ||
private readonly SparkSession _spark; | ||
|
||
public Word2VecModelTests(SparkFixture fixture) | ||
{ | ||
_spark = fixture.Spark; | ||
} | ||
|
||
[Fact] | ||
public void TestWord2VecModel() | ||
{ | ||
DataFrame documentDataFrame = | ||
_spark.Sql("SELECT split('Hi I heard about Spark', ' ') as text"); | ||
|
||
Word2Vec word2vec = new Word2Vec() | ||
.SetInputCol("text") | ||
.SetOutputCol("result") | ||
.SetMinCount(1); | ||
|
||
Word2VecModel model = word2vec.Fit(documentDataFrame); | ||
|
||
const int expectedSynonyms = 2; | ||
DataFrame synonyms = model.FindSynonyms("Hi", expectedSynonyms); | ||
|
||
Assert.Equal(expectedSynonyms, synonyms.Count()); | ||
synonyms.Show(); | ||
|
||
using (var tempDirectory = new TemporaryDirectory()) | ||
{ | ||
string savePath = Path.Join(tempDirectory.Path, "word2vecModel"); | ||
model.Save(savePath); | ||
|
||
Word2VecModel loadedModel = Word2VecModel.Load(savePath); | ||
Assert.Equal(model.Uid(), loadedModel.Uid()); | ||
} | ||
} | ||
} | ||
} |
72 changes: 72 additions & 0 deletions
72
src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.IO; | ||
using Microsoft.Spark.E2ETest.Utils; | ||
using Microsoft.Spark.ML.Feature; | ||
using Microsoft.Spark.Sql; | ||
using Xunit; | ||
|
||
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature | ||
{ | ||
[Collection("Spark E2E Tests")] | ||
public class Word2VecTests | ||
{ | ||
private readonly SparkSession _spark; | ||
|
||
public Word2VecTests(SparkFixture fixture) | ||
{ | ||
_spark = fixture.Spark; | ||
} | ||
|
||
[Fact] | ||
public void TestWord2Vec() | ||
{ | ||
DataFrame documentDataFrame = _spark.Sql("SELECT split('Spark dotnet is cool', ' ')"); | ||
|
||
const string expectedInputCol = "text"; | ||
const string expectedOutputCol = "result"; | ||
const int expectedMinCount = 0; | ||
const int expectedMaxIter = 10; | ||
const int expectedMaxSentenceLength = 100; | ||
const int expectedNumPartitions = 1000; | ||
const int expectedSeed = 10000; | ||
const double expectedStepSize = 1.9; | ||
const int expectedVectorSize = 20; | ||
const int expectedWindowSize = 200; | ||
|
||
Word2Vec word2vec = new Word2Vec() | ||
.SetInputCol(expectedInputCol) | ||
.SetOutputCol(expectedOutputCol) | ||
.SetMinCount(expectedMinCount) | ||
.SetMaxIter(expectedMaxIter) | ||
.SetMaxSentenceLength(expectedMaxSentenceLength) | ||
.SetNumPartitions(expectedNumPartitions) | ||
.SetSeed(expectedSeed) | ||
.SetStepSize(expectedStepSize) | ||
.SetVectorSize(expectedVectorSize) | ||
.SetWindowSize(expectedWindowSize); | ||
|
||
Assert.Equal(expectedInputCol, word2vec.GetInputCol()); | ||
Assert.Equal(expectedOutputCol, word2vec.GetOutputCol()); | ||
Assert.Equal(expectedMinCount, word2vec.GetMinCount()); | ||
Assert.Equal(expectedMaxIter, word2vec.GetMaxIter()); | ||
Assert.Equal(expectedMaxSentenceLength, word2vec.GetMaxSentenceLength()); | ||
Assert.Equal(expectedNumPartitions, word2vec.GetNumPartitions()); | ||
Assert.Equal(expectedSeed, word2vec.GetSeed()); | ||
Assert.Equal(expectedStepSize, word2vec.GetStepSize()); | ||
Assert.Equal(expectedVectorSize, word2vec.GetVectorSize()); | ||
Assert.Equal(expectedWindowSize, word2vec.GetWindowSize()); | ||
|
||
using (var tempDirectory = new TemporaryDirectory()) | ||
{ | ||
string savePath = Path.Join(tempDirectory.Path, "word2vec"); | ||
word2vec.Save(savePath); | ||
|
||
Word2Vec loadedWord2Vec = Word2Vec.Load(savePath); | ||
Assert.Equal(word2vec.Uid(), loadedWord2Vec.Uid()); | ||
} | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.Spark.Interop; | ||
using Microsoft.Spark.Interop.Ipc; | ||
using Microsoft.Spark.Sql; | ||
|
||
namespace Microsoft.Spark.ML.Feature | ||
{ | ||
public class Word2Vec : IJvmObjectReferenceProvider | ||
{ | ||
private static readonly string s_word2VecClassName = | ||
"org.apache.spark.ml.feature.Word2Vec"; | ||
|
||
private readonly JvmObjectReference _jvmObject; | ||
|
||
/// <summary> | ||
/// Create a <see cref="Word2Vec"/> without any parameters. Once you have created a | ||
/// <see cref="Word2Vec"/> you must call <see cref="SetInputCol(string)"/>, | ||
/// <see cref="SetOutputCol(string)"/>, and <see cref="SetMinCount(int)"/>. | ||
/// </summary> | ||
public Word2Vec() | ||
{ | ||
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(s_word2VecClassName); | ||
} | ||
|
||
/// <summary> | ||
/// Create a <see cref="Word2Vec"/> with a UID that is used to give the | ||
/// <see cref="Word2Vec"/> a unique ID. | ||
/// </summary> | ||
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param> | ||
public Word2Vec(string uid) | ||
{ | ||
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(s_word2VecClassName, uid); | ||
} | ||
|
||
internal Word2Vec(JvmObjectReference jvmObject) | ||
{ | ||
_jvmObject = jvmObject; | ||
} | ||
|
||
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; | ||
|
||
/// <summary> | ||
/// Gets the column that the <see cref="Word2Vec"/> should read from. | ||
/// </summary> | ||
/// <returns>The name of the input column.</returns> | ||
public string GetInputCol() => (string)(_jvmObject.Invoke("getInputCol")); | ||
|
||
/// <summary> | ||
/// Sets the column that the <see cref="Word2Vec"/> should read from. | ||
/// </summary> | ||
/// <param name="value">The name of the column to as the source.</param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetInputCol(string value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setInputCol", value)); | ||
|
||
/// <summary> | ||
/// The <see cref="Word2Vec"/> will create a new column in the DataFrame, this is the | ||
/// name of the new column. | ||
/// </summary> | ||
/// <returns>The name of the output column.</returns> | ||
public string GetOutputCol() => (string)(_jvmObject.Invoke("getOutputCol")); | ||
|
||
/// <summary> | ||
/// The <see cref="Word2Vec"/> will create a new column in the DataFrame, this is the | ||
/// name of the new column. | ||
/// </summary> | ||
/// <param name="value">The name of the output column which will be created.</param> | ||
/// <returns>New <see cref="Word2Vec"/></returns> | ||
public Word2Vec SetOutputCol(string value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setOutputCol", value)); | ||
|
||
/// <summary> | ||
/// Gets the vector size, the dimension of the code that you want to transform from words. | ||
/// </summary> | ||
/// <returns> | ||
/// The vector size, the dimension of the code that you want to transform from words. | ||
/// </returns> | ||
public int GetVectorSize() => (int)(_jvmObject.Invoke("getVectorSize")); | ||
|
||
/// <summary> | ||
/// Sets the vector size, the dimension of the code that you want to transform from words. | ||
/// </summary> | ||
/// <param name="value"> | ||
/// The dimension of the code that you want to transform from words. | ||
/// </param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetVectorSize(int value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setVectorSize", value)); | ||
|
||
/// <summary> | ||
/// Gets the minimum number of times a token must appear to be included in the word2vec | ||
/// model's vocabulary. | ||
/// </summary> | ||
/// <returns> | ||
/// The minimum number of times a token must appear to be included in the word2vec model's | ||
/// vocabulary. | ||
/// </returns> | ||
public int GetMinCount() => (int)_jvmObject.Invoke("getMinCount"); | ||
|
||
/// <summary> | ||
/// The minimum number of times a token must appear to be included in the word2vec model's | ||
/// vocabulary. | ||
/// </summary> | ||
/// <param name="value"> | ||
/// The minimum number of times a token must appear to be included in the word2vec model's | ||
/// vocabulary, the default is 5. | ||
/// </param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public virtual Word2Vec SetMinCount(int value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setMinCount", value)); | ||
|
||
/// <summary>Gets the maximum number of iterations.</summary> | ||
/// <returns>The maximum number of iterations.</returns> | ||
public int GetMaxIter() => (int)_jvmObject.Invoke("getMaxIter"); | ||
|
||
/// <summary>Maximum number of iterations (>= 0).</summary> | ||
/// <param name="value">The number of iterations.</param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetMaxIter(int value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setMaxIter", value)); | ||
|
||
/// <summary> | ||
/// Gets the maximum length (in words) of each sentence in the input data. | ||
/// </summary> | ||
/// <returns>The maximum length (in words) of each sentence in the input data.</returns> | ||
public virtual int GetMaxSentenceLength() => | ||
(int)_jvmObject.Invoke("getMaxSentenceLength"); | ||
|
||
/// <summary> | ||
/// Sets the maximum length (in words) of each sentence in the input data. | ||
/// </summary> | ||
/// <param name="value"> | ||
/// The maximum length (in words) of each sentence in the input data. | ||
/// </param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetMaxSentenceLength(int value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setMaxSentenceLength", value)); | ||
|
||
/// <summary>Gets the number of partitions for sentences of words.</summary> | ||
/// <returns>The number of partitions for sentences of words.</returns> | ||
public int GetNumPartitions() => (int)_jvmObject.Invoke("getNumPartitions"); | ||
|
||
/// <summary>Sets the number of partitions for sentences of words.</summary> | ||
/// <param name="value"> | ||
/// The number of partitions for sentences of words, default is 1. | ||
/// </param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetNumPartitions(int value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setNumPartitions", value)); | ||
|
||
/// <summary>Gets the value that is used for the random seed.</summary> | ||
/// <returns>The value that is used for the random seed.</returns> | ||
public long GetSeed() => (long)_jvmObject.Invoke("getSeed"); | ||
|
||
/// <summary>Random seed.</summary> | ||
/// <param name="value">The value to use for the random seed.</param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetSeed(long value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setSeed", value)); | ||
|
||
/// <summary>Gets the size to be used for each iteration of optimization.</summary> | ||
/// <returns>The size to be used for each iteration of optimization.</returns> | ||
public double GetStepSize() => (double)_jvmObject.Invoke("getStepSize"); | ||
|
||
/// <summary>Step size to be used for each iteration of optimization (> 0).</summary> | ||
/// <param name="value">Value to use for the step size.</param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetStepSize(double value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setStepSize", value)); | ||
|
||
/// <summary>Gets the window size (context words from [-window, window]).</summary> | ||
/// <returns>The window size.</returns> | ||
public int GetWindowSize() => (int)_jvmObject.Invoke("getWindowSize"); | ||
|
||
/// <summary>The window size (context words from [-window, window]).</summary> | ||
/// <param name="value"> | ||
/// The window size (context words from [-window, window]), default is 5. | ||
/// </param> | ||
/// <returns><see cref="Word2Vec"/></returns> | ||
public Word2Vec SetWindowSize(int value) => | ||
WrapAsWord2Vec(_jvmObject.Invoke("setWindowSize", value)); | ||
|
||
/// <summary>Fits a model to the input data.</summary> | ||
/// <param name="dataFrame">The <see cref="DataFrame"/> to fit the model to.</param> | ||
/// <returns><see cref="Word2VecModel"/></returns> | ||
public Word2VecModel Fit(DataFrame dataFrame) => | ||
new Word2VecModel((JvmObjectReference)_jvmObject.Invoke("fit", dataFrame)); | ||
|
||
/// <summary> | ||
/// The uid that was used to create the <see cref="Word2Vec"/>. If no UID is passed in | ||
/// when creating the <see cref="Word2Vec"/> then a random UID is created when the | ||
/// <see cref="Word2Vec"/> is created. | ||
/// </summary> | ||
/// <returns>string UID identifying the <see cref="Word2Vec"/>.</returns> | ||
public string Uid() => (string)_jvmObject.Invoke("uid"); | ||
|
||
/// <summary> | ||
/// Loads the <see cref="Word2Vec"/> that was previously saved using | ||
/// <see cref="Save(string)"/>. | ||
/// </summary> | ||
/// <param name="path">The path the previous <see cref="Word2Vec"/> was saved to</param> | ||
/// <returns>New <see cref="Word2Vec"/> object, loaded from path.</returns> | ||
public static Word2Vec Load(string path) => WrapAsWord2Vec( | ||
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_word2VecClassName, "load", path)); | ||
|
||
/// <summary> | ||
/// Saves the <see cref="Word2Vec"/> so that it can be loaded later using | ||
/// <see cref="Load(string)"/>. | ||
/// </summary> | ||
/// <param name="path">The path to save the <see cref="Word2Vec"/> to.</param> | ||
/// <returns>New <see cref="Word2Vec"/> object.</returns> | ||
public Word2Vec Save(string path) => WrapAsWord2Vec(_jvmObject.Invoke("save", path)); | ||
|
||
private static Word2Vec WrapAsWord2Vec(object obj) => | ||
new Word2Vec((JvmObjectReference)obj); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we expect 2? (sorry I am not familiar with this model).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
findSynonyms takes the word to check and the maximum amount of synonyms to return so 2 is checking that the result is limited to 2 rows.