Mercurial > hg > monetdb-java
view src/main/java/org/monetdb/mcl/net/SecureSocket.java @ 918:2543e24eb79a
Add final to more classes.
author | Martin van Dinther <martin.van.dinther@monetdbsolutions.com> |
---|---|
date | Wed, 24 Jul 2024 19:38:33 +0200 (8 months ago) |
parents | b80758ef25db |
children | d416e9b6b3d0 |
line wrap: on
line source
/* * SPDX-License-Identifier: MPL-2.0 * * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. * * Copyright 2024 MonetDB Foundation; * Copyright August 2008 - 2023 MonetDB B.V.; * Copyright 1997 - July 2008 CWI. */ package org.monetdb.mcl.net; import javax.net.ssl.*; import java.io.FileInputStream; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.Socket; import java.security.*; import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.util.Collections; public final class SecureSocket { private static final String[] ENABLED_PROTOCOLS = {"TLSv1.3"}; private static final String[] APPLICATION_PROTOCOLS = {"mapi/9"}; // Cache for the default SSL factory. It must load all trust roots // so it's worthwhile to cache. // Only access this through #getDefaultSocketFactory() private static SSLSocketFactory vanillaFactory = null; private static synchronized SSLSocketFactory getDefaultSocketFactory() { if (vanillaFactory == null) { vanillaFactory = (SSLSocketFactory) SSLSocketFactory.getDefault(); } return vanillaFactory; } public static Socket wrap(Target.Validated validated, Socket inner) throws IOException { Target.Verify verify = validated.connectVerify(); SSLSocketFactory socketFactory; boolean checkName = true; try { switch (verify) { case System: socketFactory = getDefaultSocketFactory(); break; case Cert: KeyStore keyStore = keyStoreForCert(validated.getCert()); socketFactory = certBasedSocketFactory(keyStore); break; case Hash: socketFactory = hashBasedSocketFactory(validated.connectCertHashDigits()); checkName = false; break; default: throw new RuntimeException("unreachable: unexpected verification strategy " + verify.name()); } return wrapSocket(inner, validated, socketFactory, checkName); } catch (CertificateException e) { throw new SSLException("TLS certificate rejected", e); } } private static SSLSocket wrapSocket(Socket inner, Target.Validated validated, SSLSocketFactory socketFactory, boolean checkName) throws IOException { SSLSocket sock = (SSLSocket) socketFactory.createSocket(inner, validated.connectTcp(), validated.connectPort(), true); sock.setUseClientMode(true); SSLParameters parameters = sock.getSSLParameters(); parameters.setProtocols(ENABLED_PROTOCOLS); parameters.setServerNames(Collections.singletonList(new SNIHostName(validated.connectTcp()))); if (checkName) { parameters.setEndpointIdentificationAlgorithm("HTTPS"); } // Unfortunately, SSLParameters.setApplicationProtocols is only available // since language level 9 and currently we're on 8. // Still call it if it happens to be available. try { Method setApplicationProtocols = SSLParameters.class.getMethod("setApplicationProtocols", String[].class); setApplicationProtocols.invoke(parameters, (Object) APPLICATION_PROTOCOLS); } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException ignored) { } sock.setSSLParameters(parameters); sock.startHandshake(); return sock; } private static X509Certificate loadCertificate(String path) throws CertificateException, IOException { CertificateFactory factory = CertificateFactory.getInstance("X509"); try (FileInputStream s = new FileInputStream(path)) { return (X509Certificate) factory.generateCertificate(s); } } private static SSLSocketFactory certBasedSocketFactory(KeyStore store) throws IOException, CertificateException { TrustManagerFactory trustManagerFactory; try { trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(store); } catch (NoSuchAlgorithmException | KeyStoreException e) { throw new RuntimeException("Could not create TrustManagerFactory", e); } SSLContext context; try { context = SSLContext.getInstance("TLS"); context.init(null, trustManagerFactory.getTrustManagers(), null); } catch (NoSuchAlgorithmException | KeyManagementException e) { throw new RuntimeException("Could not create SSLContext", e); } return context.getSocketFactory(); } private static KeyStore keyStoreForCert(String path) throws IOException, CertificateException { try { X509Certificate cert = loadCertificate(path); KeyStore store = emptyKeyStore(); store.setCertificateEntry("root", cert); return store; } catch (KeyStoreException e) { throw new RuntimeException("Could not create KeyStore for certificate", e); } } private static KeyStore emptyKeyStore() throws IOException, CertificateException { KeyStore store; try { store = KeyStore.getInstance("PKCS12"); store.load(null, null); return store; } catch (KeyStoreException | NoSuchAlgorithmException e) { throw new RuntimeException("Could not create KeyStore for certificate", e); } } private static SSLSocketFactory hashBasedSocketFactory(String hashDigits) { TrustManager trustManager = new HashBasedTrustManager(hashDigits); try { SSLContext context = SSLContext.getInstance("TLS"); context.init(null, new TrustManager[]{trustManager}, null); return context.getSocketFactory(); } catch (NoSuchAlgorithmException | KeyManagementException e) { throw new RuntimeException("Could not create SSLContext", e); } } private static class HashBasedTrustManager implements X509TrustManager { private static final char[] HEXDIGITS = "0123456789abcdef".toCharArray(); private final String hashDigits; public HashBasedTrustManager(String hashDigits) { this.hashDigits = hashDigits; } @Override public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { throw new RuntimeException("this TrustManager is only suitable for client side connections"); } @Override public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { X509Certificate cert = x509Certificates[0]; byte[] certBytes = cert.getEncoded(); // for now it's always SHA256. byte[] hashBytes; try { MessageDigest hasher = MessageDigest.getInstance("SHA-256"); hasher.update(certBytes); hashBytes = hasher.digest(); } catch (NoSuchAlgorithmException e) { throw new RuntimeException("failed to instantiate hash digest"); } // convert to hex digits StringBuilder buffer = new StringBuilder(2 * hashBytes.length); for (byte b : hashBytes) { int hi = (b & 0xF0) >> 4; int lo = b & 0x0F; buffer.append(HEXDIGITS[hi]); buffer.append(HEXDIGITS[lo]); } String certDigits = buffer.toString(); if (!certDigits.startsWith(hashDigits)) { throw new CertificateException("Certificate hash does not start with '" + hashDigits + "': " + certDigits); } } @Override public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } } }