Skip to content

Commit d672e14

Browse files
committed
Stop using System.ComponentModel.Composition
Replace our MEF usage, which is only used by custom mapping transforms, with the ComponentCatalog class. Fix #1595 Fix #2422
1 parent 70830ed commit d672e14

File tree

15 files changed

+193
-89
lines changed

15 files changed

+193
-89
lines changed

build/Dependencies.props

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
<SystemMemoryVersion>4.5.1</SystemMemoryVersion>
99
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
1010
<SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
11-
<SystemComponentModelCompositionVersion>4.5.0</SystemComponentModelCompositionVersion>
1211
</PropertyGroup>
1312

1413
<!-- Other/Non-Core Product Dependencies -->

docs/code/MlNetCookBook.md

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function
970970
- It should not have side effects (we may call it arbitrarily at any time, or omit the call)
971971

972972
One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it.
973-
At loading time, you will need to reconstruct the custom transformer and inject it into MLContext.
973+
At loading time, you will need to register the custom transformer with the MLContext.
974974

975975
Here is a complete example that saves and loads a model with a custom mapping.
976976
```csharp
977977
/// <summary>
978-
/// One class that contains all custom mappings that we need for our model.
978+
/// One class that contains the custom mapping functionality that we need for our model.
979+
///
980+
/// It has a [CustomMappingTransformerFactoryAttribute] on it and
981+
/// derives from CustomMappingTransformerFactory{TSrc, TDst}.
979982
/// </summary>
980-
public class CustomMappings
983+
[CustomMappingTransformerFactory(nameof(CustomMappings.IncomeMapping))]
984+
public class CustomMappings : CustomMappingTransformerFactory<InputRow, OutputRow>
981985
{
982986
// This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
983987
public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
984988

985-
// MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
986-
// this property.
987-
[Import]
988-
public MLContext MLContext { get; set; }
989-
990-
// We are exporting the custom transformer by the name 'IncomeMapping'.
991-
[Export(nameof(IncomeMapping))]
992-
public ITransformer MyCustomTransformer
993-
=> MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping));
989+
// This factory method will be called when loading the model to get the mapping operation.
990+
public override Action<InputRow, OutputRow> GetTransformer()
991+
{
992+
return IncomeMapping;
993+
}
994994
}
995995
```
996996

@@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath))
10131013

10141014
// Now pretend we are in a different process.
10151015
1016-
// Create a custom composition container for all our custom mapping actions.
1017-
newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
1016+
// Register the assembly that contains 'CustomMappings' with the ComponentCatalog
1017+
// so it can be found when loading the model.
1018+
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
10181019

10191020
// Now we can load the model.
10201021
ITransformer loadedModel;

pkg/Microsoft.ML/Microsoft.ML.nupkgproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
1616
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1717
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
18-
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
1918
</ItemGroup>
2019

2120
<ItemGroup>

src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ internal ComponentCatalog()
3535
_entryPointMap = new Dictionary<string, EntryPointInfo>();
3636
_componentMap = new Dictionary<string, ComponentInfo>();
3737
_components = new List<ComponentInfo>();
38+
39+
_extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>();
3840
}
3941

4042
/// <summary>
@@ -395,6 +397,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo
395397
private readonly List<ComponentInfo> _components;
396398
private readonly Dictionary<string, ComponentInfo> _componentMap;
397399

400+
private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap;
401+
398402
private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes,
399403
out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment)
400404
{
@@ -609,6 +613,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)
609613

