Skip to content

Implement ML Features: Word2Vec #491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class Word2VecModelTests
{
private readonly SparkSession _spark;

public Word2VecModelTests(SparkFixture fixture)
{
_spark = fixture.Spark;
}

[Fact]
public void TestWord2VecModel()
{
DataFrame documentDataFrame =
_spark.Sql("SELECT split('Hi I heard about Spark', ' ') as text");

Word2Vec word2vec = new Word2Vec()
.SetInputCol("text")
.SetOutputCol("result")
.SetMinCount(1);

Word2VecModel model = word2vec.Fit(documentDataFrame);

const int expectedSynonyms = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we expect 2? (sorry I am not familiar with this model).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

findSynonyms takes the word to check and the maximum amount of synonyms to return so 2 is checking that the result is limited to 2 rows.

DataFrame synonyms = model.FindSynonyms("Hi", expectedSynonyms);

Assert.Equal(expectedSynonyms, synonyms.Count());
synonyms.Show();

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "word2vecModel");
model.Save(savePath);

Word2VecModel loadedModel = Word2VecModel.Load(savePath);
Assert.Equal(model.Uid(), loadedModel.Uid());
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class Word2VecTests
{
private readonly SparkSession _spark;

public Word2VecTests(SparkFixture fixture)
{
_spark = fixture.Spark;
}

[Fact]
public void TestWord2Vec()
{
DataFrame documentDataFrame = _spark.Sql("SELECT split('Spark dotnet is cool', ' ')");

const string expectedInputCol = "text";
const string expectedOutputCol = "result";
const int expectedMinCount = 0;
const int expectedMaxIter = 10;
const int expectedMaxSentenceLength = 100;
const int expectedNumPartitions = 1000;
const int expectedSeed = 10000;
const double expectedStepSize = 1.9;
const int expectedVectorSize = 20;
const int expectedWindowSize = 200;

Word2Vec word2vec = new Word2Vec()
.SetInputCol(expectedInputCol)
.SetOutputCol(expectedOutputCol)
.SetMinCount(expectedMinCount)
.SetMaxIter(expectedMaxIter)
.SetMaxSentenceLength(expectedMaxSentenceLength)
.SetNumPartitions(expectedNumPartitions)
.SetSeed(expectedSeed)
.SetStepSize(expectedStepSize)
.SetVectorSize(expectedVectorSize)
.SetWindowSize(expectedWindowSize);

Assert.Equal(expectedInputCol, word2vec.GetInputCol());
Assert.Equal(expectedOutputCol, word2vec.GetOutputCol());
Assert.Equal(expectedMinCount, word2vec.GetMinCount());
Assert.Equal(expectedMaxIter, word2vec.GetMaxIter());
Assert.Equal(expectedMaxSentenceLength, word2vec.GetMaxSentenceLength());
Assert.Equal(expectedNumPartitions, word2vec.GetNumPartitions());
Assert.Equal(expectedSeed, word2vec.GetSeed());
Assert.Equal(expectedStepSize, word2vec.GetStepSize());
Assert.Equal(expectedVectorSize, word2vec.GetVectorSize());
Assert.Equal(expectedWindowSize, word2vec.GetWindowSize());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "word2vec");
word2vec.Save(savePath);

Word2Vec loadedWord2Vec = Word2Vec.Load(savePath);
Assert.Equal(word2vec.Uid(), loadedWord2Vec.Uid());
}
}
}
}
4 changes: 4 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/SparkFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public class EnvironmentVariableNames

private readonly Process _process = new Process();
private readonly TemporaryDirectory _tempDirectory = new TemporaryDirectory();

public const string DefaultLogLevel = "ERROR";

internal SparkSession Spark { get; }

Expand Down Expand Up @@ -106,6 +108,8 @@ public SparkFixture()
.Config("spark.ui.showConsoleProgress", false)
.AppName("Microsoft.Spark.E2ETest")
.GetOrCreate();

