diff --git a/src/main/java/com/aws/greengrass/disk/spool/DiskSpool.java b/src/main/java/com/aws/greengrass/disk/spool/DiskSpool.java index 482105f..0cec0b0 100644 --- a/src/main/java/com/aws/greengrass/disk/spool/DiskSpool.java +++ b/src/main/java/com/aws/greengrass/disk/spool/DiskSpool.java @@ -95,7 +95,6 @@ public Iterable getAllMessageIds() throws IOException { public void initializeSpooler() throws IOException { try { dao.initialize(); - dao.setUpDatabase(); logger.atInfo().log("Finished setting up Database"); } catch (SQLException e) { throw new IOException(e); diff --git a/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java b/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java index b6a0c94..fdc99a2 100644 --- a/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java +++ b/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java @@ -30,8 +30,10 @@ import java.sql.Statement; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; @@ -42,10 +44,11 @@ public class DiskSpoolDAO { - private static final Logger logger = LogManager.getLogger(DiskSpoolDAO.class); + private static final Logger LOGGER = LogManager.getLogger(DiskSpoolDAO.class); private static final ObjectMapper MAPPER = SerializerFactory.getFailSafeJsonObjectMapper(); - protected static final String DATABASE_CONNECTION_URL = "jdbc:sqlite:%s"; - protected static final String DATABASE_FILE_NAME = "spooler.db"; + private static final String DATABASE_CONNECTION_URL = "jdbc:sqlite:%s"; + private static final String DATABASE_FILE_NAME = "spooler.db"; + private static final String KV_URL = "url"; private static final Set CORRUPTION_ERROR_CODES = new HashSet<>(); static { @@ -55,15 +58,28 @@ public class DiskSpoolDAO { private final Path databasePath; private final String url; + private final CreateSpoolerTable createSpoolerTable = new CreateSpoolerTable(); + private final GetAllSpoolMessageIds getAllSpoolMessageIds = new GetAllSpoolMessageIds(); + private final GetSpoolMessageById getSpoolMessageById = new GetSpoolMessageById(); + private final InsertSpoolMessage insertSpoolMessage = new InsertSpoolMessage(); + private final RemoveSpoolMessageById removeSpoolMessageById = new RemoveSpoolMessageById(); + private final List> allStatements = Arrays.asList( + createSpoolerTable, + getAllSpoolMessageIds, + getSpoolMessageById, + insertSpoolMessage, + removeSpoolMessageById + ); + private final ReentrantLock recoverDBLock = new ReentrantLock(); private final ReentrantReadWriteLock connectionLock = new ReentrantReadWriteLock(); private Connection connection; /** - * This method will construct the database path. + * Create a new DiskSpoolDAO. * - * @param paths The path to the working directory - * @throws IOException when fails to set up the database + * @param paths nucleus paths + * @throws IOException if unable to resolve database path */ @Inject public DiskSpoolDAO(NucleusPaths paths) throws IOException { @@ -79,182 +95,292 @@ public DiskSpoolDAO(NucleusPaths paths) throws IOException { /** * Initialize the database connection. * - * @throws SQLException if db is unable to be created + * @throws SQLException if database is unable to be created */ public void initialize() throws SQLException { try (LockScope ls = LockScope.lock(connectionLock.writeLock())) { close(); - logger.atDebug().kv("url", url).log("Creating DB connection"); - connection = getDbInstance(); + connection = createConnection(); + + // recreate the database table first + createSpoolerTable.replaceStatement(connection); + createSpoolerTable.execute(); + + // eagerly create remaining statements + for (CachedStatement statement : allStatements) { + if (Objects.equals(statement, createSpoolerTable)) { + continue; + } + statement.replaceStatement(connection); + } } } /** - * Close DAO resources. + * Close any open DAO resources, including database connections. */ public void close() { try (LockScope ls = LockScope.lock(connectionLock.writeLock())) { + for (CachedStatement statement : allStatements) { + try { + statement.close(); + } catch (SQLException e) { + LOGGER.atWarn() + .kv("statement", statement.getClass().getSimpleName()) + .log("Unable to close statement"); + } + } if (connection != null) { try { connection.close(); } catch (SQLException e) { - logger.atWarn().kv("url", url).log("Unable to close pre-existing connection"); + LOGGER.atWarn().kv(KV_URL, url).log("Unable to close pre-existing connection"); } } } } /** - * This method will query the existing database for the existing queue of MQTT request Ids - * and return them in order. + * Get ids of all messages from the database. * - * @return ordered list of the existing ids in the persistent queue - * @throws SQLException when fails to get SpoolMessage IDs + * @return ordered iterable of message ids + * @throws SQLException if statement failed to execute, or when unable to read results */ public Iterable getAllSpoolMessageIds() throws SQLException { - // TODO don't recreate prepared statements every time - return performSqlOperation(conn -> { - try (PreparedStatement stmt = getAllSpoolMessageIdsStatement(conn); - ResultSet rs = stmt.executeQuery()) { - return getIdsFromRs(rs); - } - }); + try (ResultSet rs = getAllSpoolMessageIds.execute()) { + return getAllSpoolMessageIds.mapResultToIds(rs); + } } - private PreparedStatement getAllSpoolMessageIdsStatement(Connection conn) throws SQLException { - String query = "SELECT message_id FROM spooler;"; - return conn.prepareStatement(query); + /** + * Get a single message by id from the database. + * + * @param id message id + * @return message + * @throws SQLException if statement failed to execute, or when unable to read results + */ + public SpoolMessage getSpoolMessageById(long id) throws SQLException { + try (ResultSet rs = getSpoolMessageById.executeWithParameters(id)) { + return getSpoolMessageById.mapResultToMessage(id, rs); + } catch (IOException e) { + throw new SQLException(e); + } } /** - * This method will query a SpoolMessage and return it given an id. + * Insert a message into the database. * - * @param messageId the id of the SpoolMessage - * @return SpoolMessage - * @throws SQLException when fails to get a SpoolMessage by id + * @param message message + * @throws SQLException if statement failed to execute */ - public SpoolMessage getSpoolMessageById(long messageId) throws SQLException { - return performSqlOperation(conn -> { - try (PreparedStatement pstmt = getSpoolMessageByIdStatement(conn, messageId); - ResultSet rs = pstmt.executeQuery()) { - try { - return getSpoolMessageFromRs(messageId, rs); - } catch (IOException e) { - throw new SQLException(e); - } - } - }); + public void insertSpoolMessage(SpoolMessage message) throws SQLException { + insertSpoolMessage.executeWithParameters(message); } - private PreparedStatement getSpoolMessageByIdStatement(Connection conn, long messageId) throws SQLException { - String query = "SELECT retried, topic, qos, retain, payload, userProperties, messageExpiryIntervalSeconds, " - + "correlationData, responseTopic, payloadFormat, contentType FROM spooler WHERE message_id = ?;"; - PreparedStatement pstmt = conn.prepareStatement(query); - pstmt.setLong(1, messageId); - return pstmt; + /** + * Remove a message by id from the database. + * + * @param id message id + * @throws SQLException if statement failed to execute + */ + public void removeSpoolMessageById(Long id) throws SQLException { + removeSpoolMessageById.executeWithParameters(id); } /** - * This method will insert a SpoolMessage into the database. + * Create a new database connection. * - * @param message instance of SpoolMessage - * @throws SQLException when fails to insert SpoolMessage in the database + * @return connection + * @throws SQLException if database access error occurs */ - public void insertSpoolMessage(SpoolMessage message) throws SQLException { - performSqlOperation(conn -> { - try (PreparedStatement pstmt = insertSpoolMessageStatement(conn, message)) { - return pstmt.executeUpdate(); - } - }); + protected Connection createConnection() throws SQLException { + LOGGER.atDebug().kv(KV_URL, url).log("Creating database connection"); + return DriverManager.getConnection(url); } - private PreparedStatement insertSpoolMessageStatement(Connection conn, SpoolMessage message) throws SQLException { - String query = - "INSERT INTO spooler (message_id, retried, topic, qos, retain, payload, userProperties, " - + "messageExpiryIntervalSeconds, correlationData, responseTopic, payloadFormat, contentType) " - + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?);"; - PreparedStatement pstmt = conn.prepareStatement(query); - - Publish request = message.getRequest(); - pstmt.setLong(1, message.getId()); - pstmt.setInt(2, message.getRetried().get()); - - // MQTT 3 & 5 fields - pstmt.setString(3, request.getTopic()); - pstmt.setInt(4, request.getQos().getValue()); - pstmt.setBoolean(5, request.isRetain()); - pstmt.setBytes(6, request.getPayload()); + void recoverFromCorruption() throws SQLException { + if (!recoverDBLock.tryLock()) { + // corruption recovery in progress + return; + } - if (request.getUserProperties() == null) { - pstmt.setNull(7, Types.NULL); - } else { + // hold connection lock throughout recovery to prevent incoming operations from executing + try (LockScope ls = LockScope.lock(connectionLock.writeLock())) { + LOGGER.atWarn().kv(KV_URL, url).log("Database is corrupted, creating new database"); + close(); try { - pstmt.setString(7, MAPPER.writeValueAsString(request.getUserProperties())); - } catch (IOException e) { - throw new SQLException(e); + Files.deleteIfExists(databasePath); + } catch (IOException e2) { + throw new SQLException(e2); } + initialize(); + } finally { + recoverDBLock.unlock(); } - if (request.getMessageExpiryIntervalSeconds() == null) { - pstmt.setNull(8, Types.NULL); - } else { - pstmt.setLong(8, request.getMessageExpiryIntervalSeconds()); + } + + class GetAllSpoolMessageIds extends CachedStatement { + private static final String QUERY = "SELECT message_id FROM spooler;"; + + @Override + protected PreparedStatement createStatement(Connection connection) throws SQLException { + return connection.prepareStatement(QUERY); } - if (request.getCorrelationData() == null) { - pstmt.setNull(9, Types.NULL); - } else { - pstmt.setBytes(9, request.getCorrelationData()); + + @Override + protected ResultSet doExecute(PreparedStatement statement) throws SQLException { + return statement.executeQuery(); } - if (request.getResponseTopic() == null) { - pstmt.setNull(10, Types.NULL); - } else { - pstmt.setString(10, request.getResponseTopic()); + + List mapResultToIds(ResultSet rs) throws SQLException { + List ids = new ArrayList<>(); + while (rs.next()) { + ids.add(rs.getLong("message_id")); + } + return ids; } - if (request.getPayloadFormat() == null) { - pstmt.setNull(11, Types.NULL); - } else { - pstmt.setInt(11, request.getPayloadFormat().getValue()); + } + + class GetSpoolMessageById extends CachedStatement { + private static final String QUERY = + "SELECT retried, topic, qos, retain, payload, userProperties, messageExpiryIntervalSeconds, " + + "correlationData, responseTopic, payloadFormat, contentType " + + "FROM spooler WHERE message_id = ?;"; + + @Override + protected PreparedStatement createStatement(Connection connection) throws SQLException { + return connection.prepareStatement(QUERY); } - if (request.getContentType() == null) { - pstmt.setNull(12, Types.NULL); - } else { - pstmt.setString(12, request.getContentType()); + + @Override + protected ResultSet doExecute(PreparedStatement statement) throws SQLException { + return statement.executeQuery(); } - return pstmt; - } - /** - * This method will remove a SpoolMessage from the database given its id. - * - * @param messageId the id of the SpoolMessage - * @throws SQLException when fails to remove a SpoolMessage by id - */ - public void removeSpoolMessageById(Long messageId) throws SQLException { - performSqlOperation(conn -> { - try (PreparedStatement pstmt = removeSpoolMessageByIdStatement(conn, messageId)) { - return pstmt.executeUpdate(); + ResultSet executeWithParameters(Long id) throws SQLException { + return executeWithParameters(s -> { + s.setLong(1, id); + return null; + }); + } + + SpoolMessage mapResultToMessage(long id, ResultSet rs) throws SQLException, IOException { + if (!rs.next()) { + return null; } - }); + Publish request = Publish.builder() + .qos(QOS.fromInt(rs.getInt("qos"))) + .retain(rs.getBoolean("retain")) + .topic(rs.getString("topic")) + .payload(rs.getBytes("payload")) + .payloadFormat(rs.getObject("payloadFormat") == null + ? null : Publish.PayloadFormatIndicator.fromInt(rs.getInt("payloadFormat"))) + .messageExpiryIntervalSeconds(rs.getObject("messageExpiryIntervalSeconds") == null + ? null : rs.getLong("messageExpiryIntervalSeconds")) + .responseTopic(rs.getString("responseTopic")) + .correlationData(rs.getBytes("correlationData")) + .contentType(rs.getString("contentType")) + .userProperties(rs.getString("userProperties") == null + ? null : MAPPER.readValue(rs.getString("userProperties"), + new TypeReference>(){})).build(); + + return SpoolMessage.builder() + .id(id) + .retried(new AtomicInteger(rs.getInt("retried"))) + .request(request).build(); + } } - private PreparedStatement removeSpoolMessageByIdStatement(Connection conn, long messageId) throws SQLException { - String query = "DELETE FROM spooler WHERE message_id = ?;"; - PreparedStatement pstmt = conn.prepareStatement(query); - pstmt.setLong(1, messageId); - return pstmt; + class InsertSpoolMessage extends CachedStatement { + private static final String QUERY = + "INSERT INTO spooler (message_id, retried, topic, qos, retain, payload, userProperties, " + + "messageExpiryIntervalSeconds, correlationData, responseTopic, payloadFormat, contentType) " + + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?);"; + + @Override + protected PreparedStatement createStatement(Connection connection) throws SQLException { + return connection.prepareStatement(QUERY); + } + + @Override + protected Integer doExecute(PreparedStatement statement) throws SQLException { + return statement.executeUpdate(); + } + + Integer executeWithParameters(SpoolMessage message) throws SQLException { + return executeWithParameters(s -> { + Publish request = message.getRequest(); + s.setLong(1, message.getId()); + s.setInt(2, message.getRetried().get()); + + // MQTT 3 & 5 fields + s.setString(3, request.getTopic()); + s.setInt(4, request.getQos().getValue()); + s.setBoolean(5, request.isRetain()); + s.setBytes(6, request.getPayload()); + + if (request.getUserProperties() == null) { + s.setNull(7, Types.NULL); + } else { + try { + s.setString(7, MAPPER.writeValueAsString(request.getUserProperties())); + } catch (IOException e) { + throw new SQLException(e); + } + } + if (request.getMessageExpiryIntervalSeconds() == null) { + s.setNull(8, Types.NULL); + } else { + s.setLong(8, request.getMessageExpiryIntervalSeconds()); + } + if (request.getCorrelationData() == null) { + s.setNull(9, Types.NULL); + } else { + s.setBytes(9, request.getCorrelationData()); + } + if (request.getResponseTopic() == null) { + s.setNull(10, Types.NULL); + } else { + s.setString(10, request.getResponseTopic()); + } + if (request.getPayloadFormat() == null) { + s.setNull(11, Types.NULL); + } else { + s.setInt(11, request.getPayloadFormat().getValue()); + } + if (request.getContentType() == null) { + s.setNull(12, Types.NULL); + } else { + s.setString(12, request.getContentType()); + } + return null; + }); + } } - /** - * This method creates a connection instance of the SQLite database. - * - * @return Connection for SQLite database instance - * @throws SQLException When fails to get Database Connection - */ - public Connection getDbInstance() throws SQLException { - return DriverManager.getConnection(url); + class RemoveSpoolMessageById extends CachedStatement { + private static final String QUERY = "DELETE FROM spooler WHERE message_id = ?;"; + + @Override + protected PreparedStatement createStatement(Connection connection) throws SQLException { + return connection.prepareStatement(QUERY); + } + + @Override + protected Integer doExecute(PreparedStatement statement) throws SQLException { + return statement.executeUpdate(); + } + + protected Integer executeWithParameters(Long id) throws SQLException { + return executeWithParameters(s -> { + s.setLong(1, id); + return null; + }); + } } - protected void setUpDatabase() throws SQLException { - String query = "CREATE TABLE IF NOT EXISTS spooler (" + class CreateSpoolerTable extends CachedStatement { + private static final String QUERY = "CREATE TABLE IF NOT EXISTS spooler (" + "message_id INTEGER PRIMARY KEY, " + "retried INTEGER NOT NULL, " + "topic STRING NOT NULL," @@ -268,85 +394,89 @@ protected void setUpDatabase() throws SQLException { + "payloadFormat INTEGER," + "contentType STRING" + ");"; - DriverManager.registerDriver(new org.sqlite.JDBC()); - performSqlOperation(conn -> { - try (Statement st = conn.createStatement()) { - // create new table if table doesn't exist - st.executeUpdate(query); - return null; - } - }); - } - @SuppressWarnings("PMD.AvoidCatchingGenericException") - private T performSqlOperation(CrashableFunction operation) throws SQLException { - try { - try (LockScope ls = LockScope.lock(connectionLock.readLock())) { - return operation.apply(connection); - } - } catch (SQLException e) { - if (CORRUPTION_ERROR_CODES.contains(e.getErrorCode())) { - recoverFromCorruption(); - } - throw e; - } catch (Exception e) { - throw new SQLException(e); + @Override + protected PreparedStatement createStatement(Connection connection) throws SQLException { + return connection.prepareStatement(QUERY); + } + + @Override + protected Integer doExecute(PreparedStatement statement) throws SQLException { + return statement.executeUpdate(); } } - void recoverFromCorruption() throws SQLException { - if (!recoverDBLock.tryLock()) { - // corruption recovery in progress - return; + /** + * A {@link Statement} wrapper that reuses the statement across executions. + * + * @param statement type + * @param execution result type + */ + abstract class CachedStatement { + private T statement; + + /** + * Create a new statement and replace the existing one, if present. + * + * @param connection connection + * @throws SQLException if unable to create statement + */ + public void replaceStatement(Connection connection) throws SQLException { + close(); // clean up old resources + statement = createStatement(connection); } - // hold connection lock throughout recovery to prevent incoming operations from executing - try (LockScope ls = LockScope.lock(connectionLock.writeLock())) { - logger.atWarn().log(String.format("Database %s is corrupted, creating new database", databasePath)); - close(); - try { - Files.deleteIfExists(databasePath); - } catch (IOException e2) { - throw new SQLException(e2); + /** + * Create a new statement. + * + * @param connection connection + * @return statement + * @throws SQLException if unable to create statement + */ + protected abstract T createStatement(Connection connection) throws SQLException; + + public void close() throws SQLException { + if (statement != null) { + statement.close(); } - initialize(); - setUpDatabase(); - } finally { - recoverDBLock.unlock(); } - } - private static List getIdsFromRs(ResultSet rs) throws SQLException { - List currentIds = new ArrayList<>(); - while (rs.next()) { - currentIds.add(rs.getLong("message_id")); + /** + * Execute the statement. This could be any type of execution, e.g. executeQuery, executeUpdate. + * + * @return execution results + * @throws SQLException if error occurs during execution + */ + R execute() throws SQLException { + return executeInternal(); + } + + /** + * Set parameters on the statement and execute it. + * + * @param decorator function that sets statement parameters + * @return execution results + * @throws SQLException if error occurs during execution + */ + R executeWithParameters(CrashableFunction decorator) throws SQLException { + decorator.apply(statement); + return executeInternal(); } - return currentIds; - } - private static SpoolMessage getSpoolMessageFromRs(long messageId, ResultSet rs) throws SQLException, IOException { - if (!rs.next()) { - return null; + @SuppressWarnings("PMD.AvoidCatchingGenericException") + private R executeInternal() throws SQLException { + try { + return doExecute(statement); + } catch (SQLException e) { + if (CORRUPTION_ERROR_CODES.contains(e.getErrorCode())) { + recoverFromCorruption(); + } + throw e; + } catch (Exception e) { + throw new SQLException(e); + } } - Publish request = Publish.builder() - .qos(QOS.fromInt(rs.getInt("qos"))) - .retain(rs.getBoolean("retain")) - .topic(rs.getString("topic")) - .payload(rs.getBytes("payload")) - .payloadFormat(rs.getObject("payloadFormat") == null - ? null : Publish.PayloadFormatIndicator.fromInt(rs.getInt("payloadFormat"))) - .messageExpiryIntervalSeconds(rs.getObject("messageExpiryIntervalSeconds") == null - ? null : rs.getLong("messageExpiryIntervalSeconds")) - .responseTopic(rs.getString("responseTopic")) - .correlationData(rs.getBytes("correlationData")) - .contentType(rs.getString("contentType")) - .userProperties(rs.getString("userProperties") == null - ? null : MAPPER.readValue(rs.getString("userProperties"), - new TypeReference>(){})).build(); - - return SpoolMessage.builder() - .id(messageId) - .retried(new AtomicInteger(rs.getInt("retried"))) - .request(request).build(); + + protected abstract R doExecute(T statement) throws SQLException; } } diff --git a/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOFake.java b/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOFake.java index b1d3d24..5b83c70 100644 --- a/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOFake.java +++ b/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOFake.java @@ -27,8 +27,8 @@ public DiskSpoolDAOFake(Path path) { @Override @SuppressWarnings("PMD.CloseResource") - public Connection getDbInstance() throws SQLException { - Connection conn = super.getDbInstance(); + public Connection createConnection() throws SQLException { + Connection conn = super.createConnection(); connection.setConnection(conn); return connection; } diff --git a/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOTest.java b/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOTest.java index 497a0d9..878bfda 100644 --- a/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOTest.java +++ b/src/test/java/com/aws/greengrass/disk/spool/DiskSpoolDAOTest.java @@ -36,9 +36,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; @ExtendWith({GGExtension.class, MockitoExtension.class}) class DiskSpoolDAOTest { @@ -79,9 +76,8 @@ class DiskSpoolDAOTest { @BeforeEach void setUp() throws SQLException { - dao = spy(new DiskSpoolDAOFake(currDir.resolve("spooler.db"))); + dao = new DiskSpoolDAOFake(currDir.resolve("spooler.db")); dao.initialize(); - dao.setUpDatabase(); } @AfterEach @@ -135,19 +131,17 @@ void GIVEN_spooler_WHEN_corruption_detected_during_operation_THEN_spooler_recove SQLException corruptionException = new SQLException("DB is corrupt", "some state", 11); dao.getConnection().addExceptionOnUpdate(corruptionException); assertThrows(SQLException.class, () -> operation.apply(dao)); - verify(dao).recoverFromCorruption(); operation.apply(dao); } @ParameterizedTest @MethodSource("allSpoolerOperations") - void GIVEN_spooler_WHEN_error_during_operation_THEN_exception_thrown(CrashableFunction operation, ExtensionContext context) throws SQLException { + void GIVEN_spooler_WHEN_error_during_operation_THEN_exception_thrown(CrashableFunction operation, ExtensionContext context) { ignoreExceptionOfType(context, SQLTransientException.class); SQLException transientException = new SQLTransientException("Some Transient Error"); dao.getConnection().addExceptionOnUpdate(transientException); dao.getConnection().addExceptionOnUpdate(transientException); assertThrows(SQLException.class, () -> operation.apply(dao)); - verify(dao, never()).recoverFromCorruption(); } public static Stream allSpoolerOperations() {