changeset 802:5d04490bc58b monetdbs

Add tests using monetdb-tlstester
author Joeri van Ruth <joeri.van.ruth@monetdbsolutions.com>
date Mon, 11 Dec 2023 13:45:12 +0100 (16 months ago)
parents 88b3e8e89126
children 1671f2eb130b
files tests/TLSTester.java
diffstat 1 files changed, 304 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
new file mode 100644
--- /dev/null
+++ b/tests/TLSTester.java
@@ -0,0 +1,304 @@
+import org.monetdb.mcl.net.Parameter;
+
+import java.io.*;
+import java.net.URL;
+import java.net.URLConnection;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.SQLException;
+import java.util.HashMap;
+import java.util.Properties;
+
+public class TLSTester {
+    int verbose = 0;
+    String serverHost = null;
+    String altHost = null;
+    int serverPort = -1;
+    boolean enableTrusted = false;
+    File tempDir = null;
+    final HashMap<String, File> fileCache = new HashMap<>();
+
+    public TLSTester(String[] args) {
+        for (int i = 0; i < args.length; i++) {
+            String arg = args[i];
+            if (arg.equals("-v")) {
+                verbose = 1;
+            } else if (arg.equals("-a")) {
+                altHost = args[++i];
+            } else if (arg.equals("-t")) {
+                enableTrusted = true;
+            } else if (!arg.startsWith("-") && serverHost == null) {
+                int idx = arg.indexOf(':');
+                if (idx > 0) {
+                    serverHost = arg.substring(0, idx);
+                    try {
+                        serverPort = Integer.parseInt(arg.substring(idx + 1));
+                        if (serverPort > 0 && serverPort < 65536)
+                            continue;
+                    } catch (NumberFormatException ignored) {
+                    }
+                }
+                // if we get here it wasn't very valid
+                throw new IllegalArgumentException("Invalid argument: " + arg);
+            } else {
+                throw new IllegalArgumentException("Unexpected argument: " + arg);
+            }
+        }
+    }
+
+    public static void main(String[] args) throws IOException, SQLException, ClassNotFoundException {
+        Class.forName("org.monetdb.jdbc.MonetDriver");
+        TLSTester main = new TLSTester(args);
+        main.run();
+    }
+
+    private HashMap<String,Integer> loadPortMap(String testName) throws IOException {
+        HashMap<String,Integer> portMap = new HashMap<>();
+        InputStream in = fetchData("/?test=" + testName);
+        BufferedReader br = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8));
+        for (String line = br.readLine(); line != null; line = br.readLine()) {
+            int idx = line.indexOf(':');
+            String service = line.substring(0, idx);
+            int port;
+            try {
+                port = Integer.parseInt(line.substring(idx + 1));
+            } catch (NumberFormatException e) {
+                throw new RuntimeException("Invalid port map line: " + line);
+            }
+            portMap.put(service, port);
+        }
+        return portMap;
+    }
+
+    private File resource(String resource) throws IOException {
+        if (!fileCache.containsKey(resource))
+            fetchResource(resource);
+        return fileCache.get(resource);
+    }
+
+    private void fetchResource(String resource) throws IOException {
+        if (!resource.startsWith("/")) {
+            throw new IllegalArgumentException("Resource must start with slash: " + resource);
+        }
+        if (tempDir == null) {
+            tempDir = Files.createTempDirectory("tlstest").toFile();
+            tempDir.deleteOnExit();
+        }
+        File outPath = new File(tempDir, resource.substring(1));
+        try (InputStream in = fetchData(resource); FileOutputStream out = new FileOutputStream(outPath)) {
+            byte[] buffer = new byte[12];
+            while (true) {
+                int n = in.read(buffer);
+                if (n <= 0)
+                    break;
+                out.write(buffer, 0, n);
+            }
+        }
+        fileCache.put(resource, outPath);
+    }
+
+    private byte[] fetchBytes(String resource) throws IOException {
+        ByteArrayOutputStream out = new ByteArrayOutputStream();
+        try (InputStream in = fetchData(resource)) {
+            byte[] buffer = new byte[22];
+            while (true) {
+                int nread = in.read(buffer);
+                if (nread <= 0)
+                    break;
+                out.write(buffer, 0, nread);
+            }
+            return out.toByteArray();
+        }
+    }
+
+    private InputStream fetchData(String resource) throws IOException {
+        URL url = new URL("http://" + serverHost + ":" + serverPort + resource);
+        URLConnection conn = url.openConnection();
+        conn.connect();
+        return conn.getInputStream();
+    }
+
+    private void run() throws IOException, SQLException {
+        test_connect_plain();
+        test_connect_tls();
+        test_refuse_no_cert();
+        test_refuse_wrong_cert();
+        test_refuse_wrong_host();
+        test_refuse_tlsv12();
+        test_refuse_expired();
+//        test_connect_client_auth1();
+//        test_connect_client_auth2();
+        test_fail_tls_to_plain();
+//        test_fail_plain_to_tls();
+//        test_connect_server_name();
+//        test_connect_alpn_mapi9();
+        test_connect_trusted();
+        test_refuse_trusted_wrong_host();
+    }
+
+    private void test_connect_plain() throws IOException, SQLException {
+        attempt("connect_plain", "plain").with(Parameter.TLS, false).expectSuccess();
+    }
+
+    private void test_connect_tls() throws IOException, SQLException {
+        Attempt attempt = attempt("connect_tls", "server1");
+        attempt.withFile(Parameter.CERT, "/ca1.crt").expectSuccess();
+    }
+
+    private void test_refuse_no_cert() throws IOException, SQLException {
+        attempt("refuse_no_cert", "server1").expectFailure("PKIX path building failed");
+    }
+
+    private void test_refuse_wrong_cert() throws IOException, SQLException {
+        Attempt attempt = attempt("refuse_wrong_cert", "server1");
+        attempt.withFile(Parameter.CERT, "/ca2.crt").expectFailure("PKIX path building failed");
+    }
+
+    private void test_refuse_wrong_host() throws IOException, SQLException {
+        Attempt attempt = attempt("refuse_wrong_host", "server1").with(Parameter.HOST, altHost);
+        attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("No subject alternative DNS name");
+    }
+
+    private void test_refuse_tlsv12() throws IOException, SQLException {
+        Attempt attempt = attempt("refuse_tlsv12", "tls12");
+        attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("protocol_version");
+    }
+
+    private void test_refuse_expired() throws IOException, SQLException {
+        Attempt attempt = attempt("refuse_expired", "expiredcert");
+        attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("PKIX path validation failed");
+    }
+
+    private void test_connect_client_auth1() throws IOException, SQLException {
+        attempt("connect_client_auth1", "clientauth")
+                .withFile(Parameter.CERT, "/ca1.crt")
+                .withFile(Parameter.CLIENTKEY, "/client2.keycrt")
+                .expectSuccess();
+    }
+
+    private void test_connect_client_auth2() throws IOException, SQLException {
+        attempt("connect_client_auth2", "clientauth")
+                .withFile(Parameter.CERT, "/ca1.crt")
+                .withFile(Parameter.CLIENTKEY, "/client2.key")
+                .withFile(Parameter.CLIENTCERT, "/client2.crt")
+                .expectSuccess();
+    }
+
+    private void test_fail_tls_to_plain() throws IOException, SQLException {
+        Attempt attempt = attempt("fail_tls_to_plain", "plain");
+        attempt.withFile(Parameter.CERT, "/ca1.crt").expectFailure("");
+
+    }
+
+    private void test_fail_plain_to_tls() throws IOException, SQLException {
+        attempt("fail_plain_to_tls", "server1").with(Parameter.TLS, false).expectFailure("asdf");
+    }
+
+    private void test_connect_server_name() throws IOException, SQLException {
+        Attempt attempt = attempt("connect_server_name", "sni");
+        attempt.withFile(Parameter.CERT, "/ca1.crt").expectSuccess();
+    }
+
+    private void test_connect_alpn_mapi9() throws IOException, SQLException {
+        attempt("connect_alpn_mapi9", "");
+    }
+
+    private void test_connect_trusted() throws IOException, SQLException {
+        attempt("connect_trusted", "alpn_mapi9")
+                .with(Parameter.HOST, "monetdb.ergates.nl")
+                .with(Parameter.PORT, 50000)
+                .expectSuccess();
+    }
+
+    private void test_refuse_trusted_wrong_host() throws IOException, SQLException {
+        attempt("connect_trusted", null)
+                .with(Parameter.HOST, "monetdbxyz.ergates.nl")
+                .with(Parameter.PORT, 50000)
+                .expectFailure("No subject alternative DNS name");
+    }
+
+    private Attempt attempt(String testName, String portName) throws IOException {
+        return new Attempt(testName, portName);
+    }
+
+    private class Attempt {
+        private final String testName;
+        private final Properties props = new Properties();
+        boolean disabled = false;
+
+        public Attempt(String testName, String portName) throws IOException {
+            HashMap<String, Integer> portMap = loadPortMap(testName);
+
+            this.testName = testName;
+            with(Parameter.TLS, true);
+            with(Parameter.HOST, serverHost);
+            with(Parameter.SO_TIMEOUT, 3000);
+            if (portName != null) {
+                Integer port = portMap.get(portName);
+                if (port != null) {
+                    with(Parameter.PORT, port);
+                } else {
+                    throw new RuntimeException("Unknown port name: " + portName);
+                }
+            }
+        }
+
+        private Attempt with(Parameter parm, String value) {
+            props.setProperty(parm.name, value);
+            return this;
+        }
+
+        private Attempt with(Parameter parm, int value) {
+            props.setProperty(parm.name, Integer.toString(value));
+            return this;
+        }
+
+        private Attempt with(Parameter parm, boolean value) {
+            props.setProperty(parm.name, value ? "true" : "false");
+            return this;
+        }
+
+        private Attempt withFile(Parameter parm, String certResource) throws IOException {
+            File certFile = resource(certResource);
+            String path = certFile.getPath();
+            with(parm, path);
+            return this;
+        }
+
+        public void expectSuccess() throws SQLException {
+            if (disabled)
+                return;
+            try {
+                Connection conn = DriverManager.getConnection("jdbc:monetdb:", props);
+                conn.close();
+            } catch (SQLException e) {
+                if (e.getMessage().startsWith("Sorry, this is not a real MonetDB instance")) {
+                    // it looks like a failure but this is actually our success scenario
+                    // because this is what the TLS Tester does when the connection succeeds.
+                    return;
+                }
+                // other exceptions ARE errors and should be reported.
+                throw e;
+            }
+        }
+
+        public void expectFailure(String... expectedMessages) throws SQLException {
+            if (disabled)
+                return;
+            try {
+                expectSuccess();
+                throw new RuntimeException("Expected test " + testName + " to throw an exception but it didn't");
+            } catch (SQLException e) {
+                for (String expected : expectedMessages)
+                    if (e.getMessage().contains(expected))
+                        return;
+                String message = "Test " + testName + " threw the wrong exception: " + e.getMessage() + '\n' + "Expected:\n        <" + String.join(">\n        <", expectedMessages) + ">";
+                throw new RuntimeException(message);
+
+            }
+        }
+
+    }
+}