diff --git a/src/Confluent.SchemaRegistry/CachedSchemaRegistryClient.cs b/src/Confluent.SchemaRegistry/CachedSchemaRegistryClient.cs index f757e113e..e0a0eded3 100644 --- a/src/Confluent.SchemaRegistry/CachedSchemaRegistryClient.cs +++ b/src/Confluent.SchemaRegistry/CachedSchemaRegistryClient.cs @@ -250,7 +250,14 @@ public CachedSchemaRegistryClient(IEnumerable> conf try { sslVerify = sslVerificationMaybe.Value == null ? DefaultEnableSslCertificateVerification : bool.Parse(sslVerificationMaybe.Value); } catch (FormatException) { throw new ArgumentException($"Configured value for {SchemaRegistryConfig.PropertyNames.EnableSslCertificateVerification} must be a bool."); } - this.restService = new RestService(schemaRegistryUris, timeoutMs, authenticationHeaderValueProvider, SetSslConfig(config), sslVerify); + var sslCaLocation = config.FirstOrDefault(prop => prop.Key.ToLower() == SchemaRegistryConfig.PropertyNames.SslCaLocation).Value; + if (string.IsNullOrEmpty(sslCaLocation)) + { + this.restService = new RestService(schemaRegistryUris, timeoutMs, authenticationHeaderValueProvider, SetSslConfig(config), sslVerify); + } else + { + this.restService = new RestService(schemaRegistryUris, timeoutMs, authenticationHeaderValueProvider, SetSslConfig(config), sslVerify, new X509Certificate2(sslCaLocation)); + } } /// @@ -306,12 +313,6 @@ private bool CleanCacheIfFull()                 certificates.Add(new X509Certificate2(certificateLocation, certificatePassword));             } -            var caLocation = config.FirstOrDefault(prop => prop.Key.ToLower() == SchemaRegistryConfig.PropertyNames.SslCaLocation).Value ?? ""; -            if (!String.IsNullOrEmpty(caLocation)) -            { -                certificates.Add(new X509Certificate2(caLocation)); -            } -             return certificates;         } diff --git a/src/Confluent.SchemaRegistry/Rest/RestService.cs b/src/Confluent.SchemaRegistry/Rest/RestService.cs index 9f7ec21ae..a9438fd8f 100644 --- a/src/Confluent.SchemaRegistry/Rest/RestService.cs +++ b/src/Confluent.SchemaRegistry/Rest/RestService.cs @@ -24,6 +24,7 @@ using System.Text; using System.Threading.Tasks; using System.Security.Cryptography.X509Certificates; +using System.Net.Security; namespace Confluent.SchemaRegistry { @@ -57,7 +58,7 @@ internal class RestService : IRestService /// /// Initializes a new instance of the RestService class. /// - public RestService(string schemaRegistryUrl, int timeoutMs, IAuthenticationHeaderValueProvider authenticationHeaderValueProvider, List certificates, bool enableSslCertificateVerification) + public RestService(string schemaRegistryUrl, int timeoutMs, IAuthenticationHeaderValueProvider authenticationHeaderValueProvider, List certificates, bool enableSslCertificateVerification, X509Certificate2 sslCaCertificate = null) { this.authenticationHeaderValueProvider = authenticationHeaderValueProvider; @@ -69,7 +70,7 @@ internal class RestService : IRestService HttpClient client;                     if (certificates.Count > 0)                     { -                        client = new HttpClient(CreateHandler(certificates, enableSslCertificateVerification)) { BaseAddress = new Uri(uri, UriKind.Absolute), Timeout = TimeSpan.FromMilliseconds(timeoutMs) }; +                        client = new HttpClient(CreateHandler(certificates, enableSslCertificateVerification, sslCaCertificate)) { BaseAddress = new Uri(uri, UriKind.Absolute), Timeout = TimeSpan.FromMilliseconds(timeoutMs) };                     }                     else                     { @@ -86,17 +87,53 @@ private static string SanitizeUri(string uri) return $"{sanitized.TrimEnd('/')}/"; } - private static HttpClientHandler CreateHandler(List certificates, bool enableSslCertificateVerification) + private static HttpClientHandler CreateHandler(List certificates, bool enableSslCertificateVerification, X509Certificate2 sslCaCertificate) {     var handler = new HttpClientHandler(); handler.ClientCertificateOptions = ClientCertificateOption.Manual; +     certificates.ForEach(c => handler.ClientCertificates.Add(c)); + if (!enableSslCertificateVerification) { - handler.ServerCertificateCustomValidationCallback = (httpRequestMessage, cert, certChain, policyErrors) => { return true; }; - } + handler.ServerCertificateCustomValidationCallback = (_, __, ___, ____) => { return true; }; + } + else if (sslCaCertificate != null) + { + handler.ServerCertificateCustomValidationCallback = (_, __, chain, policyErrors) => { + + if (policyErrors == SslPolicyErrors.None) + { + return true; + } -     certificates.ForEach(c => handler.ClientCertificates.Add(c)); + + //The second element of the chain should be the issuer of the certificate + if (chain.ChainElements.Count < 2) + { + return false; + } + var connectionCertHash = chain.ChainElements[1].Certificate.GetCertHash(); + + + var expectedCertHash = sslCaCertificate.GetCertHash(); + + if (connectionCertHash.Length != expectedCertHash.Length) + { + return false; + } + + for (int i = 0; i < connectionCertHash.Length; i++) + { + if (connectionCertHash[i] != expectedCertHash[i]) + { + return false; + } + } + return true; + }; + } +     return handler; }