diff --git a/build/Dependencies.props b/build/Dependencies.props
index 197fa167b2..d2a48f6865 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -8,7 +8,6 @@
4.5.1
4.3.0
4.8.0
- 4.5.0
diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md
index da770fd20b..f205684d3a 100644
--- a/docs/code/MlNetCookBook.md
+++ b/docs/code/MlNetCookBook.md
@@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function
- It should not have side effects (we may call it arbitrarily at any time, or omit the call)
One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it.
-At loading time, you will need to reconstruct the custom transformer and inject it into MLContext.
+At loading time, you will need to register the custom transformer with the MLContext.
Here is a complete example that saves and loads a model with a custom mapping.
```csharp
///
-/// One class that contains all custom mappings that we need for our model.
+/// One class that contains the custom mapping functionality that we need for our model.
+///
+/// It has a on it and
+/// derives from .
///
-public class CustomMappings
+[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
+public class CustomMappings : CustomMappingFactory
{
// This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
- // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
- // this property.
- [Import]
- public MLContext MLContext { get; set; }
-
- // We are exporting the custom transformer by the name 'IncomeMapping'.
- [Export(nameof(IncomeMapping))]
- public ITransformer MyCustomTransformer
- => MLContext.Transforms.CustomMappingTransformer(IncomeMapping, nameof(IncomeMapping));
+ // This factory method will be called when loading the model to get the mapping operation.
+ public override Action GetMapping()
+ {
+ return IncomeMapping;
+ }
}
```
@@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath))
// Now pretend we are in a different process.
-// Create a custom composition container for all our custom mapping actions.
-newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
+// Register the assembly that contains 'CustomMappings' with the ComponentCatalog
+// so it can be found when loading the model.
+newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
// Now we can load the model.
ITransformer loadedModel;
diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
index 9e92abc840..1443a9f6b0 100644
--- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
+++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
@@ -15,7 +15,6 @@
-
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index 1dd10e9f39..0c28b6c099 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -35,6 +35,8 @@ internal ComponentCatalog()
_entryPointMap = new Dictionary();
_componentMap = new Dictionary();
_components = new List();
+
+ _extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>();
}
///
@@ -404,6 +406,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo
private readonly List _components;
private readonly Dictionary _componentMap;
+ private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap;
+
private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes,
out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment)
{
@@ -618,6 +622,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)
AddClass(info, attr.LoadNames, throwOnError);
}
+
+ LoadExtensions(assembly, throwOnError);
}
}
}
@@ -980,5 +986,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set
if (errorMsg != null)
throw Contracts.Except(errorMsg);
}
+
+ private void LoadExtensions(Assembly assembly, bool throwOnError)
+ {
+ // don't waste time looking through all the types of an assembly
+ // that can't contain extensions
+ if (CanContainExtensions(assembly))
+ {
+ foreach (Type type in assembly.GetTypes())
+ {
+ if (type.IsClass)
+ {
+ foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute)))
+ {
+ var key = (AttributeType: attribute.GetType(), attribute.ContractName);
+ if (_extensionsMap.TryGetValue(key, out var existingType))
+ {
+ if (throwOnError)
+ {
+ throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog.");
+ }
+ }
+ else
+ {
+ _extensionsMap.Add(key, type);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ ///
+ /// Gets a value indicating whether can contain extensions.
+ ///
+ ///
+ /// All ML.NET product assemblies won't contain extensions.
+ ///
+ private static bool CanContainExtensions(Assembly assembly)
+ {
+ if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal)
+ && HasMLNetPublicKey(assembly))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ private static bool HasMLNetPublicKey(Assembly assembly)
+ {
+ return assembly.GetName().GetPublicKey().SequenceEqual(
+ typeof(ComponentCatalog).Assembly.GetName().GetPublicKey());
+ }
+
+ [BestFriend]
+ internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName)
+ {
+ object exportedValue = null;
+ if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType))
+ {
+ exportedValue = Activator.CreateInstance(extensionType);
+ }
+
+ if (exportedValue == null)
+ {
+ throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'.");
+ }
+
+ return exportedValue;
+ }
}
}
diff --git a/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs
new file mode 100644
index 0000000000..7d8d00a252
--- /dev/null
+++ b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs
@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML
+{
+ ///
+ /// The base attribute type for all attributes used for extensibility purposes.
+ ///
+ [AttributeUsage(AttributeTargets.Class)]
+ public abstract class ExtensionBaseAttribute : Attribute
+ {
+ public string ContractName { get; }
+
+ [BestFriend]
+ private protected ExtensionBaseAttribute(string contractName)
+ {
+ ContractName = contractName;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index 0f095d02c5..78cd697811 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.ComponentModel.Composition.Hosting;
namespace Microsoft.ML
{
@@ -92,12 +91,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
[Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
"Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")]
IFileHandle CreateTempFile(string suffix = null, string prefix = null);
-
- ///
- /// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model
- /// is being loaded, or the components are otherwise created via dependency injection.
- ///
- CompositionContainer GetCompositionContainer();
}
///
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 89ce4503c4..1127e9b115 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -5,7 +5,6 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
-using System.ComponentModel.Composition.Hosting;
using System.IO;
namespace Microsoft.ML.Data
@@ -632,7 +631,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
else if (!removeLastNewLine)
writer.WriteLine();
}
-
- public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer();
}
}
diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
index 0d6b288499..ccd18a42b7 100644
--- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
+++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
@@ -12,7 +12,6 @@
-
diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs
index a466e37aa3..39281b36d6 100644
--- a/src/Microsoft.ML.Data/MLContext.cs
+++ b/src/Microsoft.ML.Data/MLContext.cs
@@ -3,8 +3,6 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.ComponentModel.Composition;
-using System.ComponentModel.Composition.Hosting;
using Microsoft.ML.Data;
namespace Microsoft.ML
@@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment
public event EventHandler Log;
///
- /// This is a MEF composition container catalog to be used for model loading.
+ /// This is a catalog of components that will be used for model loading.
///
- public CompositionContainer CompositionContainer { get; set; }
+ public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;
///
/// Create the ML context.
@@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment
/// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.
public MLContext(int? seed = null, int conc = 0)
{
- _env = new LocalEnvironment(seed, conc, MakeCompositionContainer);
+ _env = new LocalEnvironment(seed, conc);
_env.AddListener(ProcessMessage);
BinaryClassification = new BinaryClassificationCatalog(_env);
@@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0)
Data = new DataOperationsCatalog(_env);
}
- private CompositionContainer MakeCompositionContainer()
- {
- if (CompositionContainer == null)
- return null;
-
- var mlContext = CompositionContainer.GetExportedValueOrDefault();
- if (mlContext == null)
- CompositionContainer.ComposeExportedValue(this);
-
- return CompositionContainer;
- }
-
private void ProcessMessage(IMessageSource source, ChannelMessage message)
{
var log = Log;
@@ -120,7 +106,6 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
- ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
string IExceptionContext.ContextDescription => _env.ContextDescription;
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
TException IExceptionContext.Process(TException ex) => _env.Process(ex);
@@ -128,6 +113,5 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe IChannelProvider.StartPipe(string name) => _env.StartPipe(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
- CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer();
}
}
diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
index a9e3a4a80e..b17b2cf39f 100644
--- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
+++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
@@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.ComponentModel.Composition.Hosting;
namespace Microsoft.ML.Data
{
@@ -14,8 +13,6 @@ namespace Microsoft.ML.Data
///
internal sealed class LocalEnvironment : HostEnvironmentBase
{
- private readonly Func _compositionContainerFactory;
-
private sealed class Channel : ChannelBase
{
public readonly Stopwatch Watch;
@@ -49,11 +46,9 @@ protected override void Dispose(bool disposing)
///
/// Random seed. Set to null for a non-deterministic environment.
/// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.
- /// The function to retrieve the composition container
- public LocalEnvironment(int? seed = null, int conc = 0, Func compositionContainerFactory = null)
+ public LocalEnvironment(int? seed = null, int conc = 0)
: base(RandomUtils.Create(seed), verbose: false, conc)
{
- _compositionContainerFactory = compositionContainerFactory;
}
///
@@ -96,13 +91,6 @@ protected override IPipe CreatePipe(ChannelProviderBase pare
return new Pipe(parent, name, GetDispatchDelegate());
}
- public override CompositionContainer GetCompositionContainer()
- {
- if (_compositionContainerFactory != null)
- return _compositionContainerFactory();
- return base.GetCompositionContainer();
- }
-
private sealed class Host : HostBase
{
public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
diff --git a/src/Microsoft.ML.Transforms/CustomMappingFactory.cs b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs
new file mode 100644
index 0000000000..f1778f72f7
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs
@@ -0,0 +1,47 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Data.DataView;
+
+namespace Microsoft.ML.Transforms
+{
+ ///
+ /// Place this attribute onto a type to cause it to be considered a custom mapping factory.
+ ///
+ [AttributeUsage(AttributeTargets.Class)]
+ public sealed class CustomMappingFactoryAttributeAttribute : ExtensionBaseAttribute
+ {
+ public CustomMappingFactoryAttributeAttribute(string contractName)
+ : base(contractName)
+ {
+ }
+ }
+
+ internal interface ICustomMappingFactory
+ {
+ ITransformer CreateTransformer(IHostEnvironment env, string contractName);
+ }
+
+ ///
+ /// The base type for custom mapping factories.
+ ///
+ /// The type that describes what 'source' columns are consumed from the input .
+ /// The type that describes what new columns are added by this transform.
+ public abstract class CustomMappingFactory : ICustomMappingFactory
+ where TSrc : class, new()
+ where TDst : class, new()
+ {
+ ///
+ /// Returns the mapping delegate that maps from inputs to outputs.
+ ///
+ public abstract Action GetMapping();
+
+ ITransformer ICustomMappingFactory.CreateTransformer(IHostEnvironment env, string contractName)
+ {
+ Action mapAction = GetMapping();
+ return new CustomMappingTransformer(env, mapAction, contractName);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs
index 2f847797f7..6950d9f46a 100644
--- a/src/Microsoft.ML.Transforms/LambdaTransform.cs
+++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs
@@ -68,11 +68,13 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
var contractName = ctx.LoadString();
- var composition = env.GetCompositionContainer();
- if (composition == null)
- throw Contracts.Except("Unable to get the MEF composition container");
- ITransformer transformer = composition.GetExportedValue(contractName);
- return transformer;
+ object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName);
+ if (!(factoryObject is ICustomMappingFactory mappingFactory))
+ {
+ throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}'.");
+ }
+
+ return mappingFactory.CreateTransformer(env, contractName);
}
///
diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
index 4cf6548d6c..216f2c3179 100644
--- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
+++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
@@ -38,7 +38,6 @@
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll"
#r "System"
-#r "System.ComponentModel.Composition"
#r "System.Core"
#r "System.Xml.Linq"
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
index ce9a9e0385..cf37e1bb70 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
@@ -4,8 +4,6 @@
using System;
using System.Collections.Generic;
-using System.ComponentModel.Composition;
-using System.ComponentModel.Composition.Hosting;
using System.IO;
using System.Linq;
using Microsoft.Data.DataView;
@@ -13,6 +11,7 @@
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers;
+using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Categorical;
using Microsoft.ML.Transforms.Normalizers;
using Microsoft.ML.Transforms.Text;
@@ -491,22 +490,22 @@ public void CustomTransformer()
}
///
- /// One class that contains all custom mappings that we need for our model.
+ /// One class that contains the custom mapping functionality that we need for our model.
+ ///
+ /// It has a on it and
+ /// derives from .
///
- public class CustomMappings
+ [CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
+ public class CustomMappings : CustomMappingFactory
{
// This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
- // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
- // this property.
- [Import]
- public MLContext MLContext { get; set; }
-
- // We are exporting the custom transformer by the name 'IncomeMapping'.
- [Export(nameof(IncomeMapping))]
- public ITransformer MyCustomTransformer
- => MLContext.Transforms.CustomMappingTransformer(IncomeMapping, nameof(IncomeMapping));
+ // This factory method will be called when loading the model to get the mapping operation.
+ public override Action GetMapping()
+ {
+ return IncomeMapping;
+ }
}
private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string modelPath)
@@ -530,8 +529,9 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string
// Now pretend we are in a different process.
var newContext = new MLContext();
- // Create a custom composition container for all our custom mapping actions.
- newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
+ // Register the assembly that contains 'CustomMappings' with the ComponentCatalog
+ // so it can be found when loading the model.
+ newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
// Now we can load the model.
ITransformer loadedModel;
diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
index 9a1c048ebf..3dcfcc991a 100644
--- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
@@ -3,8 +3,6 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.ComponentModel.Composition;
-using System.ComponentModel.Composition.Hosting;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
@@ -32,18 +30,18 @@ public class MyOutput
public string Together { get; set; }
}
- public class MyLambda
+ [CustomMappingFactoryAttribute("MyLambda")]
+ public class MyLambda : CustomMappingFactory
{
- [Export("MyLambda")]
- public ITransformer MyTransformer => ML.Transforms.CustomMappingTransformer(MyAction, "MyLambda");
-
- [Import]
- public MLContext ML { get; set; }
-
public static void MyAction(MyInput input, MyOutput output)
{
output.Together = $"{input.Float1} + {string.Join(", ", input.Float4)}";
}
+
+ public override Action GetMapping()
+ {
+ return MyAction;
+ }
}
[Fact]
@@ -67,14 +65,14 @@ public void TestCustomTransformer()
try
{
TestEstimatorCore(customEst, data);
- Assert.True(false, "Cannot work without MEF injection");
+ Assert.True(false, "Cannot work without RegisterAssembly");
}
catch (InvalidOperationException ex)
{
if (!ex.IsMarked())
throw;
}
- ML.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(MyLambda)));
+ ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly);
TestEstimatorCore(customEst, data);
transformedData = customEst.Fit(data).Transform(data);