Mercurial > hg > monetdb-java
changeset 801:88b3e8e89126 monetdbs
TLS seems to work
Need to add tests though
author | Joeri van Ruth <joeri.van.ruth@monetdbsolutions.com> |
---|---|
date | Fri, 08 Dec 2023 15:17:50 +0100 (16 months ago) |
parents | 09f463444dde |
children | 5d04490bc58b |
files | src/main/java/org/monetdb/jdbc/MonetConnection.java src/main/java/org/monetdb/mcl/net/MapiSocket.java src/main/java/org/monetdb/mcl/net/SecureSocket.java |
diffstat | 3 files changed, 169 insertions(+), 14 deletions(-) [+] |
line wrap: on
line diff
--- a/src/main/java/org/monetdb/jdbc/MonetConnection.java +++ b/src/main/java/org/monetdb/jdbc/MonetConnection.java @@ -41,6 +41,8 @@ import org.monetdb.mcl.parser.HeaderLine import org.monetdb.mcl.parser.MCLParseException; import org.monetdb.mcl.parser.StartOfHeaderParser; +import javax.net.ssl.SSLException; + /** *<pre> * A {@link Connection} suitable for the MonetDB database. @@ -215,8 +217,10 @@ public class MonetConnection final String error = in.discardRemainder(); if (error != null) throw new SQLNonTransientConnectionException((error.length() > 6) ? error.substring(6) : error, "08001"); + } catch (SSLException e) { + throw new SQLNonTransientConnectionException("Cannot establish secure connection: " + e.getMessage(), e); } catch (IOException e) { - throw new SQLNonTransientConnectionException("Cannot connect: " + e.getMessage(), "08006"); + throw new SQLNonTransientConnectionException("Cannot connect: " + e.getMessage(), "08006", e); } catch (MCLParseException e) { throw new SQLNonTransientConnectionException(e.getMessage(), "08001"); } catch (org.monetdb.mcl.MCLException e) {
--- a/src/main/java/org/monetdb/mcl/net/MapiSocket.java +++ b/src/main/java/org/monetdb/mcl/net/MapiSocket.java @@ -346,10 +346,11 @@ public final class MapiSocket { } private Socket wrapTLS(Socket sock, Target.Validated validated) throws IOException { - if (validated.getTls()) - return SecureSocket.wrap(validated, sock); - return sock; - } + if (validated.getTls()) + return SecureSocket.wrap(validated, sock); + else + return sock; + } private boolean handshake(Target.Validated validated, OptionsCallback callback, ArrayList<String> warnings) throws IOException, MCLException { String challenge = reader.getLine();
--- a/src/main/java/org/monetdb/mcl/net/SecureSocket.java +++ b/src/main/java/org/monetdb/mcl/net/SecureSocket.java @@ -1,22 +1,172 @@ package org.monetdb.mcl.net; -import org.monetdb.mcl.MCLException; - -import javax.net.ssl.SSLSocket; -import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.*; +import java.io.FileInputStream; import java.io.IOException; import java.net.Socket; +import java.security.*; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; public class SecureSocket { + private static final String[] ENABLED_PROTOCOLS = {"TLSv1.3"}; + final String[] APPLICATION_PROTOCOLS = {"mapi/9"}; + public static Socket wrap(Target.Validated validated, Socket inner) throws IOException { - SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); - String host = validated.connectTcp(); - int port = validated.connectPort(); - boolean autoclose = true; - SSLSocket sock = (SSLSocket) factory.createSocket(inner, host, port, autoclose); + Target.Verify verify = validated.connectVerify(); + SSLSocketFactory socketFactory; + try { + switch (verify) { + case System: + socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + return wrapSocket(inner, validated, socketFactory, true); + case Cert: + KeyStore keyStore = keyStoreForCert(validated.getCert()); + socketFactory = certBasedSocketFactory(keyStore); + return wrapSocket(inner, validated, socketFactory, true); + case Hash: + return wrapHash(validated, inner); + default: + throw new RuntimeException("unreachable: unexpected verification strategy " + verify.name()); + } + } catch (CertificateException e) { + throw new SSLException(e.getMessage(), e); + } + } + + private static Socket wrapHash(Target.Validated validated, Socket inner) throws IOException, CertificateException { + SSLSocketFactory socketFactory = hashBasedSocketFactory(validated.connectCertHashDigits()); + SSLSocket sock = wrapSocket(inner, validated, socketFactory, false); + + return sock; + } + + private static SSLSocket wrapSocket(Socket inner, Target.Validated validated, SSLSocketFactory socketFactory, boolean checkName) throws IOException { + SSLSocket sock = (SSLSocket) socketFactory.createSocket(inner, validated.connectTcp(), validated.connectPort(), true); + sock.setUseClientMode(true); + sock.setEnabledProtocols(ENABLED_PROTOCOLS); + if (checkName) { + SSLParameters parameters = sock.getSSLParameters(); + parameters.setEndpointIdentificationAlgorithm("HTTPS"); + sock.setSSLParameters(parameters); + } sock.startHandshake(); return sock; } + + private static X509Certificate loadCertificate(String path) throws CertificateException, IOException { + CertificateFactory factory = CertificateFactory.getInstance("X509"); + try (FileInputStream s = new FileInputStream(path)) { + return (X509Certificate) factory.generateCertificate(s); + } + } + + private static SSLSocketFactory certBasedSocketFactory(KeyStore store) throws IOException, CertificateException { + TrustManagerFactory trustManagerFactory; + try { + trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(store); + } catch (NoSuchAlgorithmException | KeyStoreException e) { + throw new RuntimeException("Could not create TrustManagerFactory", e); + } + + SSLContext context; + try { + context = SSLContext.getInstance("TLS"); + context.init(null, trustManagerFactory.getTrustManagers(), null); + } catch (NoSuchAlgorithmException | KeyManagementException e) { + throw new RuntimeException("Could not create SSLContext", e); + } + + return context.getSocketFactory(); + } + + private static KeyStore keyStoreForCert(String path) throws IOException, CertificateException { + try { + X509Certificate cert = loadCertificate(path); + KeyStore store = emptyKeyStore(); + store.setCertificateEntry("root", cert); + return store; + } catch (KeyStoreException e) { + throw new RuntimeException("Could not create KeyStore for certificate", e); + } + } + + private static KeyStore emptyKeyStore() throws IOException, CertificateException { + KeyStore store; + try { + store = KeyStore.getInstance("PKCS12"); + store.load(null, null); + return store; + } catch (KeyStoreException | NoSuchAlgorithmException e) { + throw new RuntimeException("Could not create KeyStore for certificate", e); + } + } + + private static SSLSocketFactory hashBasedSocketFactory(String hashDigits) { + TrustManager trustManager = new HashBasedTrustManager(hashDigits); + try { + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, new TrustManager[]{ trustManager}, null); + return context.getSocketFactory(); + } catch (NoSuchAlgorithmException | KeyManagementException e) { + throw new RuntimeException("Could not create SSLContext", e); + } + + } + + private static class HashBasedTrustManager implements X509TrustManager { + private static final char[] HEXDIGITS = "0123456789abcdef".toCharArray(); + private final String hashDigits; + + public HashBasedTrustManager(String hashDigits) { + this.hashDigits = hashDigits; + } + + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + throw new RuntimeException("this TrustManager is only suitable for client side connections"); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + X509Certificate cert = x509Certificates[0]; + byte[] certBytes = cert.getEncoded(); + + // for now it's always SHA256. + byte[] hashBytes; + try { + MessageDigest hasher = MessageDigest.getInstance("SHA-256"); + hasher.update(certBytes); + hashBytes = hasher.digest(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("failed to instantiate hash digest"); + } + + // convert to hex digits + StringBuilder buffer = new StringBuilder(2 * hashBytes.length); + for (byte b: hashBytes) { + int hi = (b & 0xF0) >> 4; + int lo = b & 0x0F; + buffer.append(HEXDIGITS[hi]); + buffer.append(HEXDIGITS[lo]); + } + String certDigits = buffer.toString(); + + if (!certDigits.startsWith(hashDigits)) { + throw new CertificateException("Certificate hash does not start with '" + hashDigits + "': " + certDigits); + } + + + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + } }