Skip to content

Commit b043695

Browse files
suhsteveimback82
authored andcommitted
SerDe referenced assemblies (#180)
1 parent 972825c commit b043695

File tree

9 files changed

+304
-161
lines changed

9 files changed

+304
-161
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
7+
namespace Microsoft.Spark.E2ETest.ExternalLibrary
8+
{
9+
[Serializable]
10+
public class ExternalClass
11+
{
12+
private string _s;
13+
14+
public ExternalClass(string s)
15+
{
16+
_s = s;
17+
}
18+
19+
public static string HelloWorld()
20+
{
21+
return "Hello World";
22+
}
23+
24+
public string Concat(string s)
25+
{
26+
return _s + s;
27+
}
28+
}
29+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<PublicSign>false</PublicSign>
6+
</PropertyGroup>
7+
8+
</Project>

src/csharp/Microsoft.Spark.E2ETest/Microsoft.Spark.E2ETest.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
</ItemGroup>
1212

1313
<ItemGroup>
14+
<ProjectReference Include="..\Microsoft.Spark.E2ETest.ExternalLibrary\Microsoft.Spark.E2ETest.ExternalLibrary.csproj" />
1415
<ProjectReference Include="..\Microsoft.Spark.Experimental\Microsoft.Spark.Experimental.csproj" />
1516
<ProjectReference Include="..\Microsoft.Spark.Worker\Microsoft.Spark.Worker.csproj" />
1617
<ProjectReference Include="..\Microsoft.Spark\Microsoft.Spark.csproj" />
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.Spark.E2ETest.ExternalLibrary;
9+
using Microsoft.Spark.Sql;
10+
using Xunit;
11+
using static Microsoft.Spark.Sql.Functions;
12+
13+
namespace Microsoft.Spark.E2ETest.UdfTests
14+
{
15+
[Collection("Spark E2E Tests")]
16+
public class UdfSerDeTests
17+
{
18+
private readonly SparkSession _spark;
19+
private readonly DataFrame _df;
20+
21+
public UdfSerDeTests(SparkFixture fixture)
22+
{
23+
_spark = fixture.Spark;
24+
_df = _spark
25+
.Read()
26+
.Schema("age INT, name STRING")
27+
.Json($"{TestEnvironment.ResourceDirectory}people.json");
28+
}
29+
30+
[Fact]
31+
public void TestUdfClosure()
32+
{
33+
var ec = new ExternalClass("Hello");
34+
Func<Column, Column> udf = Udf<string, string>(
35+
(str) =>
36+
{
37+
return ec.Concat(str);
38+
});
39+
40+
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
41+
Assert.Equal(3, rows.Length);
42+
43+
var expected = new[] { "HelloMichael", "HelloAndy", "HelloJustin" };
44+
for (int i = 0; i < rows.Length; ++i)
45+
{
46+
Row row = rows[i];
47+
Assert.Equal(1, row.Size());
48+
Assert.Equal(expected[i], row.GetAs<string>(0));
49+
}
50+
}
51+
52+
[Fact]
53+
public void TestExternalStaticMethodCall()
54+
{
55+
Func<Column, Column> udf = Udf<string, string>(str =>
56+
{
57+
return ExternalClass.HelloWorld();
58+
});
59+
60+
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
61+
Assert.Equal(3, rows.Length);
62+
63+
for (int i = 0; i < rows.Length; ++i)
64+
{
65+
Row row = rows[i];
66+
Assert.Equal(1, row.Size());
67+
Assert.Equal("Hello World", row.GetAs<string>(0));
68+
}
69+
}
70+
71+
[Fact]
72+
public void TestInitExternalClassInUdf()
73+
{
74+
// Instantiate external assembly class within body of Udf.
75+
Func<Column, Column> udf = Udf<string, string>(
76+
(str) =>
77+
{
78+
var ec = new ExternalClass("Hello");
79+
return ec.Concat(str);
80+
});
81+
82+
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
83+
Assert.Equal(3, rows.Length);
84+
85+
var expected = new[] { "HelloMichael", "HelloAndy", "HelloJustin" };
86+
for (int i = 0; i < rows.Length; ++i)
87+
{
88+
Row row = rows[i];
89+
Assert.Equal(1, row.Size());
90+
Assert.Equal(expected[i], row.GetAs<string>(0));
91+
}
92+
}
93+
}
94+
}

src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using static Microsoft.Spark.Utils.UdfUtils;
1111

1212
#if NETCOREAPP
13+
using System.Reflection;
1314
using System.Runtime.Loader;
1415
#endif
1516

@@ -19,12 +20,19 @@ internal sealed class CommandProcessor
1920
{
2021
private readonly Version _version;
2122

22-
#if NETCOREAPP
2323
static CommandProcessor()
2424
{
25-
UdfSerDe.AssemblyLoader = AssemblyLoadContext.Default.LoadFromAssemblyPath;
26-
}
25+
#if NETCOREAPP
26+
AssemblyLoader.LoadFromFile = AssemblyLoadContext.Default.LoadFromAssemblyPath;
27+
AssemblyLoader.LoadFromName = (asmName) =>
28+
AssemblyLoadContext.Default.LoadFromAssemblyName(new AssemblyName(asmName));
29+
AssemblyLoadContext.Default.Resolving += (assemblyLoadContext, assemblyName) =>
30+
AssemblyLoader.ResolveAssembly(assemblyName.FullName);
31+
#else
32+
AppDomain.CurrentDomain.AssemblyResolve += (object sender, ResolveEventArgs args) =>
33+
AssemblyLoader.ResolveAssembly(args.Name);
2734
#endif
35+
}
2836

2937
internal CommandProcessor(Version version)
3038
{

src/csharp/Microsoft.Spark.sln

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Microsoft Visual Studio Solution File, Format Version 12.00
2-
# Visual Studio 15
3-
VisualStudioVersion = 15.0.28010.2046
2+
# Visual Studio Version 16
3+
VisualStudioVersion = 16.0.29009.5
44
MinimumVisualStudioVersion = 10.0.40219.1
55
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Spark", "Microsoft.Spark\Microsoft.Spark.csproj", "{2B4236DD-00A9-4B24-9041-6F9727F3B025}"
66
EndProject
@@ -25,6 +25,8 @@ Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "Microsoft.Spark.FSharp.Exam
2525
EndProject
2626
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Spark.Experimental", "Microsoft.Spark.Experimental\Microsoft.Spark.Experimental.csproj", "{7F276D07-6D94-49E3-A305-4B965DE7A700}"
2727
EndProject
28+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Spark.E2ETest.ExternalLibrary", "Microsoft.Spark.E2ETest.ExternalLibrary\Microsoft.Spark.E2ETest.ExternalLibrary.csproj", "{920A2CC8-4075-41D1-BC30-C496430BD9E4}"
29+
EndProject
2830
Global
2931
GlobalSection(SolutionConfigurationPlatforms) = preSolution
3032
Debug|Any CPU = Debug|Any CPU
@@ -63,6 +65,10 @@ Global
6365
{7F276D07-6D94-49E3-A305-4B965DE7A700}.Debug|Any CPU.Build.0 = Debug|Any CPU
6466
{7F276D07-6D94-49E3-A305-4B965DE7A700}.Release|Any CPU.ActiveCfg = Release|Any CPU
6567
{7F276D07-6D94-49E3-A305-4B965DE7A700}.Release|Any CPU.Build.0 = Release|Any CPU
68+
{920A2CC8-4075-41D1-BC30-C496430BD9E4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
69+
{920A2CC8-4075-41D1-BC30-C496430BD9E4}.Debug|Any CPU.Build.0 = Debug|Any CPU
70+
{920A2CC8-4075-41D1-BC30-C496430BD9E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
71+
{920A2CC8-4075-41D1-BC30-C496430BD9E4}.Release|Any CPU.Build.0 = Release|Any CPU
6672
EndGlobalSection
6773
GlobalSection(SolutionProperties) = preSolution
6874
HideSolutionNode = FALSE
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Reflection;
9+
using System.Runtime.InteropServices;
10+
11+
namespace Microsoft.Spark.Utils
12+
{
13+
internal static class AssemblyLoader
14+
{
15+
internal static Func<string, Assembly> LoadFromFile { get; set; } = Assembly.LoadFrom;
16+
17+
internal static Func<string, Assembly> LoadFromName { get; set; } = Assembly.Load;
18+
19+
private static readonly Dictionary<string, Assembly> s_assemblyCache =
20+
new Dictionary<string, Assembly>();
21+
22+
private static readonly string[] s_searchPaths =
23+
new[] { Directory.GetCurrentDirectory(), AppDomain.CurrentDomain.BaseDirectory };
24+
25+
private static readonly string[] s_extensions =
26+
RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
27+
new[] { ".dll", ".exe", ".ni.dll", ".ni.exe" } :
28+
new[] { ".dll", ".ni.dll" };
29+
30+
private static readonly object s_cacheLock = new object();
31+
32+
/// <summary>
33+
/// Return the cached assembly, otherwise attempt to load and cache the assembly
34+
/// by searching for the assembly filename in the search paths.
35+
/// </summary>
36+
/// <param name="assemblyName">The full name of the assembly</param>
37+
/// <param name="assemblyFileName">Name of the file that contains the assembly</param>
38+
/// <returns>Cached or Loaded Assembly</returns>
39+
/// <exception cref="FileNotFoundException">Thrown if the assembly is not
40+
/// found.</exception>
41+
internal static Assembly LoadAssembly(string assemblyName, string assemblyFileName)
42+
{
43+
lock (s_cacheLock)
44+
{
45+
if (s_assemblyCache.TryGetValue(assemblyName, out Assembly assembly))
46+
{
47+
return assembly;
48+
}
49+
50+
if (TryLoadAssembly(assemblyFileName, ref assembly))
51+
{
52+
s_assemblyCache[assemblyName] = assembly;
53+
return assembly;
54+
}
55+
56+
throw new FileNotFoundException($"Assembly '{assemblyName}' file not found: '{assemblyFileName}'");
57+
}
58+
}
59+
60+
/// <summary>
61+
/// Return the cached assembly, otherwise look in the following probing paths,
62+
/// searching for the simple assembly name and s_extension combination.
63+
/// 1) The working directory
64+
/// 2) The directory of the application
65+
/// </summary>
66+
/// <param name="assemblyName">The fullname of the assembly to load</param>
67+
/// <returns>The loaded assembly</returns>
68+
/// <exception cref="FileNotFoundException">Thrown if the assembly is not
69+
/// found.</exception>
70+
internal static Assembly ResolveAssembly(string assemblyName)
71+
{
72+
lock (s_cacheLock)
73+
{
74+
if (s_assemblyCache.TryGetValue(assemblyName, out Assembly assembly))
75+
{
76+
return assembly;
77+
}
78+
79+
string simpleAsmName = new AssemblyName(assemblyName).Name;
80+
foreach (string extension in s_extensions)
81+
{
82+
string assemblyFileName = $"{simpleAsmName}{extension}";
83+
if (TryLoadAssembly(assemblyFileName, ref assembly))
84+
{
85+
s_assemblyCache[assemblyName] = assembly;
86+
return assembly;
87+
}
88+
}
89+
90+
throw new FileNotFoundException($"Assembly file not found: '{assemblyName}'");
91+
}
92+
}
93+
94+
/// <summary>
95+
/// Returns the loaded assembly by probing the following locations in order:
96+
/// 1) The working directory
97+
/// 2) The directory of the application
98+
/// </summary>
99+
/// <remarks>
100+
/// The probing order is important in cases when spark is launched on
101+
/// YARN. The executors are run inside 'containers' and files that are passed
102+
/// via 'spark-submit --files' will be pushed to these 'containers'. This path
103+
/// is the working directory and the 1st probing path that will be checked.
104+
/// </remarks>
105+
/// <param name="assemblyFileName">Name of the file that contains the assembly</param>
106+
/// <param name="assembly">The loaded assembly.</param>
107+
/// <returns>True if assembly is loaded, false otherwise.</returns>
108+
private static bool TryLoadAssembly(string assemblyFileName, ref Assembly assembly)
109+
{
110+
foreach (string searchPath in s_searchPaths)
111+
{
112+
string assemblyPath = Path.Combine(searchPath, assemblyFileName);
113+
if (File.Exists(assemblyPath))
114+
{
115+
try
116+
{
117+
assembly = LoadFromFile(assemblyPath);
118+
return true;
119+
}
120+
catch (Exception ex) when (
121+
ex is FileLoadException ||
122+
ex is BadImageFormatException)
123+
{
124+
// Ignore invalid assemblies.
125+
}
126+
}
127+
}
128+
129+
return false;
130+
}
131+
}
132+
}

src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ private static void SerializeUdfs(
212212

213213
foreach (UdfSerDe.FieldData field in fields)
214214
{
215-
SerializeUdfs((Delegate)field.ValueData.Value, curNode, udfWrapperNodes, udfs);
215+
SerializeUdfs((Delegate)field.Value, curNode, udfWrapperNodes, udfs);
216216
}
217217
}
218218

0 commit comments

Comments
 (0)