Mercurial > hg > monetdb-java
changeset 802:5d04490bc58b monetdbs
Add tests using monetdb-tlstester
author | Joeri van Ruth <joeri.van.ruth@monetdbsolutions.com> |
---|---|
date | Mon, 11 Dec 2023 13:45:12 +0100 (16 months ago) |
parents | 88b3e8e89126 |
children | 1671f2eb130b |
files | tests/TLSTester.java |
diffstat | 1 files changed, 304 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
new file mode 100644 --- /dev/null +++ b/tests/TLSTester.java @@ -0,0 +1,304 @@ +import org.monetdb.mcl.net.Parameter; + +import java.io.*; +import java.net.URL; +import java.net.URLConnection; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Properties; + +public class TLSTester { + int verbose = 0; + String serverHost = null; + String altHost = null; + int serverPort = -1; + boolean enableTrusted = false; + File tempDir = null; + final HashMap<String, File> fileCache = new HashMap<>(); + + public TLSTester(String[] args) { + for (int i = 0; i < args.length; i++) { + String arg = args[i]; + if (arg.equals("-v")) { + verbose = 1; + } else if (arg.equals("-a")) { + altHost = args[++i]; + } else if (arg.equals("-t")) { + enableTrusted = true; + } else if (!arg.startsWith("-") && serverHost == null) { + int idx = arg.indexOf(':'); + if (idx > 0) { + serverHost = arg.substring(0, idx); + try { + serverPort = Integer.parseInt(arg.substring(idx + 1)); + if (serverPort > 0 && serverPort < 65536) + continue; + } catch (NumberFormatException ignored) { + } + } + // if we get here it wasn't very valid + throw new IllegalArgumentException("Invalid argument: " + arg); + } else { + throw new IllegalArgumentException("Unexpected argument: " + arg); + } + } + } + + public static void main(String[] args) throws IOException, SQLException, ClassNotFoundException { + Class.forName("org.monetdb.jdbc.MonetDriver"); + TLSTester main = new TLSTester(args); + main.run(); + } + + private HashMap<String,Integer> loadPortMap(String testName) throws IOException { + HashMap<String,Integer> portMap = new HashMap<>(); + InputStream in = fetchData("/?test=" + testName); + BufferedReader br = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); + for (String line = br.readLine(); line != null; line = br.readLine()) { + int idx = line.indexOf(':'); + String service = line.substring(0, idx); + int port; + try { + port = Integer.parseInt(line.substring(idx + 1)); + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid port map line: " + line); + } + portMap.put(service, port); + } + return portMap; + } + + private File resource(String resource) throws IOException { + if (!fileCache.containsKey(resource)) + fetchResource(resource); + return fileCache.get(resource); + } + + private void fetchResource(String resource) throws IOException { + if (!resource.startsWith("/")) { + throw new IllegalArgumentException("Resource must start with slash: " + resource); + } + if (tempDir == null) { + tempDir = Files.createTempDirectory("tlstest").toFile(); + tempDir.deleteOnExit(); + } + File outPath = new File(tempDir, resource.substring(1)); + try (InputStream in = fetchData(resource); FileOutputStream out = new FileOutputStream(outPath)) { + byte[] buffer = new byte[12]; + while (true) { + int n = in.read(buffer); + if (n <= 0) + break; + out.write(buffer, 0, n); + } + } + fileCache.put(resource, outPath); + } + + private byte[] fetchBytes(String resource) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (InputStream in = fetchData(resource)) { + byte[] buffer = new byte[22]; + while (true) { + int nread = in.read(buffer); + if (nread <= 0) + break; + out.write(buffer, 0, nread); + } + return out.toByteArray(); + } + } + + private InputStream fetchData(String resource) throws IOException { + URL url = new URL("http://" + serverHost + ":" + serverPort + resource); + URLConnection conn = url.openConnection(); + conn.connect(); + return conn.getInputStream(); + } + + private void run() throws IOException, SQLException { + test_connect_plain(); + test_connect_tls(); + test_refuse_no_cert(); + test_refuse_wrong_cert(); + test_refuse_wrong_host(); + test_refuse_tlsv12(); + test_refuse_expired(); +// test_connect_client_auth1(); +// test_connect_client_auth2(); + test_fail_tls_to_plain(); +// test_fail_plain_to_tls(); +// test_connect_server_name(); +// test_connect_alpn_mapi9(); + test_connect_trusted(); + test_refuse_trusted_wrong_host(); + } + + private void test_connect_plain() throws IOException, SQLException { + attempt("connect_plain", "plain").with(Parameter.TLS, false).expectSuccess(); + } + + private void test_connect_tls() throws IOException, SQLException { + Attempt attempt = attempt("connect_tls", "server1"); + attempt.withFile(Parameter.CERT, "/ca1.crt").expectSuccess(); + } + + private void test_refuse_no_cert() throws IOException, SQLException { + attempt("refuse_no_cert", "server1").expectFailure("PKIX path building failed"); + } + + private void test_refuse_wrong_cert() throws IOException, SQLException { + Attempt attempt = attempt("refuse_wrong_cert", "server1"); + attempt.withFile(Parameter.CERT, "/ca2.crt").expectFailure("PKIX path building failed"); + } + + private void test_refuse_wrong_host() throws IOException, SQLException { + Attempt attempt = attempt("refuse_wrong_host", "server1").with(Parameter.HOST, altHost); + attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("No subject alternative DNS name"); + } + + private void test_refuse_tlsv12() throws IOException, SQLException { + Attempt attempt = attempt("refuse_tlsv12", "tls12"); + attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("protocol_version"); + } + + private void test_refuse_expired() throws IOException, SQLException { + Attempt attempt = attempt("refuse_expired", "expiredcert"); + attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("PKIX path validation failed"); + } + + private void test_connect_client_auth1() throws IOException, SQLException { + attempt("connect_client_auth1", "clientauth") + .withFile(Parameter.CERT, "/ca1.crt") + .withFile(Parameter.CLIENTKEY, "/client2.keycrt") + .expectSuccess(); + } + + private void test_connect_client_auth2() throws IOException, SQLException { + attempt("connect_client_auth2", "clientauth") + .withFile(Parameter.CERT, "/ca1.crt") + .withFile(Parameter.CLIENTKEY, "/client2.key") + .withFile(Parameter.CLIENTCERT, "/client2.crt") + .expectSuccess(); + } + + private void test_fail_tls_to_plain() throws IOException, SQLException { + Attempt attempt = attempt("fail_tls_to_plain", "plain"); + attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure(""); + + } + + private void test_fail_plain_to_tls() throws IOException, SQLException { + attempt("fail_plain_to_tls", "server1").with(Parameter.TLS, false).expectFailure("asdf"); + } + + private void test_connect_server_name() throws IOException, SQLException { + Attempt attempt = attempt("connect_server_name", "sni"); + attempt.withFile(Parameter.CERT, "/ca1.crt").expectSuccess(); + } + + private void test_connect_alpn_mapi9() throws IOException, SQLException { + attempt("connect_alpn_mapi9", ""); + } + + private void test_connect_trusted() throws IOException, SQLException { + attempt("connect_trusted", "alpn_mapi9") + .with(Parameter.HOST, "monetdb.ergates.nl") + .with(Parameter.PORT, 50000) + .expectSuccess(); + } + + private void test_refuse_trusted_wrong_host() throws IOException, SQLException { + attempt("connect_trusted", null) + .with(Parameter.HOST, "monetdbxyz.ergates.nl") + .with(Parameter.PORT, 50000) + .expectFailure("No subject alternative DNS name"); + } + + private Attempt attempt(String testName, String portName) throws IOException { + return new Attempt(testName, portName); + } + + private class Attempt { + private final String testName; + private final Properties props = new Properties(); + boolean disabled = false; + + public Attempt(String testName, String portName) throws IOException { + HashMap<String, Integer> portMap = loadPortMap(testName); + + this.testName = testName; + with(Parameter.TLS, true); + with(Parameter.HOST, serverHost); + with(Parameter.SO_TIMEOUT, 3000); + if (portName != null) { + Integer port = portMap.get(portName); + if (port != null) { + with(Parameter.PORT, port); + } else { + throw new RuntimeException("Unknown port name: " + portName); + } + } + } + + private Attempt with(Parameter parm, String value) { + props.setProperty(parm.name, value); + return this; + } + + private Attempt with(Parameter parm, int value) { + props.setProperty(parm.name, Integer.toString(value)); + return this; + } + + private Attempt with(Parameter parm, boolean value) { + props.setProperty(parm.name, value ? "true" : "false"); + return this; + } + + private Attempt withFile(Parameter parm, String certResource) throws IOException { + File certFile = resource(certResource); + String path = certFile.getPath(); + with(parm, path); + return this; + } + + public void expectSuccess() throws SQLException { + if (disabled) + return; + try { + Connection conn = DriverManager.getConnection("jdbc:monetdb:", props); + conn.close(); + } catch (SQLException e) { + if (e.getMessage().startsWith("Sorry, this is not a real MonetDB instance")) { + // it looks like a failure but this is actually our success scenario + // because this is what the TLS Tester does when the connection succeeds. + return; + } + // other exceptions ARE errors and should be reported. + throw e; + } + } + + public void expectFailure(String... expectedMessages) throws SQLException { + if (disabled) + return; + try { + expectSuccess(); + throw new RuntimeException("Expected test " + testName + " to throw an exception but it didn't"); + } catch (SQLException e) { + for (String expected : expectedMessages) + if (e.getMessage().contains(expected)) + return; + String message = "Test " + testName + " threw the wrong exception: " + e.getMessage() + '\n' + "Expected:\n <" + String.join(">\n <", expectedMessages) + ">"; + throw new RuntimeException(message); + + } + } + + } +}