diff --git a/build/Dependencies.props b/build/Dependencies.props index 197fa167b2..d2a48f6865 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -8,7 +8,6 @@ 4.5.1 4.3.0 4.8.0 - 4.5.0 diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index da770fd20b..f205684d3a 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function - It should not have side effects (we may call it arbitrarily at any time, or omit the call) One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it. -At loading time, you will need to reconstruct the custom transformer and inject it into MLContext. +At loading time, you will need to register the custom transformer with the MLContext. Here is a complete example that saves and loads a model with a custom mapping. ```csharp /// -/// One class that contains all custom mappings that we need for our model. +/// One class that contains the custom mapping functionality that we need for our model. +/// +/// It has a on it and +/// derives from . /// -public class CustomMappings +[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))] +public class CustomMappings : CustomMappingFactory { // This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading. public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000; - // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate - // this property. - [Import] - public MLContext MLContext { get; set; } - - // We are exporting the custom transformer by the name 'IncomeMapping'. - [Export(nameof(IncomeMapping))] - public ITransformer MyCustomTransformer - => MLContext.Transforms.CustomMappingTransformer(IncomeMapping, nameof(IncomeMapping)); + // This factory method will be called when loading the model to get the mapping operation. + public override Action GetMapping() + { + return IncomeMapping; + } } ``` @@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath)) // Now pretend we are in a different process. -// Create a custom composition container for all our custom mapping actions. -newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings))); +// Register the assembly that contains 'CustomMappings' with the ComponentCatalog +// so it can be found when loading the model. +newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly); // Now we can load the model. ITransformer loadedModel; diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj index 9e92abc840..1443a9f6b0 100644 --- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj +++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj @@ -15,7 +15,6 @@ - diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index 1dd10e9f39..0c28b6c099 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -35,6 +35,8 @@ internal ComponentCatalog() _entryPointMap = new Dictionary(); _componentMap = new Dictionary(); _components = new List(); + + _extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>(); } /// @@ -404,6 +406,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo private readonly List _components; private readonly Dictionary _componentMap; + private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap; + private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes, out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment) { @@ -618,6 +622,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true) AddClass(info, attr.LoadNames, throwOnError); } + + LoadExtensions(assembly, throwOnError); } } } @@ -980,5 +986,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set if (errorMsg != null) throw Contracts.Except(errorMsg); } + + private void LoadExtensions(Assembly assembly, bool throwOnError) + { + // don't waste time looking through all the types of an assembly + // that can't contain extensions + if (CanContainExtensions(assembly)) + { + foreach (Type type in assembly.GetTypes()) + { + if (type.IsClass) + { + foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute))) + { + var key = (AttributeType: attribute.GetType(), attribute.ContractName); + if (_extensionsMap.TryGetValue(key, out var existingType)) + { + if (throwOnError) + { + throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog."); + } + } + else + { + _extensionsMap.Add(key, type); + } + } + } + } + } + } + + /// + /// Gets a value indicating whether can contain extensions. + /// + /// + /// All ML.NET product assemblies won't contain extensions. + /// + private static bool CanContainExtensions(Assembly assembly) + { + if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal) + && HasMLNetPublicKey(assembly)) + { + return false; + } + + return true; + } + + private static bool HasMLNetPublicKey(Assembly assembly) + { + return assembly.GetName().GetPublicKey().SequenceEqual( + typeof(ComponentCatalog).Assembly.GetName().GetPublicKey()); + } + + [BestFriend] + internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName) + { + object exportedValue = null; + if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType)) + { + exportedValue = Activator.CreateInstance(extensionType); + } + + if (exportedValue == null) + { + throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'."); + } + + return exportedValue; + } } } diff --git a/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs new file mode 100644 index 0000000000..7d8d00a252 --- /dev/null +++ b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML +{ + /// + /// The base attribute type for all attributes used for extensibility purposes. + /// + [AttributeUsage(AttributeTargets.Class)] + public abstract class ExtensionBaseAttribute : Attribute + { + public string ContractName { get; } + + [BestFriend] + private protected ExtensionBaseAttribute(string contractName) + { + ContractName = contractName; + } + } +} diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index 0f095d02c5..78cd697811 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition.Hosting; namespace Microsoft.ML { @@ -92,12 +91,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider [Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " + "Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")] IFileHandle CreateTempFile(string suffix = null, string prefix = null); - - /// - /// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model - /// is being loaded, or the components are otherwise created via dependency injection. - /// - CompositionContainer GetCompositionContainer(); } /// diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 89ce4503c4..1127e9b115 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.ComponentModel.Composition.Hosting; using System.IO; namespace Microsoft.ML.Data @@ -632,7 +631,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo else if (!removeLastNewLine) writer.WriteLine(); } - - public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer(); } } diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj index 0d6b288499..ccd18a42b7 100644 --- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj +++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj @@ -12,7 +12,6 @@ - diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index a466e37aa3..39281b36d6 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition; -using System.ComponentModel.Composition.Hosting; using Microsoft.ML.Data; namespace Microsoft.ML @@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment public event EventHandler Log; /// - /// This is a MEF composition container catalog to be used for model loading. + /// This is a catalog of components that will be used for model loading. /// - public CompositionContainer CompositionContainer { get; set; } + public ComponentCatalog ComponentCatalog => _env.ComponentCatalog; /// /// Create the ML context. @@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment /// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically. public MLContext(int? seed = null, int conc = 0) { - _env = new LocalEnvironment(seed, conc, MakeCompositionContainer); + _env = new LocalEnvironment(seed, conc); _env.AddListener(ProcessMessage); BinaryClassification = new BinaryClassificationCatalog(_env); @@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0) Data = new DataOperationsCatalog(_env); } - private CompositionContainer MakeCompositionContainer() - { - if (CompositionContainer == null) - return null; - - var mlContext = CompositionContainer.GetExportedValueOrDefault(); - if (mlContext == null) - CompositionContainer.ComposeExportedValue(this); - - return CompositionContainer; - } - private void ProcessMessage(IMessageSource source, ChannelMessage message) { var log = Log; @@ -120,7 +106,6 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor; bool IHostEnvironment.IsCancelled => _env.IsCancelled; - ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog; string IExceptionContext.ContextDescription => _env.ContextDescription; IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix); TException IExceptionContext.Process(TException ex) => _env.Process(ex); @@ -128,6 +113,5 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) IChannel IChannelProvider.Start(string name) => _env.Start(name); IPipe IChannelProvider.StartPipe(string name) => _env.StartPipe(name); IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); - CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer(); } } diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs index a9e3a4a80e..b17b2cf39f 100644 --- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs +++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition.Hosting; namespace Microsoft.ML.Data { @@ -14,8 +13,6 @@ namespace Microsoft.ML.Data /// internal sealed class LocalEnvironment : HostEnvironmentBase { - private readonly Func _compositionContainerFactory; - private sealed class Channel : ChannelBase { public readonly Stopwatch Watch; @@ -49,11 +46,9 @@ protected override void Dispose(bool disposing) /// /// Random seed. Set to null for a non-deterministic environment. /// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically. - /// The function to retrieve the composition container - public LocalEnvironment(int? seed = null, int conc = 0, Func compositionContainerFactory = null) + public LocalEnvironment(int? seed = null, int conc = 0) : base(RandomUtils.Create(seed), verbose: false, conc) { - _compositionContainerFactory = compositionContainerFactory; } /// @@ -96,13 +91,6 @@ protected override IPipe CreatePipe(ChannelProviderBase pare return new Pipe(parent, name, GetDispatchDelegate()); } - public override CompositionContainer GetCompositionContainer() - { - if (_compositionContainerFactory != null) - return _compositionContainerFactory(); - return base.GetCompositionContainer(); - } - private sealed class Host : HostBase { public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) diff --git a/src/Microsoft.ML.Transforms/CustomMappingFactory.cs b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs new file mode 100644 index 0000000000..f1778f72f7 --- /dev/null +++ b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Data.DataView; + +namespace Microsoft.ML.Transforms +{ + /// + /// Place this attribute onto a type to cause it to be considered a custom mapping factory. + /// + [AttributeUsage(AttributeTargets.Class)] + public sealed class CustomMappingFactoryAttributeAttribute : ExtensionBaseAttribute + { + public CustomMappingFactoryAttributeAttribute(string contractName) + : base(contractName) + { + } + } + + internal interface ICustomMappingFactory + { + ITransformer CreateTransformer(IHostEnvironment env, string contractName); + } + + /// + /// The base type for custom mapping factories. + /// + /// The type that describes what 'source' columns are consumed from the input . + /// The type that describes what new columns are added by this transform. + public abstract class CustomMappingFactory : ICustomMappingFactory + where TSrc : class, new() + where TDst : class, new() + { + /// + /// Returns the mapping delegate that maps from inputs to outputs. + /// + public abstract Action GetMapping(); + + ITransformer ICustomMappingFactory.CreateTransformer(IHostEnvironment env, string contractName) + { + Action mapAction = GetMapping(); + return new CustomMappingTransformer(env, mapAction, contractName); + } + } +} diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 2f847797f7..6950d9f46a 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -68,11 +68,13 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) var contractName = ctx.LoadString(); - var composition = env.GetCompositionContainer(); - if (composition == null) - throw Contracts.Except("Unable to get the MEF composition container"); - ITransformer transformer = composition.GetExportedValue(contractName); - return transformer; + object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName); + if (!(factoryObject is ICustomMappingFactory mappingFactory)) + { + throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}'."); + } + + return mappingFactory.CreateTransformer(env, contractName); } /// diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 4cf6548d6c..216f2c3179 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -38,7 +38,6 @@ #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll" #r "System" -#r "System.ComponentModel.Composition" #r "System.Core" #r "System.Xml.Linq" diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index ce9a9e0385..cf37e1bb70 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -4,8 +4,6 @@ using System; using System.Collections.Generic; -using System.ComponentModel.Composition; -using System.ComponentModel.Composition.Hosting; using System.IO; using System.Linq; using Microsoft.Data.DataView; @@ -13,6 +11,7 @@ using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Normalizers; using Microsoft.ML.Transforms.Text; @@ -491,22 +490,22 @@ public void CustomTransformer() } /// - /// One class that contains all custom mappings that we need for our model. + /// One class that contains the custom mapping functionality that we need for our model. + /// + /// It has a on it and + /// derives from . /// - public class CustomMappings + [CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))] + public class CustomMappings : CustomMappingFactory { // This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading. public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000; - // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate - // this property. - [Import] - public MLContext MLContext { get; set; } - - // We are exporting the custom transformer by the name 'IncomeMapping'. - [Export(nameof(IncomeMapping))] - public ITransformer MyCustomTransformer - => MLContext.Transforms.CustomMappingTransformer(IncomeMapping, nameof(IncomeMapping)); + // This factory method will be called when loading the model to get the mapping operation. + public override Action GetMapping() + { + return IncomeMapping; + } } private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string modelPath) @@ -530,8 +529,9 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Now pretend we are in a different process. var newContext = new MLContext(); - // Create a custom composition container for all our custom mapping actions. - newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings))); + // Register the assembly that contains 'CustomMappings' with the ComponentCatalog + // so it can be found when loading the model. + newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly); // Now we can load the model. ITransformer loadedModel; diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs index 9a1c048ebf..3dcfcc991a 100644 --- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition; -using System.ComponentModel.Composition.Hosting; using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Data; @@ -32,18 +30,18 @@ public class MyOutput public string Together { get; set; } } - public class MyLambda + [CustomMappingFactoryAttribute("MyLambda")] + public class MyLambda : CustomMappingFactory { - [Export("MyLambda")] - public ITransformer MyTransformer => ML.Transforms.CustomMappingTransformer(MyAction, "MyLambda"); - - [Import] - public MLContext ML { get; set; } - public static void MyAction(MyInput input, MyOutput output) { output.Together = $"{input.Float1} + {string.Join(", ", input.Float4)}"; } + + public override Action GetMapping() + { + return MyAction; + } } [Fact] @@ -67,14 +65,14 @@ public void TestCustomTransformer() try { TestEstimatorCore(customEst, data); - Assert.True(false, "Cannot work without MEF injection"); + Assert.True(false, "Cannot work without RegisterAssembly"); } catch (InvalidOperationException ex) { if (!ex.IsMarked()) throw; } - ML.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(MyLambda))); + ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); TestEstimatorCore(customEst, data); transformedData = customEst.Fit(data).Transform(data);