Skip to content

Commit

Permalink
Add SSL option support across all drivers.
Browse files Browse the repository at this point in the history
  • Loading branch information
LibbyKen committed Jan 8, 2025
1 parent 4aa9883 commit 4dab51d
Show file tree
Hide file tree
Showing 19 changed files with 476 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ jacoco.exec

# VS Code
.vscode/

.java-version
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import com.amazonaws.secretsmanager.util.SQLExceptionUtils;
import com.amazonaws.secretsmanager.util.URLBuilder;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
import software.amazon.awssdk.utils.StringUtils;
Expand Down Expand Up @@ -113,6 +114,17 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
if (StringUtils.isNotBlank(dbname)) {
url += "/" + dbname;
}
else {
url += "/";
}
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
if("true".equalsIgnoreCase(sslMode)) {
return new URLBuilder(url).appendProperty("sslConnection", "true").build();
}
return url;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public abstract class AWSSecretsManagerDriver implements Driver {
*/
public static final String INVALID_SECRET_STRING_JSON = "Could not parse SecretString JSON";

/**
* Logger for the AWSSecretsManagerDriver class.
*/
private static final Logger logger = Logger.getLogger(AWSSecretsManagerDriver.class.getName());

private SecretCache secretCache;

private String realDriverClass;
Expand Down Expand Up @@ -319,6 +324,32 @@ public boolean acceptsURL(String url) throws SQLException {
*/
public abstract String getDefaultDriverClass();

/**
* Enforce SSL on the given database URL based on the specified SSL mode.
* This method is called when the <code>connect</code> method is called with a secret ID instead of a URL.
*
* @param url The database URL to enforce SSL on.
* @param sslMode The SSL mode to enforce.
*
* @return String The database URL with SSL enforced.
*/
public abstract String enforceSSL(String url, String sslMode);

private String getSSLConfig(JsonNode jsonObject) {
JsonNode sslNode = jsonObject.get("ssl");

if(sslNode == null) {
return "true";
}

if (sslNode.isBoolean()) {
return sslNode.asBoolean() ? "true" : "false";
} else if (sslNode.isTextual()) {
return sslNode.asText();
}
return "true";
}

/**
* Calls the real driver's <code>connect</code> method using credentials from a secret stored in AWS Secrets
* Manager.
Expand All @@ -329,30 +360,42 @@ public boolean acceptsURL(String url) throws SQLException {
* credentials retrieved from Secrets Manager.
* @param credentialsSecretId The friendly name or ARN of the secret that stores the
* login credentials.
* @param isSecretId A flag indicating if the connection uses a secret ID.
*
* @return Connection A database connection.
*
* @throws SQLException If there is an error from the driver or underlying
* database.
* @throws InterruptedException If there was an interruption during secret refresh.
*/
private Connection connectWithSecret(String unwrappedUrl, Properties info, String credentialsSecretId)
private Connection connectWithSecret(String unwrappedUrl, Properties info, String credentialsSecretId, boolean isSecretId)
throws SQLException, InterruptedException {
int retryCount = 0;
String sslMode = null;
while (retryCount++ <= MAX_RETRY) {
String secretString = secretCache.getSecretString(credentialsSecretId);
Properties updatedInfo = new Properties(info);
try {
JsonNode jsonObject = mapper.readTree(secretString);
updatedInfo.setProperty("user", jsonObject.get("username").asText());
updatedInfo.setProperty("password", jsonObject.get("password").asText());
sslMode = isSecretId ? getSSLConfig(jsonObject) : null;
} catch (IOException e) {
// Most likely to occur in the event that the data is not JSON.
// Or the secret's username and/or password fields have been
// removed entirely. Either scenario is most often a user error.
throw new RuntimeException(INVALID_SECRET_STRING_JSON);
}

if (sslMode != null && !"false".equalsIgnoreCase(sslMode)) {
try {
return getWrappedDriver().connect(enforceSSL(unwrappedUrl, sslMode), updatedInfo);
} catch (SQLException e) {
// If SSL connection fails, fall back to non-SSL
logger.warning("SSL connection failed. Falling back to non-SSL connection. Error: " + e.getMessage());
}
}

try {
return getWrappedDriver().connect(unwrappedUrl, updatedInfo);
} catch (Exception e) {
Expand All @@ -379,6 +422,7 @@ public Connection connect(String url, Properties info) throws SQLException {
}

String unwrappedUrl = "";
boolean isSecretId = false;
if (url.startsWith(SCHEME)) { // If this is a URL in the correct scheme, unwrap it
unwrappedUrl = unwrapUrl(url);
} else { // Else, assume this is a secret ID and try to retrieve it
Expand All @@ -395,6 +439,7 @@ public Connection connect(String url, Properties info) throws SQLException {
JsonNode dbnameNode = jsonObject.get("dbname");
String dbname = dbnameNode == null ? null : dbnameNode.asText();
unwrappedUrl = constructUrlFromEndpointPortDatabase(endpoint, port, dbname);
isSecretId = true;
} catch (IOException e) {
// Most likely to occur in the event that the data is not JSON.
// Or the secret has been modified and is no longer valid.
Expand All @@ -406,7 +451,7 @@ public Connection connect(String url, Properties info) throws SQLException {
if (info != null && info.getProperty("user") != null) {
String credentialsSecretId = info.getProperty("user");
try {
return connectWithSecret(unwrappedUrl, info, credentialsSecretId);
return connectWithSecret(unwrappedUrl, info, credentialsSecretId, isSecretId);
} catch (InterruptedException e) {
// User driven exception. Throw a runtime exception.
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.secretsmanager.sql;

import java.sql.SQLException;
import com.amazonaws.secretsmanager.util.URLBuilder;

import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
Expand Down Expand Up @@ -126,6 +127,21 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "TLS":
case "TLSv1":
case "TLSv1.1":
case "TLSv1.2":
return builder.appendProperty("sslProtocol", sslMode).build();
default:
break;
}
return builder.appendProperty("sslProtocol", "TLS").build();
}

@Override
public String getDefaultDriverClass() {
return "com.microsoft.sqlserver.jdbc.SQLServerDriver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import com.amazonaws.secretsmanager.util.SQLExceptionUtils;
import com.amazonaws.secretsmanager.util.URLBuilder;

import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
Expand Down Expand Up @@ -121,6 +122,22 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "disable":
break;
case "trust":
case "verify-full":
case "verify-ca":
return builder.appendParameter("sslMode", sslMode, !url.contains("?")).build();
default:
return builder.appendParameter("sslMode", "trust", !url.contains("?")).build();
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "org.mariadb.jdbc.Driver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import com.amazonaws.secretsmanager.util.SQLExceptionUtils;
import com.amazonaws.secretsmanager.util.URLBuilder;

import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
Expand Down Expand Up @@ -121,6 +122,23 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "DISABLED":
break;
case "PREFERRED":
case "REQUIRED":
case "VERIFY_CA":
case "VERIFY_IDENTITY":
return builder.appendParameter("sslMode", sslMode, !url.contains("?")).build();
default:
return builder.appendParameter("sslMode", "PREFERRED", !url.contains("?")).build();
}
return url;
}

@Override
public String getDefaultDriverClass() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
if("true".equalsIgnoreCase(sslMode)) {
if (url.startsWith("jdbc:oracle:thin:@//")) {
return url.replace("jdbc:oracle:thin:@//", "jdbc:oracle:thin:@tcps://");
}
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "oracle.jdbc.OracleDriver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.secretsmanager.sql;

import java.sql.SQLException;
import com.amazonaws.secretsmanager.util.URLBuilder;

import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
Expand Down Expand Up @@ -134,6 +135,34 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {

if (url.endsWith("/")) {
url = url.substring(0, url.length() - 1);
}

URLBuilder builder = new URLBuilder(url);

if ("true".equalsIgnoreCase(sslMode)) {
return builder.appendParameter("sslMode", "verify-full", !url.contains("?")).build();
} else {
switch(sslMode) {
case "disable":
break;
case "allow":
case "prefer":
case "require":
case "verify-ca":
case "verify-full":
return builder.appendParameter("sslMode", sslMode, !url.contains("?")).build();
default:
return builder.appendParameter("sslMode", "prefer", !url.contains("?")).build();
}
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "org.postgresql.Driver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.secretsmanager.sql;

import java.sql.SQLException;
import com.amazonaws.secretsmanager.util.URLBuilder;

import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
Expand Down Expand Up @@ -128,6 +129,24 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "disable":
break;
case "allow":
case "prefer":
case "require":
case "verify-ca":
case "verify-full":
return builder.appendProperty("sslmode", sslMode).build();
default:
return builder.appendProperty("sslmode", "prefer").build();
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "com.amazon.redshift.Driver";
Expand Down
Loading

0 comments on commit 4dab51d

Please sign in to comment.