Skip to content

Commit 407bdde

Browse files
committed
Update doc
1 parent 10106d2 commit 407bdde

File tree

4 files changed

+137
-2
lines changed

4 files changed

+137
-2
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using System;
2+
using System.Linq;
3+
using Microsoft.ML.Data;
4+
using Microsoft.ML.Trainers;
5+
6+
namespace Microsoft.ML.Samples.Dynamic
7+
{
8+
public class SDCACalibrated_BinaryClassificationExample
9+
{
10+
/// <summary>
11+
/// Call <see cref="SdcaCalibratedBinaryTrainer"/> to train a logistic regression model.
12+
/// Training examples are generated by calling <see cref="Samples.Utils.DatasetUtils.GenerateBinaryLabelFloatFeatureVectorSamples"/>.
13+
/// The type of those examples is <see cref="SamplesUtils.DatasetUtils.FloatLabelFloatFeatureVectorSample"/>,
14+
/// which contains one bool label and a float feature vector.
15+
/// </summary>
16+
public static void SDCACalibrated_BinaryClassification()
17+
{
18+
// Generate C# objects as training examples.
19+
var rawData = SamplesUtils.DatasetUtils.GenerateBinaryLabelFloatFeatureVectorSamples(100);
20+
21+
// Information in first example.
22+
// Label: true
23+
Console.WriteLine("First example's label is {0}", rawData.First().Label);
24+
// Features is a 10-element float[]:
25+
// [0] 1.0173254 float
26+
// [1] 0.9680227 float
27+
// [2] 0.7581612 float
28+
// [3] 0.406033158 float
29+
// [4] 0.7588848 float
30+
// [5] 1.10602713 float
31+
// [6] 0.6421779 float
32+
// [7] 1.17754972 float
33+
// [8] 0.473704457 float
34+
// [9] 0.4919063 float
35+
Console.WriteLine("First example's feature vector is {0}", rawData.First().Features);
36+
37+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
38+
// as a catalog of available operations and as the source of randomness.
39+
var mlContext = new MLContext();
40+
41+
// Step 1: Read the data as an IDataView.
42+
var data = mlContext.Data.ReadFromEnumerable(rawData);
43+
44+
// ML.NET doesn't cache data set by default. Caching is very helpful when working with iterative
45+
// algorithms which needs many data passes. Since SDCA is the case, we cache.
46+
data = mlContext.Data.Cache(data);
47+
48+
// Step 2: Create a binary classifier.
49+
// We set the "Label" column as the label of the dataset, and the "Features" column as the features column.
50+
var pipeline = mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscentCalibrated(labelColumn: "Label", featureColumn: "Features", l2Const: 0.001f);
51+
52+
// Step 3: Train the pipeline created.
53+
var model = pipeline.Fit(data);
54+
55+
// Step 4: Make prediction and evaluate its quality (on training set).
56+
var prediction = model.Transform(data);
57+
58+
var rawPrediction = mlContext.CreateEnumerable<SamplesUtils.DatasetUtils.CalibratedBinaryClassifierOutput>(prediction, false);
59+
60+
// Step 5: Inspect the prediction of the first example.
61+
var first = rawPrediction.First();
62+
63+
Console.WriteLine("The first example actual label is {0}." +
64+
" The trained model assigns it a score {1} and a probability of being positive class {2}.",
65+
first.Label /*true*/, first.Score /*around 3.2*/, first.Probability /*around around 0.95*/);
66+
}
67+
}
68+
}

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ public static IEnumerable<SampleVectorOfNumbersData> GetVectorOfNumbersData()
371371

372372
private const int _simpleBinaryClassSampleFeatureLength = 10;
373373

374+
/// <summary>
375+
/// Example with one binary label and 10 feature values.
376+
/// </summary>
374377
public class BinaryLabelFloatFeatureVectorSample
375378
{
376379
public bool Label;
@@ -379,6 +382,17 @@ public class BinaryLabelFloatFeatureVectorSample
379382
public float[] Features;
380383
}
381384

