Skip to content

Commit b6135a9

Browse files
committed
Sdca entry points also strongly-typed based on trained models.
1. Remove arguments to set up calibrator in all SDCA trainer 2. Clean up entry point base on argument's changes. 3. Update test files
1 parent ce8f768 commit b6135a9

11 files changed

+1766
-1486
lines changed

src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,30 +1399,9 @@ public void Add(Double summand)
13991399
}
14001400
}
14011401

1402-
public abstract class SdcaBinaryTrainerBase<TModelParameters> : SdcaTrainerBase<SdcaBinaryTrainerBase<TModelParameters>.Options, BinaryPredictionTransformer<TModelParameters>, TModelParameters>
1402+
public abstract class SdcaBinaryTrainerBase<TModelParameters> : SdcaTrainerBase<SdcaBinaryTrainerBase<TModelParameters>.BinaryArgumentBase, BinaryPredictionTransformer<TModelParameters>, TModelParameters>
14031403
where TModelParameters : class, IPredictorProducing<float>
14041404
{
1405-
public sealed class Options : ArgumentsBase
1406-
{
1407-
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
1408-
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
1409-
1410-
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
1411-
public float PositiveInstanceWeight = 1;
1412-
1413-
[Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
1414-
public ICalibratorTrainerFactory Calibrator = null;
1415-
1416-
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
1417-
public int MaxCalibrationExamples = 1000000;
1418-
1419-
internal override void Check(IHostEnvironment env)
1420-
{
1421-
base.Check(env);
1422-
env.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Weight for positive instances must be positive");
1423-
}
1424-
}
1425-
14261405
private readonly ISupportSdcaClassificationLoss _loss;
14271406
private readonly float _positiveInstanceWeight;
14281407

@@ -1436,6 +1415,18 @@ internal override void Check(IHostEnvironment env)
14361415

14371416
public override TrainerInfo Info { get; }
14381417

1418+
public class BinaryArgumentBase : ArgumentsBase
1419+
{
1420+
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
1421+
public float PositiveInstanceWeight = 1;
1422+
1423+
internal override void Check(IHostEnvironment env)
1424+
{
1425+
base.Check(env);
1426+
env.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Weight for positive instances must be positive");
1427+
}
1428+
}
1429+
14391430
/// <summary>
14401431
/// Initializes a new instance of <see cref="SdcaBinaryTrainerBase{TModelParameters}"/>
14411432
/// </summary>
@@ -1460,19 +1451,19 @@ protected SdcaBinaryTrainerBase(IHostEnvironment env,
14601451
{
14611452
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
14621453
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
1463-
_loss = loss ?? Args.LossFunction.CreateComponent(env);
1454+
_loss = loss ?? new LogLossFactory().CreateComponent(env);
14641455
Loss = _loss;
1465-
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
1456+
Info = new TrainerInfo(calibration: false);
14661457
_positiveInstanceWeight = Args.PositiveInstanceWeight;
14671458
_outputColumns = ComputeSdcaBinaryClassifierSchemaShape();
14681459
}
14691460

1470-
protected SdcaBinaryTrainerBase(IHostEnvironment env, Options options)
1461+
protected SdcaBinaryTrainerBase(IHostEnvironment env, BinaryArgumentBase options, ISupportSdcaClassificationLoss loss = null)
14711462
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
14721463
{
1473-
_loss = options.LossFunction.CreateComponent(env);
1464+
_loss = loss ?? new LogLossFactory().CreateComponent(env);
14741465
Loss = _loss;
1475-
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
1466+
Info = new TrainerInfo(calibration: false);
14761467
_positiveInstanceWeight = Args.PositiveInstanceWeight;
14771468
_outputColumns = ComputeSdcaBinaryClassifierSchemaShape();
14781469
}
@@ -1528,6 +1519,13 @@ public sealed class SdcaCalibratedBinaryTrainer :
15281519
internal const string LoadNameValue = "SDCALR";
15291520
internal const string UserNameValue = "Fast Linear (SA-SDCA) Logistic Regression";
15301521

1522+
/// <summary>
1523+
/// Configuration to training logistic regression using SDCA.
1524+
/// </summary>
1525+
public sealed class Options : BinaryArgumentBase
1526+
{
1527+
}
1528+
15311529
internal SdcaCalibratedBinaryTrainer(IHostEnvironment env,
15321530
string labelColumn = DefaultColumnNames.Label,
15331531
string featureColumn = DefaultColumnNames.Features,
@@ -1540,7 +1538,7 @@ internal SdcaCalibratedBinaryTrainer(IHostEnvironment env,
15401538
}
15411539

15421540
internal SdcaCalibratedBinaryTrainer(IHostEnvironment env, Options options)
1543-
: base(env, options)
1541+
: base(env, options, new LogLoss())
15441542
{
15451543
}
15461544

@@ -1583,6 +1581,15 @@ public sealed class SdcaBinaryTrainer : SdcaBinaryTrainerBase<LinearBinaryModelP
15831581
internal const string LoadNameValue = "SDCA";
15841582
internal const string UserNameValue = "Fast Linear (SA-SDCA)";
15851583

1584+
/// <summary>
1585+
/// General Configuration to training linear model using SDCA.
1586+
/// </summary>
1587+
public sealed class Options : BinaryArgumentBase
1588+
{
1589+
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
1590+
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
1591+
}
1592+
15861593
internal SdcaBinaryTrainer(IHostEnvironment env,
15871594
string labelColumn = DefaultColumnNames.Label,
15881595
string featureColumn = DefaultColumnNames.Features,
@@ -1596,7 +1603,7 @@ internal SdcaBinaryTrainer(IHostEnvironment env,
15961603
}
15971604

15981605
internal SdcaBinaryTrainer(IHostEnvironment env, Options options)
1599-
: base(env, options)
1606+
: base(env, options, options.LossFunction.CreateComponent(env))
16001607
{
16011608
}
16021609

@@ -2025,8 +2032,23 @@ internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnviro
20252032

20262033
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
20272034
() => new SdcaBinaryTrainer(host, input),
2028-
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
2029-
maxCalibrationExamples: input.MaxCalibrationExamples);
2035+
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
2036+
}
2037+
2038+
[TlcModule.EntryPoint(Name = "Trainers.StochasticDualCoordinateAscentCalibratedBinaryClassifier",
2039+
Desc = "Train logistic regression using SDCA.",
2040+
UserName = SdcaCalibratedBinaryTrainer.UserNameValue,
2041+
ShortName = SdcaCalibratedBinaryTrainer.LoadNameValue)]
2042+
internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaCalibratedBinaryTrainer.Options input)
2043+
{
2044+
Contracts.CheckValue(env, nameof(env));
2045+
var host = env.Register("TrainSDCA");
2046+
host.CheckValue(input, nameof(input));
2047+
EntryPointUtils.CheckInputArgs(host, input);
2048+
2049+
return LearnerEntryPointsUtils.Train<SdcaCalibratedBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
2050+
() => new SdcaCalibratedBinaryTrainer(host, input),
2051+
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
20302052
}
20312053
}
20322054
}

test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ Trainers.OnlineGradientDescentRegressor Train a Online gradient descent perceptr
6565
Trainers.OrdinaryLeastSquaresRegressor Train an OLS regression model. Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer TrainRegression Microsoft.ML.Trainers.HalLearners.OlsLinearRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
6666
Trainers.PcaAnomalyDetector Train an PCA Anomaly model. Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer TrainPcaAnomaly Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput
6767
Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Trainers.PoissonRegression TrainRegression Microsoft.ML.Trainers.PoissonRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
68-
Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
68+
Trainers.StochasticDualCoordinateAscentBinaryClassifier Train a linear model to binary classification using SDCA. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
69+
Trainers.StochasticDualCoordinateAscentCalibratedBinaryClassifier Train logistic regression using SDCA. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaCalibratedBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
6970
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
7071
Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
7172
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput

0 commit comments

Comments
 (0)