comparison src/main/java/org/monetdb/mcl/net/SecureSocket.java @ 834:5aa19bbed0d6 monetdbs

Comments and formatting
author Joeri van Ruth <joeri.van.ruth@monetdbsolutions.com>
date Wed, 13 Dec 2023 15:39:47 +0100 (17 months ago)
parents 2fee4b71baac
children a2b1ae53565e
comparison
equal deleted inserted replaced
813:a71afa48f269 834:5aa19bbed0d6
11 import java.security.cert.CertificateFactory; 11 import java.security.cert.CertificateFactory;
12 import java.security.cert.X509Certificate; 12 import java.security.cert.X509Certificate;
13 import java.util.Collections; 13 import java.util.Collections;
14 14
15 public class SecureSocket { 15 public class SecureSocket {
16 private static final String[] ENABLED_PROTOCOLS = {"TLSv1.3"}; 16 private static final String[] ENABLED_PROTOCOLS = {"TLSv1.3"};
17 private static final String[] APPLICATION_PROTOCOLS = {"mapi/9"}; 17 private static final String[] APPLICATION_PROTOCOLS = {"mapi/9"};
18 18
19 public static Socket wrap(Target.Validated validated, Socket inner) throws IOException { 19 public static Socket wrap(Target.Validated validated, Socket inner) throws IOException {
20 Target.Verify verify = validated.connectVerify(); 20 Target.Verify verify = validated.connectVerify();
21 SSLSocketFactory socketFactory; 21 SSLSocketFactory socketFactory;
22 boolean checkName = true; 22 boolean checkName = true;
23 try { 23 try {
24 switch (verify) { 24 switch (verify) {
25 case System: 25 case System:
26 socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault(); 26 socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault();
27 break; 27 break;
28 case Cert: 28 case Cert:
29 KeyStore keyStore = keyStoreForCert(validated.getCert()); 29 KeyStore keyStore = keyStoreForCert(validated.getCert());
30 socketFactory = certBasedSocketFactory(keyStore); 30 socketFactory = certBasedSocketFactory(keyStore);
31 break; 31 break;
32 case Hash: 32 case Hash:
33 socketFactory = hashBasedSocketFactory(validated.connectCertHashDigits()); 33 socketFactory = hashBasedSocketFactory(validated.connectCertHashDigits());
34 checkName = false; 34 checkName = false;
35 break; 35 break;
36 default: 36 default:
37 throw new RuntimeException("unreachable: unexpected verification strategy " + verify.name()); 37 throw new RuntimeException("unreachable: unexpected verification strategy " + verify.name());
38 } 38 }
39 return wrapSocket(inner, validated, socketFactory, checkName); 39 return wrapSocket(inner, validated, socketFactory, checkName);
40 } catch (CertificateException e) { 40 } catch (CertificateException e) {
41 throw new SSLException(e.getMessage(), e); 41 throw new SSLException(e.getMessage(), e);
42 } 42 }
43 } 43 }
44 44
45 private static SSLSocket wrapSocket(Socket inner, Target.Validated validated, SSLSocketFactory socketFactory, boolean checkName) throws IOException { 45 private static SSLSocket wrapSocket(Socket inner, Target.Validated validated, SSLSocketFactory socketFactory, boolean checkName) throws IOException {
46 SSLSocket sock = (SSLSocket) socketFactory.createSocket(inner, validated.connectTcp(), validated.connectPort(), true); 46 SSLSocket sock = (SSLSocket) socketFactory.createSocket(inner, validated.connectTcp(), validated.connectPort(), true);
47 sock.setUseClientMode(true); 47 sock.setUseClientMode(true);
48 SSLParameters parameters = sock.getSSLParameters(); 48 SSLParameters parameters = sock.getSSLParameters();
49 49
50 parameters.setProtocols(ENABLED_PROTOCOLS); 50 parameters.setProtocols(ENABLED_PROTOCOLS);
51 51
52 parameters.setServerNames(Collections.singletonList(new SNIHostName(validated.connectTcp()))); 52 parameters.setServerNames(Collections.singletonList(new SNIHostName(validated.connectTcp())));
53 53
54 if (checkName) { 54 if (checkName) {
55 parameters.setEndpointIdentificationAlgorithm("HTTPS"); 55 parameters.setEndpointIdentificationAlgorithm("HTTPS");
56 } 56 }
57 57
58 // Unfortunately, SSLParameters.setApplicationProtocols is only available 58 // Unfortunately, SSLParameters.setApplicationProtocols is only available
59 // since language level 9 and currently we're on 8. 59 // since language level 9 and currently we're on 8.
60 // Still call it if it happens to be available. 60 // Still call it if it happens to be available.
61 try { 61 try {
62 Method setApplicationProtocols = SSLParameters.class.getMethod("setApplicationProtocols", String[].class); 62 Method setApplicationProtocols = SSLParameters.class.getMethod("setApplicationProtocols", String[].class);
63 setApplicationProtocols.invoke(parameters, (Object) APPLICATION_PROTOCOLS); 63 setApplicationProtocols.invoke(parameters, (Object) APPLICATION_PROTOCOLS);
64 } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException ignored) {} 64 } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException ignored) {
65 }
65 66
66 sock.setSSLParameters(parameters); 67 sock.setSSLParameters(parameters);
67 sock.startHandshake(); 68 sock.startHandshake();
68 return sock; 69 return sock;
69 } 70 }
70 71
71 private static X509Certificate loadCertificate(String path) throws CertificateException, IOException { 72 private static X509Certificate loadCertificate(String path) throws CertificateException, IOException {
72 CertificateFactory factory = CertificateFactory.getInstance("X509"); 73 CertificateFactory factory = CertificateFactory.getInstance("X509");
73 try (FileInputStream s = new FileInputStream(path)) { 74 try (FileInputStream s = new FileInputStream(path)) {
74 return (X509Certificate) factory.generateCertificate(s); 75 return (X509Certificate) factory.generateCertificate(s);
75 } 76 }
76 } 77 }
77 78
78 private static SSLSocketFactory certBasedSocketFactory(KeyStore store) throws IOException, CertificateException { 79 private static SSLSocketFactory certBasedSocketFactory(KeyStore store) throws IOException, CertificateException {
79 TrustManagerFactory trustManagerFactory; 80 TrustManagerFactory trustManagerFactory;
80 try { 81 try {
81 trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); 82 trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
82 trustManagerFactory.init(store); 83 trustManagerFactory.init(store);
83 } catch (NoSuchAlgorithmException | KeyStoreException e) { 84 } catch (NoSuchAlgorithmException | KeyStoreException e) {
84 throw new RuntimeException("Could not create TrustManagerFactory", e); 85 throw new RuntimeException("Could not create TrustManagerFactory", e);
85 } 86 }
86 87
87 SSLContext context; 88 SSLContext context;
88 try { 89 try {
89 context = SSLContext.getInstance("TLS"); 90 context = SSLContext.getInstance("TLS");
90 context.init(null, trustManagerFactory.getTrustManagers(), null); 91 context.init(null, trustManagerFactory.getTrustManagers(), null);
91 } catch (NoSuchAlgorithmException | KeyManagementException e) { 92 } catch (NoSuchAlgorithmException | KeyManagementException e) {
92 throw new RuntimeException("Could not create SSLContext", e); 93 throw new RuntimeException("Could not create SSLContext", e);
93 } 94 }
94 95
95 return context.getSocketFactory(); 96 return context.getSocketFactory();
96 } 97 }
97 98
98 private static KeyStore keyStoreForCert(String path) throws IOException, CertificateException { 99 private static KeyStore keyStoreForCert(String path) throws IOException, CertificateException {
99 try { 100 try {
100 X509Certificate cert = loadCertificate(path); 101 X509Certificate cert = loadCertificate(path);
101 KeyStore store = emptyKeyStore(); 102 KeyStore store = emptyKeyStore();
102 store.setCertificateEntry("root", cert); 103 store.setCertificateEntry("root", cert);
103 return store; 104 return store;
104 } catch (KeyStoreException e) { 105 } catch (KeyStoreException e) {
105 throw new RuntimeException("Could not create KeyStore for certificate", e); 106 throw new RuntimeException("Could not create KeyStore for certificate", e);
106 } 107 }
107 } 108 }
108 109
109 private static KeyStore emptyKeyStore() throws IOException, CertificateException { 110 private static KeyStore emptyKeyStore() throws IOException, CertificateException {
110 KeyStore store; 111 KeyStore store;
111 try { 112 try {
112 store = KeyStore.getInstance("PKCS12"); 113 store = KeyStore.getInstance("PKCS12");
113 store.load(null, null); 114 store.load(null, null);
114 return store; 115 return store;
115 } catch (KeyStoreException | NoSuchAlgorithmException e) { 116 } catch (KeyStoreException | NoSuchAlgorithmException e) {
116 throw new RuntimeException("Could not create KeyStore for certificate", e); 117 throw new RuntimeException("Could not create KeyStore for certificate", e);
117 } 118 }
118 } 119 }
119 120
120 private static SSLSocketFactory hashBasedSocketFactory(String hashDigits) { 121 private static SSLSocketFactory hashBasedSocketFactory(String hashDigits) {
121 TrustManager trustManager = new HashBasedTrustManager(hashDigits); 122 TrustManager trustManager = new HashBasedTrustManager(hashDigits);
122 try { 123 try {
123 SSLContext context = SSLContext.getInstance("TLS"); 124 SSLContext context = SSLContext.getInstance("TLS");
124 context.init(null, new TrustManager[]{ trustManager}, null); 125 context.init(null, new TrustManager[]{trustManager}, null);
125 return context.getSocketFactory(); 126 return context.getSocketFactory();
126 } catch (NoSuchAlgorithmException | KeyManagementException e) { 127 } catch (NoSuchAlgorithmException | KeyManagementException e) {
127 throw new RuntimeException("Could not create SSLContext", e); 128 throw new RuntimeException("Could not create SSLContext", e);
128 } 129 }
129 130
130 } 131 }
131 132
132 private static class HashBasedTrustManager implements X509TrustManager { 133 private static class HashBasedTrustManager implements X509TrustManager {
133 private static final char[] HEXDIGITS = "0123456789abcdef".toCharArray(); 134 private static final char[] HEXDIGITS = "0123456789abcdef".toCharArray();
134 private final String hashDigits; 135 private final String hashDigits;
135 136
136 public HashBasedTrustManager(String hashDigits) { 137 public HashBasedTrustManager(String hashDigits) {
137 this.hashDigits = hashDigits; 138 this.hashDigits = hashDigits;
138 } 139 }
139 140
140 141
141 @Override 142 @Override
142 public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { 143 public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
143 throw new RuntimeException("this TrustManager is only suitable for client side connections"); 144 throw new RuntimeException("this TrustManager is only suitable for client side connections");
144 } 145 }
145 146
146 @Override 147 @Override
147 public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { 148 public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
148 X509Certificate cert = x509Certificates[0]; 149 X509Certificate cert = x509Certificates[0];
149 byte[] certBytes = cert.getEncoded(); 150 byte[] certBytes = cert.getEncoded();
150 151
151 // for now it's always SHA256. 152 // for now it's always SHA256.
152 byte[] hashBytes; 153 byte[] hashBytes;
153 try { 154 try {
154 MessageDigest hasher = MessageDigest.getInstance("SHA-256"); 155 MessageDigest hasher = MessageDigest.getInstance("SHA-256");
155 hasher.update(certBytes); 156 hasher.update(certBytes);
156 hashBytes = hasher.digest(); 157 hashBytes = hasher.digest();
157 } catch (NoSuchAlgorithmException e) { 158 } catch (NoSuchAlgorithmException e) {
158 throw new RuntimeException("failed to instantiate hash digest"); 159 throw new RuntimeException("failed to instantiate hash digest");
159 } 160 }
160 161
161 // convert to hex digits 162 // convert to hex digits
162 StringBuilder buffer = new StringBuilder(2 * hashBytes.length); 163 StringBuilder buffer = new StringBuilder(2 * hashBytes.length);
163 for (byte b: hashBytes) { 164 for (byte b : hashBytes) {
164 int hi = (b & 0xF0) >> 4; 165 int hi = (b & 0xF0) >> 4;
165 int lo = b & 0x0F; 166 int lo = b & 0x0F;
166 buffer.append(HEXDIGITS[hi]); 167 buffer.append(HEXDIGITS[hi]);
167 buffer.append(HEXDIGITS[lo]); 168 buffer.append(HEXDIGITS[lo]);
168 } 169 }
169 String certDigits = buffer.toString(); 170 String certDigits = buffer.toString();
170 171
171 if (!certDigits.startsWith(hashDigits)) { 172 if (!certDigits.startsWith(hashDigits)) {
172 throw new CertificateException("Certificate hash does not start with '" + hashDigits + "': " + certDigits); 173 throw new CertificateException("Certificate hash does not start with '" + hashDigits + "': " + certDigits);
173 } 174 }
174 175
175 176
176 } 177 }
177 178
178 @Override 179 @Override
179 public X509Certificate[] getAcceptedIssuers() { 180 public X509Certificate[] getAcceptedIssuers() {
180 return new X509Certificate[0]; 181 return new X509Certificate[0];
181 } 182 }
182 } 183 }
183 } 184 }