Skip to content

ImageClassificationTrainer PredictedLabelColumnName bug when the name is not default #7458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ private protected sealed class BindingsImpl : BindingsBase
public readonly int ScoreColumnIndex;
// The type of the derived column.
public readonly DataViewType PredColType;
/// <summary>
/// The name of the column that contains the predicted labels.
/// This field is used in the scoring process to store or reference the predicted label column.
/// </summary>
public readonly string PredictedLabelColumnName;
// The ScoreColumnKind metadata value for all score columns.
public readonly string ScoreColumnKind;

Expand All @@ -54,6 +59,7 @@ private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string
ScoreColumnIndex = scoreColIndex;
ScoreColumnKind = scoreColumnKind;
PredColType = predColType;
PredictedLabelColumnName = predictedLabelColumnName;

_getScoreColumnKind = GetScoreColumnKind;
_getScoreValueKind = GetScoreValueKind;
Expand Down Expand Up @@ -113,7 +119,7 @@ public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bi
bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn);
env.Check(tmp, "Mapper doesn't have expected score column");

return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType);
return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType, PredictedLabelColumnName);
}

public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input,
Expand Down