610614
AddClass(info, attr.LoadNames, throwOnError);
611615
}
616+
617+
LoadExtensions(assembly, throwOnError);
612618
}
613619
}
614620
}
@@ -971,5 +977,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set
971977
if (errorMsg != null)
972978
throw Contracts.Except(errorMsg);
973979
}
980+
981+
private void LoadExtensions(Assembly assembly, bool throwOnError)
982+
{
983+
// don't waste time looking through all the types of an assembly
984+
// that can't contain extensions
985+
if (CanContainExtensions(assembly))
986+
{
987+
foreach (Type type in assembly.GetTypes())
988+
{
989+
if (type.IsClass)
990+
{
991+
foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute)))
992+
{
993+
var key = (AttributeType: attribute.GetType(), attribute.ContractName);
994+
if (_extensionsMap.TryGetValue(key, out var existingType))
995+
{
996+
if (throwOnError)
997+
{
998+
throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog.");
999+
}
1000+
}
1001+
else
1002+
{
1003+
_extensionsMap.Add(key, type);
1004+
}
1005+
}
1006+
}
1007+
}
1008+
}
1009+
}
1010+
1011+
/// <summary>
1012+
/// Gets a value indicating whether <paramref name="assembly"/> can contain extensions.
1013+
/// </summary>
1014+
/// <remarks>
1015+
/// All ML.NET product assemblies won't contain extensions.
1016+
/// </remarks>
1017+
private static bool CanContainExtensions(Assembly assembly)
1018+
{
1019+
if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal)
1020+
&& HasMLNetPublicKey(assembly))
1021+
{
1022+
return false;
1023+
}
1024+
1025+
return true;
1026+
}
1027+
1028+
private static bool HasMLNetPublicKey(Assembly assembly)
1029+
{
1030+
return assembly.GetName().GetPublicKey().SequenceEqual(
1031+
typeof(ComponentCatalog).Assembly.GetName().GetPublicKey());
1032+
}
1033+
1034+
[BestFriend]
1035+
internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName)
1036+
{
1037+
object exportedValue = null;
1038+
if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType))
1039+
{
1040+
exportedValue = Activator.CreateInstance(extensionType);
1041+
}
1042+
1043+
if (exportedValue == null)
1044+
{
1045+
throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'.");
1046+
}
1047+
1048+
return exportedValue;
1049+
}
9741050
}
9751051
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
7+
namespace Microsoft.ML
8+
{
9+
/// <summary>
10+
/// The base attribute type for all attributes used for extensibility purposes.
11+
/// </summary>
12+
[AttributeUsage(AttributeTargets.Class)]
13+
public abstract class ExtensionBaseAttribute : Attribute
14+
{
15+
public string ContractName { get; }
16+
17+
[BestFriend]
18+
private protected ExtensionBaseAttribute(string contractName)
19+
{
20+
ContractName = contractName;
21+
}
22+
}
23+
}

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.ComponentModel.Composition.Hosting;
76

87
namespace Microsoft.ML
98
{
@@ -75,12 +74,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
7574
[Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
7675
"Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")]
7776
IFileHandle CreateTempFile(string suffix = null, string prefix = null);
78-
79-
/// <summary>
80-
/// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model
81-
/// is being loaded, or the components are otherwise created via dependency injection.
82-
/// </summary>
83-
CompositionContainer GetCompositionContainer();
8477
}
8578

8679
/// <summary>

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Concurrent;
77
using System.Collections.Generic;
8-
using System.ComponentModel.Composition.Hosting;
98
using System.IO;
109

1110
namespace Microsoft.ML.Data
@@ -675,7 +674,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
675674
else if (!removeLastNewLine)
676675
writer.WriteLine();
677676
}
678-
679-
public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer();
680677
}
681678
}

src/Microsoft.ML.Core/Microsoft.ML.Core.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
<ProjectReference Include="..\Microsoft.Data.DataView\Microsoft.Data.DataView.csproj" />
1313

1414
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
15-
<PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
1615
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1716
</ItemGroup>
1817

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.ComponentModel.Composition;
7-
using System.ComponentModel.Composition.Hosting;
86
using Microsoft.ML.Data;
97

108
namespace Microsoft.ML
@@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment
6967
public event EventHandler<LoggingEventArgs> Log;
7068

7169
/// <summary>
72-
/// This is a MEF composition container catalog to be used for model loading.
70+
/// This is a catalog of components that will be used for model loading.
7371
/// </summary>
74-
public CompositionContainer CompositionContainer { get; set; }
72+
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;
7573