Spark.SparkContext.SetLogLevel(DefaultLogLevel);

Jvm = ((IJvmObjectReferenceProvider)Spark).Reference.Jvm;
}
Expand Down
220 changes: 220 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Feature/Word2Vec.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;

namespace Microsoft.Spark.ML.Feature
{
public class Word2Vec : IJvmObjectReferenceProvider
{
private static readonly string s_word2VecClassName =
"org.apache.spark.ml.feature.Word2Vec";

private readonly JvmObjectReference _jvmObject;

/// <summary>
/// Create a <see cref="Word2Vec"/> without any parameters. Once you have created a
/// <see cref="Word2Vec"/> you must call <see cref="SetInputCol(string)"/>,
/// <see cref="SetOutputCol(string)"/>, and <see cref="SetMinCount(int)"/>.
/// </summary>
public Word2Vec()
{
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(s_word2VecClassName);
}

/// <summary>
/// Create a <see cref="Word2Vec"/> with a UID that is used to give the
/// <see cref="Word2Vec"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public Word2Vec(string uid)
{
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(s_word2VecClassName, uid);
}

internal Word2Vec(JvmObjectReference jvmObject)
{
_jvmObject = jvmObject;
}

JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;

/// <summary>
/// Gets the column that the <see cref="Word2Vec"/> should read from.
/// </summary>
/// <returns>The name of the input column.</returns>
public string GetInputCol() => (string)(_jvmObject.Invoke("getInputCol"));

/// <summary>
/// Sets the column that the <see cref="Word2Vec"/> should read from.
/// </summary>
/// <param name="value">The name of the column to as the source.</param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetInputCol(string value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setInputCol", value));

/// <summary>
/// The <see cref="Word2Vec"/> will create a new column in the DataFrame, this is the
/// name of the new column.
/// </summary>
/// <returns>The name of the output column.</returns>
public string GetOutputCol() => (string)(_jvmObject.Invoke("getOutputCol"));

/// <summary>
/// The <see cref="Word2Vec"/> will create a new column in the DataFrame, this is the
/// name of the new column.
/// </summary>
/// <param name="value">The name of the output column which will be created.</param>
/// <returns>New <see cref="Word2Vec"/></returns>
public Word2Vec SetOutputCol(string value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setOutputCol", value));

/// <summary>
/// Gets the vector size, the dimension of the code that you want to transform from words.
/// </summary>
/// <returns>
/// The vector size, the dimension of the code that you want to transform from words.
/// </returns>
public int GetVectorSize() => (int)(_jvmObject.Invoke("getVectorSize"));

/// <summary>
/// Sets the vector size, the dimension of the code that you want to transform from words.
/// </summary>
/// <param name="value">
/// The dimension of the code that you want to transform from words.
/// </param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetVectorSize(int value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setVectorSize", value));

/// <summary>
/// Gets the minimum number of times a token must appear to be included in the word2vec
/// model's vocabulary.
/// </summary>
/// <returns>
/// The minimum number of times a token must appear to be included in the word2vec model's
/// vocabulary.
/// </returns>
public int GetMinCount() => (int)_jvmObject.Invoke("getMinCount");

/// <summary>
/// The minimum number of times a token must appear to be included in the word2vec model's
/// vocabulary.
/// </summary>
/// <param name="value">
/// The minimum number of times a token must appear to be included in the word2vec model's
/// vocabulary, the default is 5.
/// </param>
/// <returns><see cref="Word2Vec"/></returns>
public virtual Word2Vec SetMinCount(int value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setMinCount", value));

/// <summary>Gets the maximum number of iterations.</summary>
/// <returns>The maximum number of iterations.</returns>
public int GetMaxIter() => (int)_jvmObject.Invoke("getMaxIter");

/// <summary>Maximum number of iterations (&gt;= 0).</summary>
/// <param name="value">The number of iterations.</param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetMaxIter(int value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setMaxIter", value));

/// <summary>
/// Gets the maximum length (in words) of each sentence in the input data.
/// </summary>
/// <returns>The maximum length (in words) of each sentence in the input data.</returns>
public virtual int GetMaxSentenceLength() =>
(int)_jvmObject.Invoke("getMaxSentenceLength");

/// <summary>
/// Sets the maximum length (in words) of each sentence in the input data.
/// </summary>
/// <param name="value">
/// The maximum length (in words) of each sentence in the input data.
/// </param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetMaxSentenceLength(int value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setMaxSentenceLength", value));

/// <summary>Gets the number of partitions for sentences of words.</summary>
/// <returns>The number of partitions for sentences of words.</returns>
public int GetNumPartitions() => (int)_jvmObject.Invoke("getNumPartitions");

/// <summary>Sets the number of partitions for sentences of words.</summary>
/// <param name="value">
/// The number of partitions for sentences of words, default is 1.
/// </param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetNumPartitions(int value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setNumPartitions", value));

/// <summary>Gets the value that is used for the random seed.</summary>
/// <returns>The value that is used for the random seed.</returns>
public long GetSeed() => (long)_jvmObject.Invoke("getSeed");

/// <summary>Random seed.</summary>
/// <param name="value">The value to use for the random seed.</param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetSeed(long value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setSeed", value));

/// <summary>Gets the size to be used for each iteration of optimization.</summary>
/// <returns>The size to be used for each iteration of optimization.</returns>
public double GetStepSize() => (double)_jvmObject.Invoke("getStepSize");

/// <summary>Step size to be used for each iteration of optimization (&gt; 0).</summary>
/// <param name="value">Value to use for the step size.</param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetStepSize(double value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setStepSize", value));

/// <summary>Gets the window size (context words from [-window, window]).</summary>
/// <returns>The window size.</returns>
public int GetWindowSize() => (int)_jvmObject.Invoke("getWindowSize");

/// <summary>The window size (context words from [-window, window]).</summary>
/// <param name="value">
/// The window size (context words from [-window, window]), default is 5.
/// </param>
/// <returns><see cref="Word2Vec"/></returns>
public Word2Vec SetWindowSize(int value) =>
WrapAsWord2Vec(_jvmObject.Invoke("setWindowSize", value));

/// <summary>Fits a model to the input data.</summary>
/// <param name="dataFrame">The <see cref="DataFrame"/> to fit the model to.</param>
/// <returns><see cref="Word2VecModel"/></returns>
public Word2VecModel Fit(DataFrame dataFrame) =>
new Word2VecModel((JvmObjectReference)_jvmObject.Invoke("fit", dataFrame));

/// <summary>
/// The uid that was used to create the <see cref="Word2Vec"/>. If no UID is passed in
/// when creating the <see cref="Word2Vec"/> then a random UID is created when the
/// <see cref="Word2Vec"/> is created.
/// </summary>
/// <returns>string UID identifying the <see cref="Word2Vec"/>.</returns>
public string Uid() => (string)_jvmObject.Invoke("uid");

/// <summary>
/// Loads the <see cref="Word2Vec"/> that was previously saved using
/// <see cref="Save(string)"/>.
/// </summary>
/// <param name="path">The path the previous <see cref="Word2Vec"/> was saved to</param>
/// <returns>New <see cref="Word2Vec"/> object, loaded from path.</returns>
public static Word2Vec Load(string path) => WrapAsWord2Vec(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_word2VecClassName, "load", path));

/// <summary>
/// Saves the <see cref="Word2Vec"/> so that it can be loaded later using
/// <see cref="Load(string)"/>.
/// </summary>
/// <param name="path">The path to save the <see cref="Word2Vec"/> to.</param>
/// <returns>New <see cref="Word2Vec"/> object.</returns>
public Word2Vec Save(string path) => WrapAsWord2Vec(_jvmObject.Invoke("save", path));

private static Word2Vec WrapAsWord2Vec(object obj) =>
new Word2Vec((JvmObjectReference)obj);
}
}
Loading