385+
/// <summary>
386+
/// Class used to capture prediction of <see cref="BinaryLabelFloatFeatureVectorSample"/> when
387+
/// calling <see cref="CursoringUtils.CreateEnumerable"/> via on <see cref="MLContext"/>.
388+
/// </summary>
389+
public class CalibratedBinaryClassifierOutput
390+
{
391+
public bool Label;
392+
public float Score;
393+
public float Probability;
394+
}
395+
382396
public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount)
383397
{
384398
var rnd = new Random(0);

src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ public static SdcaBinaryTrainer StochasticDualCoordinateAscent(
135135

136136
/// <summary>
137137
/// Predict a target using a logistic regression model trained with the SDCA trainer.
138+
/// The trained model can produce probablity via feeding output value of the linear
139+
/// function to a <see cref="PlattCalibrator"/>.
138140
/// </summary>
139141
/// <param name="catalog">The binary classification catalog trainer object.</param>
140142
/// <param name="labelColumn">The labelColumn, or dependent variable.</param>
@@ -146,7 +148,7 @@ public static SdcaBinaryTrainer StochasticDualCoordinateAscent(
146148
/// <example>
147149
/// <format type="text/markdown">
148150
/// <![CDATA[
149-
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs)]
151+
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/SDCACalibrated.cs)]
150152
/// ]]></format>
151153
/// </example>
152154
public static SdcaCalibratedBinaryTrainer StochasticDualCoordinateAscentCalibrated(
@@ -164,7 +166,10 @@ public static SdcaCalibratedBinaryTrainer StochasticDualCoordinateAscentCalibrat
164166
}
165167

166168
/// <summary>
167-
/// Predict a target using a linear binary classification model trained with the SDCA trainer.
169+
/// Predict a target using a logistic regression model trained with the SDCA trainer.
170+
/// The trained model can produce probablity via feeding output value of the linear
171+
/// function to a <see cref="PlattCalibrator"/>. Comparing with <see cref="StochasticDualCoordinateAscentCalibrated(BinaryClassificationCatalog.BinaryClassificationTrainers, string, string, string, float?, float?, int?)"/>,
172+
/// this function allows more advanced settings via accepting <see cref="SdcaCalibratedBinaryTrainer.Options"/>.
168173
/// </summary>
169174
/// <param name="catalog">The binary classification catalog trainer object.</param>
170175
/// <param name="options">Advanced arguments to the algorithm.</param>

test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,54 @@ private void CrossValidationOn(string dataPath)
445445
Console.WriteLine(microAccuracies.Average());
446446
}
447447

448+
[Fact]
449+
public void SdcaLogisticRegression()
450+
{
451+
// Generate C# objects as training examples.
452+
var rawData = SamplesUtils.DatasetUtils.GenerateBinaryLabelFloatFeatureVectorSamples(100);
453+
454+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
455+
// as a catalog of available operations and as the source of randomness.
456+
var mlContext = new MLContext();
457+
458+
// Step 1: Read the data as an IDataView.
459+
var data = mlContext.Data.ReadFromEnumerable(rawData);
460+
461+
// ML.NET doesn't cache data set by default. Caching is very helpful when working with iterative
462+
// algorithms which needs many data passes. Since SDCA is the case, we cache.
463+
data = mlContext.Data.Cache(data);
464+
465+
// Step 2: Create a binary classifier.
466+
// We set the "Label" column as the label of the dataset, and the "Features" column as the features column.
467+
var pipeline = mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscentCalibrated(labelColumn: "Label", featureColumn: "Features", l2Const: 0.001f);
468+
469+
// Step 3: Train the pipeline created.
470+
var model = pipeline.Fit(data);
471+
472+
// Step 4: Make prediction and evaluate its quality (on training set).
473+
var prediction = model.Transform(data);
474+
var metrics = mlContext.BinaryClassification.Evaluate(prediction);
475+
476+
// Check a few metrics to make sure the trained model is ok.
477+
Assert.InRange(metrics.Auc, 0.9, 1);
478+
Assert.InRange(metrics.LogLoss, 0, 0.5);
479+
480+
var rawPrediction = mlContext.CreateEnumerable<SamplesUtils.DatasetUtils.CalibratedBinaryClassifierOutput>(prediction, false);
481+
482+
// Step 5: Inspect the prediction of the first example.
483+
var first = rawPrediction.First();
484+
// This is a positive example.
485+
Assert.True(first.Label);
486+
// Positive example should have non-negative score.
487+
Assert.True(first.Score > 0);
488+
// Positive example should have high probability of belonging the positive class.
489+
Assert.InRange(first.Probability, 0.8, 1);
490+
491+
Console.WriteLine("The first example actual label is {0}." +
492+
" The trained model assigns it a score {1} and a probability of being positive class {2}.",
493+
first.Label /*true*/, first.Score /*around 3.2*/, first.Probability /*around around 0.95*/);
494+
}
495+
448496
[Fact]
449497
public void ReadData()
450498
{

0 commit comments

Comments
 (0)