diff --git a/src/contracting/execution/executor.py b/src/contracting/execution/executor.py index 7d1f1ff6..404cbfb5 100644 --- a/src/contracting/execution/executor.py +++ b/src/contracting/execution/executor.py @@ -128,6 +128,7 @@ def execute(self, sender, contract_name, function_name, kwargs, enable_restricted_imports() runtime.rt.set_up(stmps=stamps * 1000, meter=metering) result = func(**kwargs) + transaction_writes = deepcopy(driver.transaction_writes) runtime.rt.tracer.stop() disable_restricted_imports() @@ -140,11 +141,13 @@ def execute(self, sender, contract_name, function_name, kwargs, # Revert the writes if the transaction fails driver.pending_writes = current_driver_pending_writes + transaction_writes = {} if auto_commit: driver.flush_cache() finally: + driver.clear_transaction_writes() runtime.rt.tracer.stop() #runtime.rt.tracer.stop() @@ -184,7 +187,7 @@ def execute(self, sender, contract_name, function_name, kwargs, 'status_code': status_code, 'result': result, 'stamps_used': stamps_used, - 'writes': deepcopy(driver.pending_writes), + 'writes': transaction_writes, 'reads': driver.pending_reads } diff --git a/src/contracting/storage/driver.py b/src/contracting/storage/driver.py index 760c6578..f471c4ad 100644 --- a/src/contracting/storage/driver.py +++ b/src/contracting/storage/driver.py @@ -38,6 +38,7 @@ def __init__(self, bypass_cache=False, storage_home=constants.STORAGE_HOME): self.contract_state = storage_home.joinpath("contract_state") self.run_state = storage_home.joinpath("run_state") self.__build_directories() + self.transaction_writes = {} # New dictionary for transaction-specific writes def __build_directories(self): self.contract_state.mkdir(exist_ok=True, parents=True) @@ -76,7 +77,7 @@ def get(self, key: str, save: bool = True): """ Get a value from the cache, pending reads, or disk. If save is True, the value will be saved to pending_reads. - """ + """ # Parse the key to get the filename and group value = self.find(key) if save and self.pending_reads.get(key) is None: @@ -93,6 +94,8 @@ def set(self, key, value): if type(value) in [decimal.Decimal, float]: value = ContractingDecimal(str(value)) self.pending_writes[key] = value + self.transaction_writes[key] = value + def find(self, key: str): """ @@ -417,3 +420,10 @@ def get_run_state(self): run_state[full_key] = value return run_state + + + def clear_transaction_writes(self): + """ + Clear the transaction-specific writes. + """ + self.transaction_writes.clear()