diff --git a/src/Authentication/Authentication.Test/Helpers/AuthenticationHelpersTests.cs b/src/Authentication/Authentication.Test/Helpers/AuthenticationHelpersTests.cs index 10273314154..3bf8906e43d 100644 --- a/src/Authentication/Authentication.Test/Helpers/AuthenticationHelpersTests.cs +++ b/src/Authentication/Authentication.Test/Helpers/AuthenticationHelpersTests.cs @@ -3,6 +3,7 @@ using Microsoft.Graph.Auth; using Microsoft.Graph.PowerShell.Authentication; using Microsoft.Graph.PowerShell.Authentication.Helpers; + using System; using System.Linq; using System.Net; @@ -10,6 +11,7 @@ using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; + using Xunit; public class AuthenticationHelpersTests { @@ -78,7 +80,7 @@ public void ShouldUseClientCredentialProviderWhenAppOnlyContextIsProvided() CertificateName = "cn=dummyCert", ContextScope = ContextScope.Process }; - CreateSelfSignedCert(appOnlyAuthContext.CertificateName); + CreateAndStoreSelfSignedCert(appOnlyAuthContext.CertificateName); // Act IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(appOnlyAuthContext); @@ -87,12 +89,155 @@ public void ShouldUseClientCredentialProviderWhenAppOnlyContextIsProvided() Assert.IsType(authProvider); // reset - DeleteSelfSignedCert(appOnlyAuthContext.CertificateName); + DeleteSelfSignedCertByName(appOnlyAuthContext.CertificateName); + GraphSession.Reset(); + + } + + [Fact] + public void ShouldUseInMemoryCertificateWhenProvided() + { + // Arrange + var certificate = CreateSelfSignedCert("cn=inmemorycert"); + AuthContext appOnlyAuthContext = new AuthContext + { + AuthType = AuthenticationType.AppOnly, + ClientId = Guid.NewGuid().ToString(), + Certificate = certificate, + ContextScope = ContextScope.Process + }; + // Act + IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(appOnlyAuthContext); + + // Assert + Assert.IsType(authProvider); + var clientCredentialProvider = (ClientCredentialProvider)authProvider; + // Assert: That the certificate created and set above is the same as used here. + Assert.Equal(clientCredentialProvider.ClientApplication.AppConfig.ClientCredentialCertificate, certificate); GraphSession.Reset(); - } - private void CreateSelfSignedCert(string certName) + [Fact] + public void ShouldUseCertNameInsteadOfPassedInCertificateWhenBothAreSpecified() + { + // Arrange + var dummyCertName = "CN=dummycert"; + var inMemoryCertName = "CN=inmemorycert"; + CreateAndStoreSelfSignedCert(dummyCertName); + var inMemoryCertificate = CreateSelfSignedCert(inMemoryCertName); + AuthContext appOnlyAuthContext = new AuthContext + { + AuthType = AuthenticationType.AppOnly, + ClientId = Guid.NewGuid().ToString(), + CertificateName = dummyCertName, + Certificate = inMemoryCertificate, + ContextScope = ContextScope.Process + }; + // Act + IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(appOnlyAuthContext); + + // Assert + Assert.IsType(authProvider); + var clientCredentialProvider = (ClientCredentialProvider)authProvider; + // Assert: That the certificate used is dummycert, that is in the store + Assert.NotEqual(inMemoryCertName, clientCredentialProvider.ClientApplication.AppConfig.ClientCredentialCertificate.SubjectName.Name); + Assert.Equal(appOnlyAuthContext.CertificateName, clientCredentialProvider.ClientApplication.AppConfig.ClientCredentialCertificate.SubjectName.Name); + + //CleanUp + DeleteSelfSignedCertByName(appOnlyAuthContext.CertificateName); + GraphSession.Reset(); + } + + [Fact] + public void ShouldUseCertThumbPrintInsteadOfPassedInCertificateWhenBothAreSpecified() + { + // Arrange + var dummyCertName = "CN=dummycert"; + var inMemoryCertName = "CN=inmemorycert"; + var storedDummyCertificate = CreateAndStoreSelfSignedCert(dummyCertName); + var inMemoryCertificate = CreateSelfSignedCert(inMemoryCertName); + AuthContext appOnlyAuthContext = new AuthContext + { + AuthType = AuthenticationType.AppOnly, + ClientId = Guid.NewGuid().ToString(), + CertificateThumbprint = storedDummyCertificate.Thumbprint, + Certificate = inMemoryCertificate, + ContextScope = ContextScope.Process + }; + // Act + IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(appOnlyAuthContext); + + // Assert + Assert.IsType(authProvider); + var clientCredentialProvider = (ClientCredentialProvider)authProvider; + // Assert: That the certificate used is dummycert (Thumbprint), that is in the store + Assert.NotEqual(inMemoryCertName, clientCredentialProvider.ClientApplication.AppConfig.ClientCredentialCertificate.SubjectName.Name); + Assert.Equal(appOnlyAuthContext.CertificateThumbprint, clientCredentialProvider.ClientApplication.AppConfig.ClientCredentialCertificate.Thumbprint); + + //CleanUp + DeleteSelfSignedCertByThumbprint(appOnlyAuthContext.CertificateThumbprint); + GraphSession.Reset(); + } + + [Fact] + public void ShouldThrowIfNonExistentCertNameIsProvided() + { + // Arrange + var dummyCertName = "CN=NonExistingCert"; + AuthContext appOnlyAuthContext = new AuthContext + { + AuthType = AuthenticationType.AppOnly, + ClientId = Guid.NewGuid().ToString(), + CertificateName = dummyCertName, + ContextScope = ContextScope.Process + }; + // Act + Action action = () => AuthenticationHelpers.GetAuthProvider(appOnlyAuthContext); + + //Assert + Assert.ThrowsAny(action); + } + + [Fact] + public void ShouldThrowIfNullInMemoryCertIsProvided() + { + // Arrange + AuthContext appOnlyAuthContext = new AuthContext + { + AuthType = AuthenticationType.AppOnly, + ClientId = Guid.NewGuid().ToString(), + Certificate = null, + ContextScope = ContextScope.Process + }; + // Act + Action action = () => AuthenticationHelpers.GetAuthProvider(appOnlyAuthContext); + + //Assert + Assert.Throws(action); + } + + /// + /// Create and Store a Self Signed Certificate + /// + /// + private static X509Certificate2 CreateAndStoreSelfSignedCert(string certName) + { + var cert = CreateSelfSignedCert(certName); + using (var store = new X509Store(StoreName.My, StoreLocation.CurrentUser)) + { + store.Open(OpenFlags.ReadWrite); + store.Add(cert); + } + + return cert; + } + + /// + /// Create a Self Signed Certificate + /// + /// + /// + private static X509Certificate2 CreateSelfSignedCert(string certName) { ECDsa ecdsaKey = ECDsa.Create(); CertificateRequest certificateRequest = new CertificateRequest(certName, ecdsaKey, HashAlgorithmName.SHA256); @@ -108,14 +253,11 @@ private void CreateSelfSignedCert(string certName) { dummyCert = new X509Certificate2(cert.Export(X509ContentType.Pfx, "P@55w0rd"), "P@55w0rd", X509KeyStorageFlags.PersistKeySet); } - using (X509Store store = new X509Store(StoreName.My, StoreLocation.CurrentUser)) - { - store.Open(OpenFlags.ReadWrite); - store.Add(dummyCert); - } + + return dummyCert; } - private void DeleteSelfSignedCert(string certificateName) + private static void DeleteSelfSignedCertByName(string certificateName) { using (X509Store xStore = new X509Store(StoreName.My, StoreLocation.CurrentUser)) { @@ -134,6 +276,25 @@ private void DeleteSelfSignedCert(string certificateName) xStore.Remove(xCertificate); } } + private static void DeleteSelfSignedCertByThumbprint(string certificateThumbPrint) + { + using (X509Store xStore = new X509Store(StoreName.My, StoreLocation.CurrentUser)) + { + xStore.Open(OpenFlags.ReadWrite); + + X509Certificate2Collection unexpiredCerts = xStore.Certificates + .Find(X509FindType.FindByTimeValid, DateTime.Now, false) + .Find(X509FindType.FindByThumbprint, certificateThumbPrint, false); + + // Only return current cert. + var xCertificate = unexpiredCerts + .OfType() + .OrderByDescending(c => c.NotBefore) + .FirstOrDefault(); + + xStore.Remove(xCertificate); + } + } #endif } diff --git a/src/Authentication/Authentication/Cmdlets/ConnectMgGraph.cs b/src/Authentication/Authentication/Cmdlets/ConnectMgGraph.cs index 04c83aabfde..75a55d6f353 100644 --- a/src/Authentication/Authentication/Cmdlets/ConnectMgGraph.cs +++ b/src/Authentication/Authentication/Cmdlets/ConnectMgGraph.cs @@ -18,6 +18,7 @@ namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets using System.Globalization; using Microsoft.Graph.PowerShell.Authentication.Interfaces; using Microsoft.Graph.PowerShell.Authentication.Common; + using System.Security.Cryptography.X509Certificates; [Cmdlet(VerbsCommunications.Connect, "MgGraph", DefaultParameterSetName = Constants.UserParameterSet)] [Alias("Connect-Graph")] @@ -45,7 +46,7 @@ public class ConnectMgGraph : PSCmdlet, IModuleAssemblyInitializer, IModuleAssem Position = 3, HelpMessage = "The thumbprint of your certificate. The Certificate will be retrieved from the current user's certificate store.")] public string CertificateThumbprint { get; set; } - + [Parameter(ParameterSetName = Constants.AccessTokenParameterSet, Position = 1, HelpMessage = "Specifies a bearer token for Microsoft Graph service. Access tokens do timeout and you'll have to handle their refresh.")] @@ -69,6 +70,9 @@ public class ConnectMgGraph : PSCmdlet, IModuleAssemblyInitializer, IModuleAssem [Alias("EnvironmentName", "NationalCloud")] public string Environment { get; set; } + [Parameter(ParameterSetName = Constants.AppParameterSet, Mandatory = false, HelpMessage = "An x509 Certificate supplied during invocation")] + public X509Certificate2 Certificate { get; set; } + private CancellationTokenSource cancellationTokenSource; private IGraphEnvironment environment; @@ -125,6 +129,7 @@ protected override void ProcessRecord() authContext.ClientId = ClientId; authContext.CertificateThumbprint = CertificateThumbprint; authContext.CertificateName = CertificateName; + authContext.Certificate = Certificate; // Default to Process but allow the customer to change this via `ContextScope` param. authContext.ContextScope = this.IsParameterBound(nameof(ContextScope)) ? ContextScope : ContextScope.Process; } @@ -256,10 +261,10 @@ private void ValidateParameters() this.ThrowParameterError(nameof(ClientId)); } - // Certificate Thumbprint or name - if (string.IsNullOrEmpty(CertificateThumbprint) && string.IsNullOrEmpty(CertificateName)) + // Certificate Thumbprint, Name or Actual Certificate + if (string.IsNullOrEmpty(CertificateThumbprint) && string.IsNullOrEmpty(CertificateName) && this.Certificate == null) { - this.ThrowParameterError($"{nameof(CertificateThumbprint)} or {nameof(CertificateName)}"); + this.ThrowParameterError($"{nameof(CertificateThumbprint)} or {nameof(CertificateName)} or {nameof(Certificate)}"); } // Tenant Id diff --git a/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs b/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs index afef14b889e..096c0350aa6 100644 --- a/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs +++ b/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs @@ -7,6 +7,7 @@ namespace Microsoft.Graph.PowerShell.Authentication.Helpers using Microsoft.Graph.PowerShell.Authentication.Models; using Microsoft.Graph.PowerShell.Authentication.TokenCache; using Microsoft.Identity.Client; + using System; using System.Linq; using System.Net; @@ -14,6 +15,7 @@ namespace Microsoft.Graph.PowerShell.Authentication.Helpers using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; + using AuthenticationException = System.Security.Authentication.AuthenticationException; internal static class AuthenticationHelpers @@ -40,7 +42,8 @@ internal static IAuthenticationProvider GetAuthProvider(IAuthContext authContext .Build(); ConfigureTokenCache(publicClientApp.UserTokenCache, authContext); - authProvider = new DeviceCodeProvider(publicClientApp, authContext.Scopes, async (result) => { + authProvider = new DeviceCodeProvider(publicClientApp, authContext.Scopes, async (result) => + { await Console.Out.WriteLineAsync(result.Message); }); break; @@ -51,7 +54,7 @@ internal static IAuthenticationProvider GetAuthProvider(IAuthContext authContext .Create(authContext.ClientId) .WithTenantId(authContext.TenantId) .WithAuthority(authorityUrl) - .WithCertificate(string.IsNullOrEmpty(authContext.CertificateThumbprint) ? GetCertificateByName(authContext.CertificateName) : GetCertificateByThumbprint(authContext.CertificateThumbprint)) + .WithCertificate(GetCertificate(authContext)) .Build(); ConfigureTokenCache(confidentialClientApp.AppTokenCache, authContext); @@ -61,7 +64,8 @@ internal static IAuthenticationProvider GetAuthProvider(IAuthContext authContext } case AuthenticationType.UserProvidedAccessToken: { - authProvider = new DelegateAuthenticationProvider((requestMessage) => { + authProvider = new DelegateAuthenticationProvider((requestMessage) => + { requestMessage.Headers.Authorization = new AuthenticationHeaderValue("Bearer", new NetworkCredential(string.Empty, GraphSession.Instance.UserProvidedToken).Password); return Task.CompletedTask; @@ -71,6 +75,35 @@ internal static IAuthenticationProvider GetAuthProvider(IAuthContext authContext } return authProvider; } + /// + /// Gets a certificate based on the current context. + /// Priority is Name, ThumbPrint, then In-Memory Cert + /// + /// Current context + /// A based on provided context + private static X509Certificate2 GetCertificate(IAuthContext context) + { + X509Certificate2 certificate; + if (!string.IsNullOrWhiteSpace(context.CertificateName)) + { + certificate = GetCertificateByName(context.CertificateName); + } + else if (!string.IsNullOrWhiteSpace(context.CertificateThumbprint)) + { + certificate = GetCertificateByThumbprint(context.CertificateThumbprint); + } + else + { + certificate = context.Certificate; + } + + if (certificate == null) + { + throw new ArgumentNullException(nameof(certificate), $"Certificate with the Specified ThumbPrint {context.CertificateThumbprint}, Name {context.CertificateName} or In-Memory could not be found"); + } + + return certificate; + } private static string GetAuthorityUrl(IAuthContext authContext) { @@ -108,7 +141,8 @@ internal static void Logout(IAuthContext authConfig) private static void ConfigureTokenCache(ITokenCache tokenCache, IAuthContext authContext) { - tokenCache.SetBeforeAccess((TokenCacheNotificationArgs args) => { + tokenCache.SetBeforeAccess((TokenCacheNotificationArgs args) => + { try { _cacheLock.EnterReadLock(); @@ -120,7 +154,8 @@ private static void ConfigureTokenCache(ITokenCache tokenCache, IAuthContext aut } }); - tokenCache.SetAfterAccess((TokenCacheNotificationArgs args) => { + tokenCache.SetAfterAccess((TokenCacheNotificationArgs args) => + { if (args.HasStateChanged) { try @@ -164,8 +199,8 @@ private static X509Certificate2 GetCertificateByThumbprint(string CertificateThu .FirstOrDefault(); } return xCertificate; - } - + } + /// /// Gets unexpired certificate of the specified certificate subject name for the current user in My store.. /// @@ -184,7 +219,7 @@ private static X509Certificate2 GetCertificateByName(string CertificateName) .Find(X509FindType.FindByTimeValid, DateTime.Now, false) .Find(X509FindType.FindBySubjectDistinguishedName, CertificateName, false); - if (unexpiredCerts == null) + if (unexpiredCerts.Count < 1) throw new Exception($"{CertificateName} certificate was not found or has expired."); // Only return current cert. diff --git a/src/Authentication/Authentication/Interfaces/IAuthContext.cs b/src/Authentication/Authentication/Interfaces/IAuthContext.cs index 9dd1f1ddaed..98ce72b2b87 100644 --- a/src/Authentication/Authentication/Interfaces/IAuthContext.cs +++ b/src/Authentication/Authentication/Interfaces/IAuthContext.cs @@ -2,6 +2,8 @@ // Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. // ------------------------------------------------------------------------------ +using System.Security.Cryptography.X509Certificates; + namespace Microsoft.Graph.PowerShell.Authentication { public enum AuthenticationType @@ -28,5 +30,6 @@ public interface IAuthContext string Account { get; set; } string AppName { get; set; } ContextScope ContextScope { get; set; } + X509Certificate2 Certificate { get; set; } } } diff --git a/src/Authentication/Authentication/Models/AuthContext.cs b/src/Authentication/Authentication/Models/AuthContext.cs index 4df327d8218..8a9545d5478 100644 --- a/src/Authentication/Authentication/Models/AuthContext.cs +++ b/src/Authentication/Authentication/Models/AuthContext.cs @@ -1,6 +1,9 @@ // ------------------------------------------------------------------------------ // Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. // ------------------------------------------------------------------------------ + +using System.Security.Cryptography.X509Certificates; + namespace Microsoft.Graph.PowerShell.Authentication { public class AuthContext: IAuthContext @@ -15,6 +18,7 @@ public class AuthContext: IAuthContext public string Account { get; set; } public string AppName { get; set; } public ContextScope ContextScope { get ; set ; } + public X509Certificate2 Certificate { get; set; } public AuthContext() {