Skip to content

Commit 56823fc

Browse files
committed
fixed mac build and minor torch sharp changes
1 parent 65c7ca9 commit 56823fc

File tree

5 files changed

+30
-10
lines changed

5 files changed

+30
-10
lines changed

build/vsts-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ stages:
100100
pool:
101101
vmImage: macOS-12
102102
steps:
103-
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && brew update && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
103+
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
104104
displayName: Install build dependencies
105105
# Only build native assets to avoid conflicts.
106106
- script: ./build.sh -projects $(Build.SourcesDirectory)/src/Native/Native.proj -configuration $(BuildConfig) /p:TargetArchitecture=x64 /p:CopyPackageAssets=true

src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,19 @@ namespace Microsoft.ML.TorchSharp.NasBert
5858
///
5959
public class SentenceSimilarityTrainer : NasBertTrainer<float, float>
6060
{
61-
internal SentenceSimilarityTrainer(IHostEnvironment env, Options options) : base(env, options)
61+
62+
public class SentenceSimilarityOptions : NasBertOptions
63+
{
64+
public SentenceSimilarityOptions()
65+
{
66+
BatchSize = 32;
67+
MaxEpoch = 10;
68+
TaskType = BertTaskType.SentenceRegression;
69+
LearningRate = new List<double>() { .0002 };
70+
WeightDecay = .01;
71+
}
72+
}
73+
internal SentenceSimilarityTrainer(IHostEnvironment env, SentenceSimilarityOptions options) : base(env, options)
6274
{
6375
}
6476

@@ -71,7 +83,7 @@ internal SentenceSimilarityTrainer(IHostEnvironment env,
7183
int maxEpochs = 10,
7284
IDataView validationSet = null,
7385
BertArchitecture architecture = BertArchitecture.Roberta) :
74-
this(env, new NasBertOptions
86+
this(env, new SentenceSimilarityOptions
7587
{
7688
ScoreColumnName = scoreColumnName,
7789
Sentence1ColumnName = sentence1ColumnName,

src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,17 @@ namespace Microsoft.ML.TorchSharp.NasBert
6060
///
6161
public class TextClassificationTrainer : NasBertTrainer<UInt32, long>
6262
{
63-
internal TextClassificationTrainer(IHostEnvironment env, NasBertOptions options) : base(env, options)
63+
public class TextClassificationOptions : NasBertTrainer.NasBertOptions
64+
{
65+
public TextClassificationOptions()
66+
{
67+
TaskType = BertTaskType.TextClassification;
68+
BatchSize = 32;
69+
MaxEpoch = 10;
70+
}
71+
}
72+
73+
internal TextClassificationTrainer(IHostEnvironment env, TextClassificationOptions options) : base(env, options)
6474
{
6575
}
6676

@@ -74,7 +84,7 @@ internal TextClassificationTrainer(IHostEnvironment env,
7484
int maxEpochs = 10,
7585
IDataView validationSet = null,
7686
BertArchitecture architecture = BertArchitecture.Roberta) :
77-
this(env, new NasBertOptions
87+
this(env, new TextClassificationOptions
7888
{
7989
PredictionColumnName = predictionColumnName,
8090
ScoreColumnName = scoreColumnName,

src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public static TextClassificationTrainer TextClassification(
5959
/// <returns></returns>
6060
public static TextClassificationTrainer TextClassification(
6161
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
62-
NasBertTrainer.NasBertOptions options)
62+
TextClassificationTrainer.TextClassificationOptions options)
6363
=> new TextClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options);
6464

6565
/// <summary>
@@ -99,7 +99,7 @@ public static SentenceSimilarityTrainer SentenceSimilarity(
9999
/// <returns></returns>
100100
public static SentenceSimilarityTrainer SentenceSimilarity(
101101
this RegressionCatalog.RegressionTrainers catalog,
102-
NasBertTrainer.NasBertOptions options)
102+
SentenceSimilarityTrainer.SentenceSimilarityOptions options)
103103
=> new SentenceSimilarityTrainer(CatalogUtils.GetEnvironment(catalog), options);
104104

105105

test/Microsoft.ML.Tests/TextClassificationTests.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,12 @@ public void TestSentenceSimilarityLargeFileGpu()
432432

433433
var dataSplit = ML.Data.TrainTestSplit(dataView, testFraction: 0.2);
434434

435-
var options = new NasBertTrainer.NasBertOptions()
435+
var options = new SentenceSimilarityTrainer.SentenceSimilarityOptions()
436436
{
437437
TaskType = BertTaskType.SentenceRegression,
438438
Sentence1ColumnName = "search_term",
439439
Sentence2ColumnName = "product_description",
440440
LabelColumnName = "relevance",
441-
LearningRate = new List<double>() { .0002 },
442-
WeightDecay = .01
443441
};
444442

445443
var estimator = ML.Regression.Trainers.SentenceSimilarity(options);

0 commit comments

Comments
 (0)