7674
/// <summary>
7775
/// Create the ML context.
@@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment
8078
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
8179
public MLContext(int? seed = null, int conc = 0)
8280
{
83-
_env = new LocalEnvironment(seed, conc, MakeCompositionContainer);
81+
_env = new LocalEnvironment(seed, conc);
8482
_env.AddListener(ProcessMessage);
8583

8684
BinaryClassification = new BinaryClassificationCatalog(_env);
@@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0)
9492
Data = new DataOperationsCatalog(_env);
9593
}
9694

97-
private CompositionContainer MakeCompositionContainer()
98-
{
99-
if (CompositionContainer == null)
100-
return null;
101-
102-
var mlContext = CompositionContainer.GetExportedValueOrDefault<MLContext>();
103-
if (mlContext == null)
104-
CompositionContainer.ComposeExportedValue<MLContext>(this);
105-
106-
return CompositionContainer;
107-
}
108-
10995
private void ProcessMessage(IMessageSource source, ChannelMessage message)
11096
{
11197
var log = Log;
@@ -120,7 +106,6 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
120106

121107
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
122108
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
123-
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
124109
string IExceptionContext.ContextDescription => _env.ContextDescription;
125110
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path);
126111
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
@@ -130,6 +115,5 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
130115
IChannel IChannelProvider.Start(string name) => _env.Start(name);
131116
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
132117
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
133-
CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer();
134118
}
135119
}

src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.ComponentModel.Composition.Hosting;
76

87
namespace Microsoft.ML.Data
98
{
@@ -14,8 +13,6 @@ namespace Microsoft.ML.Data
1413
/// </summary>
1514
internal sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment>
1615
{
17-
private readonly Func<CompositionContainer> _compositionContainerFactory;
18-
1916
private sealed class Channel : ChannelBase
2017
{
2118
public readonly Stopwatch Watch;
@@ -49,11 +46,9 @@ protected override void Dispose(bool disposing)
4946
/// </summary>
5047
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
5148
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
52-
/// <param name="compositionContainerFactory">The function to retrieve the composition container</param>
53-
public LocalEnvironment(int? seed = null, int conc = 0, Func<CompositionContainer> compositionContainerFactory = null)
49+
public LocalEnvironment(int? seed = null, int conc = 0)
5450
: base(RandomUtils.Create(seed), verbose: false, conc)
5551
{
56-
_compositionContainerFactory = compositionContainerFactory;
5752
}
5853

5954
/// <summary>
@@ -96,13 +91,6 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare
9691
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
9792
}
9893

99-
public override CompositionContainer GetCompositionContainer()
100-
{
101-
if (_compositionContainerFactory != null)
102-
return _compositionContainerFactory();
103-
return base.GetCompositionContainer();
104-
}
105-
10694
private sealed class Host : HostBase
10795
{
10896
public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.Data.DataView;
7+
8+
namespace Microsoft.ML.Transforms
9+
{
10+
/// <summary>
11+
/// Place this attribute onto a type to cause it to be considered a custom mapping transformer factory.
12+
/// </summary>
13+
[AttributeUsage(AttributeTargets.Class)]
14+
public sealed class CustomMappingTransformerFactoryAttribute : ExtensionBaseAttribute
15+
{
16+
public CustomMappingTransformerFactoryAttribute(string contractName)
17+
: base(contractName)
18+
{
19+
}
20+
}
21+
22+
internal interface ICustomMappingTransformerFactory
23+
{
24+
ITransformer CreateTransformerObject(IHostEnvironment env, string contractName);
25+
}
26+
27+
/// <summary>
28+
/// The base type for custom mapping transformer factories.
29+
/// </summary>
30+
/// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the input <see cref="IDataView"/>.</typeparam>
31+
/// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
32+
public abstract class CustomMappingTransformerFactory<TSrc, TDst> : ICustomMappingTransformerFactory
33+
where TSrc : class, new()
34+
where TDst : class, new()
35+
{
36+
public abstract Action<TSrc, TDst> GetTransformer();
37+
38+
ITransformer ICustomMappingTransformerFactory.CreateTransformerObject(IHostEnvironment env, string contractName)
39+
{
40+
Action<TSrc, TDst> mapAction = GetTransformer();
41+
return new CustomMappingTransformer<TSrc, TDst>(env, mapAction, contractName);
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)