Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSL option support across all drivers. #258

Open
wants to merge 1 commit into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ jacoco.exec

# VS Code
.vscode/

.java-version
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import com.amazonaws.secretsmanager.util.SQLExceptionUtils;
import com.amazonaws.secretsmanager.util.URLBuilder;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
import software.amazon.awssdk.utils.StringUtils;
Expand Down Expand Up @@ -113,6 +114,17 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
if (StringUtils.isNotBlank(dbname)) {
url += "/" + dbname;
}
else {
url += "/";
}
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
if("true".equalsIgnoreCase(sslMode)) {
return new URLBuilder(url).appendProperty("sslConnection", "true").build();
}
return url;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public abstract class AWSSecretsManagerDriver implements Driver {
*/
public static final String INVALID_SECRET_STRING_JSON = "Could not parse SecretString JSON";

/**
* Logger for the AWSSecretsManagerDriver class.
*/
private static final Logger logger = Logger.getLogger(AWSSecretsManagerDriver.class.getName());

private SecretCache secretCache;

private String realDriverClass;
Expand Down Expand Up @@ -319,6 +324,32 @@ public boolean acceptsURL(String url) throws SQLException {
*/
public abstract String getDefaultDriverClass();

/**
* Enforce SSL on the given database URL based on the specified SSL mode.
* This method is called when the <code>connect</code> method is called with a secret ID instead of a URL.
*
* @param url The database URL to enforce SSL on.
* @param sslMode The SSL mode to enforce.
*
* @return String The database URL with SSL enforced.
*/
public abstract String enforceSSL(String url, String sslMode);

private String getSSLConfig(JsonNode jsonObject) {
JsonNode sslNode = jsonObject.get("ssl");

if(sslNode == null) {
return "true";
}

if (sslNode.isBoolean()) {
return sslNode.asBoolean() ? "true" : "false";
} else if (sslNode.isTextual()) {
return sslNode.asText();
}
return "true";
}

/**
* Calls the real driver's <code>connect</code> method using credentials from a secret stored in AWS Secrets
* Manager.
Expand All @@ -329,30 +360,42 @@ public boolean acceptsURL(String url) throws SQLException {
* credentials retrieved from Secrets Manager.
* @param credentialsSecretId The friendly name or ARN of the secret that stores the
* login credentials.
* @param isSecretId A flag indicating if the connection uses a secret ID.
*
* @return Connection A database connection.
*
* @throws SQLException If there is an error from the driver or underlying
* database.
* @throws InterruptedException If there was an interruption during secret refresh.
*/
private Connection connectWithSecret(String unwrappedUrl, Properties info, String credentialsSecretId)
private Connection connectWithSecret(String unwrappedUrl, Properties info, String credentialsSecretId, boolean isSecretId)
throws SQLException, InterruptedException {
int retryCount = 0;
String sslMode = null;
while (retryCount++ <= MAX_RETRY) {
String secretString = secretCache.getSecretString(credentialsSecretId);
Properties updatedInfo = new Properties(info);
try {
JsonNode jsonObject = mapper.readTree(secretString);
updatedInfo.setProperty("user", jsonObject.get("username").asText());
updatedInfo.setProperty("password", jsonObject.get("password").asText());
sslMode = isSecretId ? getSSLConfig(jsonObject) : null;
} catch (IOException e) {
// Most likely to occur in the event that the data is not JSON.
// Or the secret's username and/or password fields have been
// removed entirely. Either scenario is most often a user error.
throw new RuntimeException(INVALID_SECRET_STRING_JSON);
}

if (sslMode != null && !"false".equalsIgnoreCase(sslMode)) {
try {
return getWrappedDriver().connect(enforceSSL(unwrappedUrl, sslMode), updatedInfo);
} catch (SQLException e) {
// If SSL connection fails, fall back to non-SSL
logger.warning("SSL connection failed. Falling back to non-SSL connection. Error: " + e.getMessage());
}
}

try {
return getWrappedDriver().connect(unwrappedUrl, updatedInfo);
} catch (Exception e) {
Expand All @@ -379,6 +422,7 @@ public Connection connect(String url, Properties info) throws SQLException {
}

String unwrappedUrl = "";
boolean isSecretId = false;
if (url.startsWith(SCHEME)) { // If this is a URL in the correct scheme, unwrap it
unwrappedUrl = unwrapUrl(url);
} else { // Else, assume this is a secret ID and try to retrieve it
Expand All @@ -395,6 +439,7 @@ public Connection connect(String url, Properties info) throws SQLException {
JsonNode dbnameNode = jsonObject.get("dbname");
String dbname = dbnameNode == null ? null : dbnameNode.asText();
unwrappedUrl = constructUrlFromEndpointPortDatabase(endpoint, port, dbname);
isSecretId = true;
} catch (IOException e) {
// Most likely to occur in the event that the data is not JSON.
// Or the secret has been modified and is no longer valid.
Expand All @@ -406,7 +451,7 @@ public Connection connect(String url, Properties info) throws SQLException {
if (info != null && info.getProperty("user") != null) {
String credentialsSecretId = info.getProperty("user");
try {
return connectWithSecret(unwrappedUrl, info, credentialsSecretId);
return connectWithSecret(unwrappedUrl, info, credentialsSecretId, isSecretId);
} catch (InterruptedException e) {
// User driven exception. Throw a runtime exception.
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.secretsmanager.sql;

import java.sql.SQLException;
import com.amazonaws.secretsmanager.util.URLBuilder;

import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
Expand Down Expand Up @@ -126,6 +127,21 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "TLS":
case "TLSv1":
case "TLSv1.1":
case "TLSv1.2":
return builder.appendProperty("sslProtocol", sslMode).build();
default:
break;
}
return builder.appendProperty("sslProtocol", "TLS").build();
}

@Override
public String getDefaultDriverClass() {
return "com.microsoft.sqlserver.jdbc.SQLServerDriver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import com.amazonaws.secretsmanager.util.SQLExceptionUtils;
import com.amazonaws.secretsmanager.util.URLBuilder;

import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
Expand Down Expand Up @@ -121,6 +122,22 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "disable":
break;
case "trust":
case "verify-full":
case "verify-ca":
return builder.appendParameter("sslMode", sslMode, !url.contains("?")).build();
default:
return builder.appendParameter("sslMode", "trust", !url.contains("?")).build();
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "org.mariadb.jdbc.Driver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
import com.amazonaws.secretsmanager.util.SQLExceptionUtils;
import com.amazonaws.secretsmanager.util.URLBuilder;

import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
Expand Down Expand Up @@ -121,6 +122,23 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "DISABLED":
break;
case "PREFERRED":
case "REQUIRED":
case "VERIFY_CA":
case "VERIFY_IDENTITY":
return builder.appendParameter("sslMode", sslMode, !url.contains("?")).build();
default:
return builder.appendParameter("sslMode", "PREFERRED", !url.contains("?")).build();
}
return url;
}

@Override
public String getDefaultDriverClass() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
if("true".equalsIgnoreCase(sslMode)) {
if (url.startsWith("jdbc:oracle:thin:@//")) {
return url.replace("jdbc:oracle:thin:@//", "jdbc:oracle:thin:@tcps://");
}
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "oracle.jdbc.OracleDriver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.secretsmanager.sql;

import java.sql.SQLException;
import com.amazonaws.secretsmanager.util.URLBuilder;

import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
Expand Down Expand Up @@ -134,6 +135,34 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {

if (url.endsWith("/")) {
url = url.substring(0, url.length() - 1);
}

URLBuilder builder = new URLBuilder(url);

if ("true".equalsIgnoreCase(sslMode)) {
return builder.appendParameter("sslMode", "verify-full", !url.contains("?")).build();
} else {
switch(sslMode) {
case "disable":
break;
case "allow":
case "prefer":
case "require":
case "verify-ca":
case "verify-full":
return builder.appendParameter("sslMode", sslMode, !url.contains("?")).build();
default:
return builder.appendParameter("sslMode", "prefer", !url.contains("?")).build();
}
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "org.postgresql.Driver";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.secretsmanager.sql;

import java.sql.SQLException;
import com.amazonaws.secretsmanager.util.URLBuilder;

import com.amazonaws.secretsmanager.caching.SecretCache;
import com.amazonaws.secretsmanager.caching.SecretCacheConfiguration;
Expand Down Expand Up @@ -128,6 +129,24 @@ public String constructUrlFromEndpointPortDatabase(String endpoint, String port,
return url;
}

@Override
public String enforceSSL(String url, String sslMode) {
URLBuilder builder = new URLBuilder(url);
switch(sslMode) {
case "disable":
break;
case "allow":
case "prefer":
case "require":
case "verify-ca":
case "verify-full":
return builder.appendProperty("sslmode", sslMode).build();
default:
return builder.appendProperty("sslmode", "prefer").build();
}
return url;
}

@Override
public String getDefaultDriverClass() {
return "com.amazon.redshift.Driver";
Expand Down
Loading
Loading