Skip to content

Commit

Permalink
Improve rollback / restore logic (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrexHD authored Aug 15, 2024
1 parent 50deb8b commit 988d2f9
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.github.quiltservertools.ledger.utility.HandlerWithContext;
import com.llamalad7.mixinextras.sugar.Local;
import kotlin.Pair;
import net.minecraft.inventory.SimpleInventory;
import net.minecraft.item.ItemStack;
import net.minecraft.screen.ScreenHandler;
import net.minecraft.server.network.ServerPlayerEntity;
Expand All @@ -18,6 +19,9 @@

import java.util.List;

import static com.github.quiltservertools.ledger.utility.ItemChangeLogicKt.addItem;
import static com.github.quiltservertools.ledger.utility.ItemChangeLogicKt.removeMatchingItem;

@Mixin(targets = "net.minecraft.server.network.ServerPlayerEntity$1")
public abstract class ServerPlayerEntityMixin {

Expand All @@ -40,28 +44,14 @@ private DefaultedList<ItemStack> modifyStacks(DefaultedList<ItemStack> stacks, @
if (preview == null) return stacks;
List<Pair<ItemStack, Boolean>> modifiedItems = preview.getModifiedItems().get(pos);
if (modifiedItems == null) return stacks;
// Copy original list
DefaultedList<ItemStack> previewStacks = DefaultedList.of();
previewStacks.addAll(stacks);
SimpleInventory inventory = new SimpleInventory(stacks.toArray(new ItemStack[]{}));
for (Pair<ItemStack, Boolean> modifiedItem : modifiedItems) {
if (modifiedItem.component2()) {
// Add item
for (int i = 0; i < previewStacks.size(); i++) {
if (previewStacks.get(i).isEmpty()) {
previewStacks.set(i, modifiedItem.component1());
break;
}
}
addItem(modifiedItem.component1(), inventory);
} else {
// Remove item
for (int i = 0; i < previewStacks.size(); i++) {
if (ItemStack.areItemsEqual(previewStacks.get(i), modifiedItem.component1())) {
previewStacks.set(i, ItemStack.EMPTY);
break;
}
}
removeMatchingItem(modifiedItem.component1(), inventory);
}
}
return previewStacks;
return inventory.getHeldStacks();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.time.Instant
import kotlin.time.ExperimentalTime

abstract class AbstractActionType : ActionType {
override var id: Int = -1
override var timestamp: Instant = Instant.now()
override var pos: BlockPos = BlockPos.ORIGIN
override var world: Identifier? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import java.time.Instant
import kotlin.time.ExperimentalTime

interface ActionType {
var id: Int
val identifier: String
var timestamp: Instant
var pos: BlockPos
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package com.github.quiltservertools.ledger.actions
import com.github.quiltservertools.ledger.actionutils.Preview
import com.github.quiltservertools.ledger.utility.NbtUtils
import com.github.quiltservertools.ledger.utility.TextColorPallet
import com.github.quiltservertools.ledger.utility.addItem
import com.github.quiltservertools.ledger.utility.getOtherChestSide
import com.github.quiltservertools.ledger.utility.getWorld
import com.github.quiltservertools.ledger.utility.literal
import com.github.quiltservertools.ledger.utility.removeMatchingItem
import net.minecraft.block.Blocks
import net.minecraft.block.ChestBlock
import net.minecraft.block.InventoryProvider
Expand Down Expand Up @@ -108,15 +110,8 @@ abstract class ItemChangeActionType : AbstractActionType() {

if (world != null) {
val rollbackStack = getStack(server)

if (inventory != null) {
for (i in 0 until inventory.size()) {
val stack = inventory.getStack(i)
if (ItemStack.areItemsEqual(stack, rollbackStack)) {
inventory.setStack(i, ItemStack.EMPTY)
return true
}
}
return removeMatchingItem(rollbackStack, inventory)
} else if (rollbackStack.isOf(Items.WRITABLE_BOOK) || rollbackStack.isOf(Items.WRITTEN_BOOK)) {
val blockEntity = world.getBlockEntity(pos)
if (blockEntity is LecternBlockEntity) {
Expand All @@ -136,15 +131,8 @@ abstract class ItemChangeActionType : AbstractActionType() {

if (world != null) {
val rollbackStack = getStack(server)

if (inventory != null) {
for (i in 0 until inventory.size()) {
val stack = inventory.getStack(i)
if (stack.isEmpty) {
inventory.setStack(i, rollbackStack)
return true
}
}
return addItem(rollbackStack, inventory)
} else if (rollbackStack.isOf(Items.WRITABLE_BOOK) || rollbackStack.isOf(Items.WRITTEN_BOOK)) {
val blockEntity = world.getBlockEntity(pos)
if (blockEntity is LecternBlockEntity && !blockEntity.hasBook()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object RestoreCommand : BuildableCommand {
params.ensureSpecific()
Ledger.launch(Dispatchers.IO) {
MessageUtils.warnBusy(source)
val actions = DatabaseManager.restoreActions(params)
val actions = DatabaseManager.selectRestore(params)

if (actions.isEmpty()) {
source.sendError(Text.translatable("error.ledger.command.no_results"))
Expand All @@ -53,12 +53,16 @@ object RestoreCommand : BuildableCommand {

context.source.world.launchMain {
val fails = HashMap<String, Int>()

val actionIds = HashSet<Int>()
for (action in actions) {
if (!action.restore(context.source.server)) {
fails[action.identifier] = fails.getOrPut(action.identifier) { 0 } + 1
} else {
actionIds.add(action.id)
}
action.rolledBack = true
}
Ledger.launch(Dispatchers.IO) {
DatabaseManager.restoreActions(actionIds)
}

for (entry in fails.entries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object RollbackCommand : BuildableCommand {
params.ensureSpecific()
Ledger.launch(Dispatchers.IO) {
MessageUtils.warnBusy(source)
val actions = DatabaseManager.rollbackActions(params)
val actions = DatabaseManager.selectRollback(params)

if (actions.isEmpty()) {
source.sendError(Text.translatable("error.ledger.command.no_results"))
Expand All @@ -53,12 +53,16 @@ object RollbackCommand : BuildableCommand {

context.source.world.launchMain {
val fails = HashMap<String, Int>()

val actionIds = HashSet<Int>()
for (action in actions) {
if (!action.rollback(context.source.server)) {
fails[action.identifier] = fails.getOrPut(action.identifier) { 0 } + 1
} else {
actionIds.add(action.id)
}
action.rolledBack = true
}
Ledger.launch(Dispatchers.IO) {
DatabaseManager.rollbackActions(actionIds)
}

for (entry in fails.entries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.jetbrains.exposed.sql.Op
import org.jetbrains.exposed.sql.Query
import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.SortOrder
import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq
import org.jetbrains.exposed.sql.SqlExpressionBuilder.inSubQuery
import org.jetbrains.exposed.sql.SqlExpressionBuilder.lessEq
import org.jetbrains.exposed.sql.SqlLogger
Expand Down Expand Up @@ -148,18 +149,49 @@ object DatabaseManager {
}

suspend fun rollbackActions(params: ActionSearchParams): List<ActionType> = execute {
return@execute selectAndRollbackActions(params)
val actions = selectRollback(params)
val actionIds = actions.map { it.id }.toSet()
rollbackActions(actionIds)
return@execute actions
}

suspend fun rollbackActions(actionIds: Set<Int>) = execute {
return@execute rollbackActions(actionIds)
}

suspend fun restoreActions(params: ActionSearchParams): List<ActionType> = execute {
return@execute selectAndRestoreActions(params)
val actions = selectRestore(params)
val actionIds = actions.map { it.id }.toSet()
restoreActions(actionIds)
return@execute actions
}

suspend fun restoreActions(actionIds: Set<Int>) = execute {
return@execute restoreActions(actionIds)
}

suspend fun selectRollback(params: ActionSearchParams): List<ActionType> = execute {
val query = buildQuery()
.where(buildQueryParams(params) and (Tables.Actions.rolledBack eq false))
.orderBy(Tables.Actions.id, SortOrder.DESC)
return@execute getActionsFromQuery(query)
}

suspend fun selectRestore(params: ActionSearchParams): List<ActionType> = execute {
val query = buildQuery()
.where(buildQueryParams(params) and (Tables.Actions.rolledBack eq true))
.orderBy(Tables.Actions.id, SortOrder.ASC)
return@execute getActionsFromQuery(query)
}

suspend fun previewActions(
params: ActionSearchParams,
type: Preview.Type
): List<ActionType> = execute {
return@execute selectActionsPreview(params, type)
when (type) {
Preview.Type.ROLLBACK -> return@execute selectRollback(params)
Preview.Type.RESTORE -> return@execute selectRestore(params)
}
}

private fun getActionsFromQuery(query: Query): List<ActionType> {
Expand All @@ -173,6 +205,7 @@ object DatabaseManager {
}

val type = typeSupplier.get()
type.id = action[Tables.Actions.id].value
type.timestamp = action[Tables.Actions.timestamp]
type.pos = BlockPos(action[Tables.Actions.x], action[Tables.Actions.y], action[Tables.Actions.z])
type.world = Identifier.tryParse(action[Tables.Worlds.identifier])
Expand Down Expand Up @@ -448,23 +481,11 @@ object DatabaseManager {

private fun Transaction.selectActionsSearch(params: ActionSearchParams, page: Int): SearchResults {
val actions = mutableListOf<ActionType>()
var totalActions: Long

var query = Tables.Actions
.innerJoin(Tables.ActionIdentifiers)
.innerJoin(Tables.Worlds)
.leftJoin(Tables.Players)
.innerJoin(
Tables.oldObjectTable,
{ Tables.Actions.oldObjectId },
{ Tables.oldObjectTable[Tables.ObjectIdentifiers.id] }
)
.innerJoin(Tables.ObjectIdentifiers, { Tables.Actions.objectId }, { Tables.ObjectIdentifiers.id })
.innerJoin(Tables.Sources)
.selectAll()
var query = buildQuery()
.andWhere { buildQueryParams(params) }

totalActions = countActions(params)
val totalActions: Long = countActions(params)
if (totalActions == 0L) return SearchResults(actions, params, page, 0)

query = query.orderBy(Tables.Actions.id, SortOrder.DESC)
Expand All @@ -485,15 +506,8 @@ object DatabaseManager {
.andWhere { buildQueryParams(params) }
.count()

private fun Transaction.selectActionsPreview(
params: ActionSearchParams,
type: Preview.Type
): MutableList<ActionType> {
val actions = mutableListOf<ActionType>()

val isRestore = type == Preview.Type.RESTORE

val selectQuery = Tables.Actions
private fun Transaction.buildQuery(): Query {
return Tables.Actions
.innerJoin(Tables.ActionIdentifiers)
.innerJoin(Tables.Worlds)
.leftJoin(Tables.Players)
Expand All @@ -505,68 +519,20 @@ object DatabaseManager {
.innerJoin(Tables.ObjectIdentifiers, { Tables.Actions.objectId }, { Tables.ObjectIdentifiers.id })
.innerJoin(Tables.Sources)
.selectAll()
.andWhere { buildQueryParams(params) and (Tables.Actions.rolledBack eq isRestore) }
.orderBy(Tables.Actions.id, if (isRestore) SortOrder.ASC else SortOrder.DESC)
actions.addAll(getActionsFromQuery(selectQuery))

return actions
}

private fun Transaction.selectAndRollbackActions(params: ActionSearchParams): MutableList<ActionType> {
val actions = mutableListOf<ActionType>()

val selectQuery = Tables.Actions
.innerJoin(Tables.ActionIdentifiers)
.innerJoin(Tables.Worlds)
.leftJoin(Tables.Players)
.innerJoin(
Tables.oldObjectTable,
{ Tables.Actions.oldObjectId },
{ Tables.oldObjectTable[Tables.ObjectIdentifiers.id] }
)
.innerJoin(Tables.ObjectIdentifiers, { Tables.Actions.objectId }, { Tables.ObjectIdentifiers.id })
.innerJoin(Tables.Sources)
.selectAll()
.andWhere { buildQueryParams(params) and (Tables.Actions.rolledBack eq false) }
.orderBy(Tables.Actions.id, SortOrder.DESC)
val actionIds = selectQuery.map { it[Tables.Actions.id] }
.toSet() // SQLite doesn't support update where so select by ID. Might not be as efficent
actions.addAll(getActionsFromQuery(selectQuery))

private fun Transaction.rollbackActions(actionIds: Set<Int>) {
Tables.Actions
.update({ Tables.Actions.id inList actionIds and (Tables.Actions.rolledBack eq false) }) {
.update({ Tables.Actions.id inList actionIds }) {
it[rolledBack] = true
}

return actions
}

private fun Transaction.selectAndRestoreActions(params: ActionSearchParams): MutableList<ActionType> {
val actions = mutableListOf<ActionType>()

val selectQuery = Tables.Actions
.innerJoin(Tables.ActionIdentifiers)
.innerJoin(Tables.Worlds)
.leftJoin(Tables.Players)
.innerJoin(
Tables.oldObjectTable,
{ Tables.Actions.oldObjectId },
{ Tables.oldObjectTable[Tables.ObjectIdentifiers.id] }
)
.innerJoin(Tables.ObjectIdentifiers, { Tables.Actions.objectId }, { Tables.ObjectIdentifiers.id })
.innerJoin(Tables.Sources)
.selectAll()
.andWhere { buildQueryParams(params) and (Tables.Actions.rolledBack eq true) }
.orderBy(Tables.Actions.id, SortOrder.ASC)
val actionIds = selectQuery.map { it[Tables.Actions.id] }.toSet()
actions.addAll(getActionsFromQuery(selectQuery))

private fun Transaction.restoreActions(actionIds: Set<Int>) {
Tables.Actions
.update({ Tables.Actions.id inList actionIds and (Tables.Actions.rolledBack eq true) }) {
.update({ Tables.Actions.id inList actionIds }) {
it[rolledBack] = false
}

return actions
}

fun getKnownSources() =
Expand Down
Loading

0 comments on commit 988d2f9

Please sign in to comment.