Skip to content

Commit ea84d42

Browse files
authored
Add QA sweepable estimator in AutoML (#6781)
* Add QA sweepable * clean
1 parent a823199 commit ea84d42

File tree

6 files changed

+103
-4
lines changed

6 files changed

+103
-4
lines changed

src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
"ForecastBySsa",
7474
"TextClassifcation",
7575
"SentenceSimilarity",
76-
"ObjectDetection"
76+
"ObjectDetection",
77+
"QuestionAnswering"
7778
]
7879
},
7980
"nugetDependencies": {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"$schema": "./search-space-schema.json#",
3+
"name": "question_answering_option",
4+
"search_space": [
5+
{
6+
"name": "ContextColumnName",
7+
"type": "string",
8+
"default": "Context"
9+
},
10+
{
11+
"name": "QuestionColumnName",
12+
"type": "string",
13+
"default": "Question"
14+
},
15+
{
16+
"name": "TrainingAnswerColumnName",
17+
"type": "string",
18+
"default": "TrainingAnswer"
19+
},
20+
{
21+
"name": "AnswerIndexStartColumnName",
22+
"type": "string",
23+
"default": "AnswerStart"
24+
},
25+
{
26+
"name": "ScoreColumnName",
27+
"type": "string",
28+
"default": "Score"
29+
},
30+
{
31+
"name": "predictedAnswerColumnName",
32+
"type": "string",
33+
"default": "Answer"
34+
},
35+
{
36+
"name": "BatchSize",
37+
"type": "integer",
38+
"default": 4
39+
},
40+
{
41+
"name": "MaxEpochs",
42+
"type": "integer",
43+
"default": 10
44+
},
45+
{
46+
"name": "TopKAnswers",
47+
"type": "integer",
48+
"default": 3
49+
},
50+
{
51+
"name": "Architecture",
52+
"type": "bertArchitecture",
53+
"default": "BertArchitecture.Roberta"
54+
}
55+
]
56+
}

src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@
146146
"dnn_featurizer_image_option",
147147
"text_classification_option",
148148
"sentence_similarity_option",
149-
"object_detection_option"
149+
"object_detection_option",
150+
"question_answering_option"
150151
]
151152
},
152153
"option_name": {
@@ -210,7 +211,13 @@
210211
"Steps",
211212
"MaxEpoch",
212213
"InitLearningRate",
213-
"WeightDecay"
214+
"WeightDecay",
215+
"ContextColumnName",
216+
"QuestionColumnName",
217+
"TrainingAnswerColumnName",
218+
"AnswerIndexStartColumnName",
219+
"predictedAnswerColumnName",
220+
"TopKAnswers"
214221
]
215222
},
216223
"option_type": {

src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,13 @@
532532
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
533533
"searchOption": "object_detection_option"
534534
},
535+
{
536+
"functionName": "QuestionAnswering",
537+
"estimatorTypes": [ "MultiClassification" ],
538+
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
539+
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
540+
"searchOption": "question_answering_option"
541+
},
535542
{
536543
"functionName": "ForecastBySsa",
537544
"estimatorTypes": [ "Forecasting" ],

src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
<AdditionalFiles Include="CodeGen\code_gen_flag.json" />
7070
<AdditionalFiles Include="CodeGen\*-estimators.json" />
7171
</ItemGroup>
72-
72+
7373
<ItemGroup>
7474
<EmbeddedResource Include="Tuner\Portfolios.json">
7575
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
using Microsoft.ML.TorchSharp;
5+
using Microsoft.ML.TorchSharp.NasBert;
6+
using Microsoft.ML.TorchSharp.Roberta;
7+
8+
namespace Microsoft.ML.AutoML.CodeGen
9+
{
10+
internal partial class QuestionAnsweringMulti
11+
{
12+
public override IEstimator<ITransformer> BuildFromOption(MLContext context, QuestionAnsweringOption param)
13+
{
14+
return context.MulticlassClassification.Trainers.QuestionAnswer(
15+
contextColumnName: param.ContextColumnName,
16+
questionColumnName: param.QuestionColumnName,
17+
trainingAnswerColumnName: param.TrainingAnswerColumnName,
18+
answerIndexColumnName: param.AnswerIndexStartColumnName,
19+
predictedAnswerColumnName: param.PredictedAnswerColumnName,
20+
scoreColumnName: param.ScoreColumnName,
21+
batchSize: param.BatchSize,
22+
maxEpochs: param.MaxEpochs,
23+
topK: param.TopKAnswers,
24+
architecture: BertArchitecture.Roberta);
25+
}
26+
27+
}
28+
}

0 commit comments

Comments
 (0)