diff --git a/src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs b/src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs index 4c42a30373f09b..5473667f294c9c 100644 --- a/src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs +++ b/src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using Xunit; using Xunit.Abstractions; @@ -551,6 +552,72 @@ public void BuildNativeInNonEnglishCulture(BuildArgs buildArgs, string culture, Assert.Contains("square: 25", output); } + [Theory] + [BuildAndRun(host: RunHost.Chrome, parameters: new object[] { new object[] { + "with-hyphen", + "with#hash-and-hyphen", + "with.per.iod", + "with🚀unicode#" + } })] + + public void CallIntoLibrariesWithNonAlphanumericCharactersInTheirNames(BuildArgs buildArgs, string[] libraryNames, RunHost host, string id) + { + buildArgs = ExpandBuildArgs(buildArgs, + extraItems: @$"", + extraProperties: buildArgs.AOT + ? string.Empty + : "true"); + + int baseArg = 10; + (_, string output) = BuildProject(buildArgs, + id: id, + new BuildProjectOptions( + InitProject: () => GenerateSourceFiles(_projectDir!, baseArg), + Publish: buildArgs.AOT, + DotnetWasmFromRuntimePack: false + )); + + output = RunAndTestWasmApp(buildArgs, + buildDir: _projectDir, + expectedExitCode: 42, + host: host, + id: id); + + for (int i = 0; i < libraryNames.Length; i ++) + { + Assert.Contains($"square_{i}: {(i + baseArg) * (i + baseArg)}", output); + } + + void GenerateSourceFiles(string outputPath, int baseArg) + { + StringBuilder csBuilder = new($@" + using System; + using System.Runtime.InteropServices; + "); + + StringBuilder dllImportsBuilder = new(); + for (int i = 0; i < libraryNames.Length; i ++) + { + dllImportsBuilder.AppendLine($"[DllImport(\"{libraryNames[i]}\")] static extern int square_{i}(int x);"); + csBuilder.AppendLine($@"Console.WriteLine($""square_{i}: {{square_{i}({i + baseArg})}}"");"); + + string nativeCode = $@" + #include + + int square_{i}(int x) + {{ + return x * x; + }}"; + File.WriteAllText(Path.Combine(outputPath, $"{libraryNames[i]}.c"), nativeCode); + } + + csBuilder.AppendLine("return 42;"); + csBuilder.Append(dllImportsBuilder); + + File.WriteAllText(Path.Combine(outputPath, "Program.cs"), csBuilder.ToString()); + } + } + private (BuildArgs, string) BuildForVariadicFunctionTests(string programText, BuildArgs buildArgs, string id, string? verbosity = null, string extraProperties = "") { extraProperties += "true<_WasmDevel>true"; diff --git a/src/tasks/WasmAppBuilder/IcallTableGenerator.cs b/src/tasks/WasmAppBuilder/IcallTableGenerator.cs index c40b6b4c6ec20a..ba6f90026c371c 100644 --- a/src/tasks/WasmAppBuilder/IcallTableGenerator.cs +++ b/src/tasks/WasmAppBuilder/IcallTableGenerator.cs @@ -3,9 +3,6 @@ using System; using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Text; @@ -23,8 +20,13 @@ internal sealed class IcallTableGenerator private Dictionary _runtimeIcalls = new Dictionary(); private TaskLoggingHelper Log { get; set; } + private readonly Func _fixupSymbolName; - public IcallTableGenerator(TaskLoggingHelper log) => Log = log; + public IcallTableGenerator(Func fixupSymbolName, TaskLoggingHelper log) + { + Log = log; + _fixupSymbolName = fixupSymbolName; + } // // Given the runtime generated icall table, and a set of assemblies, generate @@ -86,7 +88,7 @@ private void EmitTable(StreamWriter w) if (assembly == "System.Private.CoreLib") aname = "corlib"; else - aname = assembly.Replace(".", "_"); + aname = _fixupSymbolName(assembly); w.WriteLine($"#define ICALL_TABLE_{aname} 1\n"); w.WriteLine($"static int {aname}_icall_indexes [] = {{"); diff --git a/src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs b/src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs index df48afaa52f84a..1dff74a9dcc1c1 100644 --- a/src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs +++ b/src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs @@ -1,21 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.IO; using System.Linq; using System.Text; -using System.Text.Json; -using System.Reflection; using Microsoft.Build.Framework; using Microsoft.Build.Utilities; -#nullable enable - public class ManagedToNativeGenerator : Task { [Required] @@ -37,6 +29,11 @@ public class ManagedToNativeGenerator : Task [Output] public string[]? FileWrites { get; private set; } + private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' }; + + // Avoid sharing this cache with all the invocations of this task throughout the build + private readonly Dictionary _symbolNameFixups = new(); + public override bool Execute() { if (Assemblies!.Length == 0) @@ -65,8 +62,8 @@ public override bool Execute() private void ExecuteInternal() { - var pinvoke = new PInvokeTableGenerator(Log); - var icall = new IcallTableGenerator(Log); + var pinvoke = new PInvokeTableGenerator(FixupSymbolName, Log); + var icall = new IcallTableGenerator(FixupSymbolName, Log); IEnumerable cookies = Enumerable.Concat( pinvoke.Generate(PInvokeModules, Assemblies!, PInvokeOutputPath!), @@ -80,4 +77,37 @@ private void ExecuteInternal() ? new string[] { PInvokeOutputPath, IcallOutputPath, InterpToNativeOutputPath } : new string[] { PInvokeOutputPath, InterpToNativeOutputPath }; } + + public string FixupSymbolName(string name) + { + if (_symbolNameFixups.TryGetValue(name, out string? fixedName)) + return fixedName; + + UTF8Encoding utf8 = new(); + byte[] bytes = utf8.GetBytes(name); + StringBuilder sb = new(); + + foreach (byte b in bytes) + { + if ((b >= (byte)'0' && b <= (byte)'9') || + (b >= (byte)'a' && b <= (byte)'z') || + (b >= (byte)'A' && b <= (byte)'Z') || + (b == (byte)'_')) + { + sb.Append((char)b); + } + else if (s_charsToReplace.Contains((char)b)) + { + sb.Append('_'); + } + else + { + sb.Append($"_{b:X}_"); + } + } + + fixedName = sb.ToString(); + _symbolNameFixups[name] = fixedName; + return fixedName; + } } diff --git a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs index 74349b38166dde..92784d8b209155 100644 --- a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs +++ b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs @@ -13,12 +13,16 @@ internal sealed class PInvokeTableGenerator { - private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' }; private readonly Dictionary _assemblyDisableRuntimeMarshallingAttributeCache = new(); private TaskLoggingHelper Log { get; set; } + private readonly Func _fixupSymbolName; - public PInvokeTableGenerator(TaskLoggingHelper log) => Log = log; + public PInvokeTableGenerator(Func fixupSymbolName, TaskLoggingHelper log) + { + Log = log; + _fixupSymbolName = fixupSymbolName; + } public IEnumerable Generate(string[] pinvokeModules, string[] assemblies, string outputPath) { @@ -234,14 +238,14 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary modules foreach (var module in modules.Keys) { - string symbol = ModuleNameToId(module) + "_imports"; + string symbol = _fixupSymbolName(module) + "_imports"; w.WriteLine("static PinvokeImport " + symbol + " [] = {"); var assemblies_pinvokes = pinvokes. Where(l => l.Module == module && !l.Skip). OrderBy(l => l.EntryPoint). GroupBy(d => d.EntryPoint). - Select(l => "{\"" + FixupSymbolName(l.Key) + "\", " + FixupSymbolName(l.Key) + "}, " + + Select(l => "{\"" + _fixupSymbolName(l.Key) + "\", " + _fixupSymbolName(l.Key) + "}, " + "// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n))); foreach (var pinvoke in assemblies_pinvokes) @@ -255,7 +259,7 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary modules w.Write("static void *pinvoke_tables[] = { "); foreach (var module in modules.Keys) { - string symbol = ModuleNameToId(module) + "_imports"; + string symbol = _fixupSymbolName(module) + "_imports"; w.Write(symbol + ","); } w.WriteLine("};"); @@ -266,18 +270,6 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary modules } w.WriteLine("};"); - static string ModuleNameToId(string name) - { - if (name.IndexOfAny(s_charsToReplace) < 0) - return name; - - string fixedName = name; - foreach (char c in s_charsToReplace) - fixedName = fixedName.Replace(c, '_'); - - return fixedName; - } - static bool ShouldTreatAsVariadic(PInvoke[] candidates) { if (candidates.Length < 2) @@ -295,35 +287,7 @@ static bool ShouldTreatAsVariadic(PInvoke[] candidates) } } - private static string FixupSymbolName(string name) - { - UTF8Encoding utf8 = new(); - byte[] bytes = utf8.GetBytes(name); - StringBuilder sb = new(); - - foreach (byte b in bytes) - { - if ((b >= (byte)'0' && b <= (byte)'9') || - (b >= (byte)'a' && b <= (byte)'z') || - (b >= (byte)'A' && b <= (byte)'Z') || - (b == (byte)'_')) - { - sb.Append((char)b); - } - else if (s_charsToReplace.Contains((char)b)) - { - sb.Append('_'); - } - else - { - sb.Append($"_{b:X}_"); - } - } - - return sb.ToString(); - } - - private static string SymbolNameForMethod(MethodInfo method) + private string SymbolNameForMethod(MethodInfo method) { StringBuilder sb = new(); Type? type = method.DeclaringType; @@ -331,7 +295,7 @@ private static string SymbolNameForMethod(MethodInfo method) sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_"); sb.Append(method.Name); - return FixupSymbolName(sb.ToString()); + return _fixupSymbolName(sb.ToString()); } private static string MapType(Type t) => t.Name switch @@ -374,7 +338,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN { // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types // https://github.com/dotnet/runtime/issues/43791 - sb.Append($"int {FixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);"); + sb.Append($"int {_fixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);"); return sb.ToString(); } @@ -390,7 +354,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN } sb.Append(MapType(method.ReturnType)); - sb.Append($" {FixupSymbolName(pinvoke.EntryPoint)} ("); + sb.Append($" {_fixupSymbolName(pinvoke.EntryPoint)} ("); int pindex = 0; var pars = method.GetParameters(); foreach (var p in pars) @@ -404,7 +368,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN return sb.ToString(); } - private static void EmitNativeToInterp(StreamWriter w, ref List callbacks) + private void EmitNativeToInterp(StreamWriter w, ref List callbacks) { // Generate native->interp entry functions // These are called by native code, so they need to obtain @@ -450,7 +414,7 @@ private static void EmitNativeToInterp(StreamWriter w, ref List bool is_void = method.ReturnType.Name == "Void"; - string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_"); + string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!); uint token = (uint)method.MetadataToken; string class_name = method.DeclaringType.Name; string method_name = method.Name; @@ -517,7 +481,7 @@ private static void EmitNativeToInterp(StreamWriter w, ref List foreach (var cb in callbacks) { var method = cb.Method; - string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_"); + string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!); string class_name = method.DeclaringType.Name; string method_name = method.Name; w.WriteLine($"\"{module_symbol}_{class_name}_{method_name}\",");