diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 383052e145..a0be7945d6 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -38,6 +38,11 @@ private protected sealed class BindingsImpl : BindingsBase public readonly int ScoreColumnIndex; // The type of the derived column. public readonly DataViewType PredColType; + /// + /// The name of the column that contains the predicted labels. + /// This field is used in the scoring process to store or reference the predicted label column. + /// + public readonly string PredictedLabelColumnName; // The ScoreColumnKind metadata value for all score columns. public readonly string ScoreColumnKind; @@ -54,6 +59,7 @@ private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string ScoreColumnIndex = scoreColIndex; ScoreColumnKind = scoreColumnKind; PredColType = predColType; + PredictedLabelColumnName = predictedLabelColumnName; _getScoreColumnKind = GetScoreColumnKind; _getScoreValueKind = GetScoreValueKind; @@ -113,7 +119,7 @@ public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bi bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn); env.Check(tmp, "Mapper doesn't have expected score column"); - return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType); + return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType, PredictedLabelColumnName); } public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input,