Skip to content

Commit

Permalink
Improvements and optimizations for execution
Browse files Browse the repository at this point in the history
  • Loading branch information
Endogen committed Oct 4, 2024
1 parent 507c483 commit d30b1b2
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 91 deletions.
16 changes: 4 additions & 12 deletions src/contracting/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
from contracting.stdlib.bridge.decimal import ContractingDecimal, CONTEXT
from contracting.stdlib.bridge.random import Seeded
from contracting import constants
from loguru import logger
import re
from copy import deepcopy

import importlib
import decimal
import traceback


class Executor:
def __init__(self,
Expand Down Expand Up @@ -96,10 +92,6 @@ def execute(self, sender, contract_name, function_name, kwargs,
runtime.rt.env.update(environment)
status_code = 0

# TODO: Why do we do this?
# Multiply stamps by 1000 because we divide by it later
# runtime.rt.set_up(stmps=stamps * 1000, meter=metering)

runtime.rt.context._base_state = {
'signer': sender,
'caller': sender,
Expand Down Expand Up @@ -140,16 +132,16 @@ def execute(self, sender, contract_name, function_name, kwargs,

# Revert the writes if the transaction fails
driver.pending_writes = current_driver_pending_writes

if auto_commit:
driver.flush_cache()

finally:
runtime.rt.tracer.stop()
# Clear the module cache to prevent holding contracts in memory
from contracting.execution.module import DatabaseLoader
DatabaseLoader.module_cache.clear()

#runtime.rt.tracer.stop()

# Deduct the stamps if that is enabled
stamps_used = runtime.rt.tracer.get_stamp_used()

stamps_used = stamps_used // 1000
Expand Down
59 changes: 22 additions & 37 deletions src/contracting/execution/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,11 @@
import sys
import importlib.util

# This function overrides the __import__ function, which is the builtin function that is called whenever Python runs
# an 'import' statement. If the globals dictionary contains {'__contract__': True}, then this function will make sure
# that the module being imported comes from the database and not from builtins or site packages.
#
# For all exec statements, we add the {'__contract__': True} _key to the globals to protect against unwanted imports.
#
# Note: anything installed with pip or in site-packages will also not work, so contract package names *must* be unique.


def is_valid_import(name):
spec = importlib.util.find_spec(name)
if not isinstance(spec.loader, DatabaseLoader):
raise ImportError("module {} cannot be imported in a smart contract.".format(name))


def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):
if globals is not None and globals.get('__contract__') is True:
spec = importlib.util.find_spec(name)
Expand All @@ -33,60 +23,49 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):

return __import__(name, globals, locals, fromlist, level)


def enable_restricted_imports():
builtins.__import__ = restricted_import
# builtins.float = ContractingDecimal


def disable_restricted_imports():
builtins.__import__ = __import__


def uninstall_builtins():
sys.meta_path.clear()
sys.path_hooks.clear()
sys.path.clear()
sys.path_importer_cache.clear()
invalidate_caches()


def install_database_loader(driver=Driver()):
DatabaseFinder.driver = driver
if DatabaseFinder not in sys.meta_path:
sys.meta_path.insert(0, DatabaseFinder)

if DatabaseFinder() not in sys.meta_path:
sys.meta_path.insert(0, DatabaseFinder())

def uninstall_database_loader():
sys.meta_path = list(set(sys.meta_path))
if DatabaseFinder in sys.meta_path:
sys.meta_path.remove(DatabaseFinder)

if DatabaseFinder() in sys.meta_path:
sys.meta_path.remove(DatabaseFinder())

def install_system_contracts(directory=''):
pass


'''
Is this where interaction with the database occurs with the interface of code strings, etc?
IE: pushing a contract does sanity checks here?
'''


class DatabaseFinder:
driver = Driver()

def find_spec(self, fullname, path=None, target=None):
if MODULE_CACHE.get(self) is None:
if DatabaseFinder.driver.get_contract(self) is None:
return None
return ModuleSpec(self, DatabaseLoader(DatabaseFinder.driver))

if self.driver.get_contract(fullname) is None:
return None
return ModuleSpec(fullname, DatabaseLoader(self.driver))

MODULE_CACHE = {}
def __eq__(self, other):
return isinstance(other, DatabaseFinder)

def __hash__(self):
return hash('DatabaseFinder')

class DatabaseLoader(Loader):
module_cache = {}

def __init__(self, d=Driver()):
self.d = d

Expand All @@ -95,9 +74,9 @@ def create_module(self, spec):

def exec_module(self, module):
# fetch the individual contract
code = MODULE_CACHE.get(module.__name__)
code = self.module_cache.get(module.__name__)

if MODULE_CACHE.get(module.__name__) is None:
if code is None:
code = self.d.get_compiled(module.__name__)
if code is None:
raise ImportError("Module {} not found".format(module.__name__))
Expand All @@ -106,7 +85,7 @@ def exec_module(self, module):
code = bytes.fromhex(code)

code = marshal.loads(code)
MODULE_CACHE[module.__name__] = code
self.module_cache[module.__name__] = code

if code is None:
raise ImportError("Module {} not found".format(module.__name__))
Expand All @@ -127,3 +106,9 @@ def exec_module(self, module):

def module_repr(self, module):
return '<module {!r} (smart contract)>'.format(module.__name__)

def __eq__(self, other):
return isinstance(other, DatabaseLoader) and self.d == other.d

def __hash__(self):
return hash(('DatabaseLoader', self.d))
8 changes: 4 additions & 4 deletions src/contracting/execution/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import os


class Context:
def __init__(self, base_state, maxlen=constants.RECURSION_LIMIT):
self._state = []
Expand Down Expand Up @@ -57,7 +56,6 @@ def entry(self):
def submission_name(self):
return self._get_state()['submission_name']


_context = Context({
'this': None,
'caller': None,
Expand All @@ -69,7 +67,6 @@ def submission_name(self):

WRITE_MAX = 1024 * 128


class Runtime:
cu_path = contracting.__path__[0]
cu_path = os.path.join(cu_path, 'execution', 'metering', 'cu_costs.const')
Expand Down Expand Up @@ -114,6 +111,10 @@ def clean_up(cls):
cls.loaded_modules = []
cls.env = {}

# Clear module cache to prevent holding contracts in memory
from contracting.execution.module import DatabaseLoader
DatabaseLoader.module_cache.clear()

@classmethod
def deduct_read(cls, key, value):
if cls.tracer.is_started():
Expand All @@ -131,5 +132,4 @@ def deduct_write(cls, key, value):
stamp_cost = cost * constants.WRITE_COST_PER_BYTE
cls.tracer.add_cost(stamp_cost)


rt = Runtime()
Loading

0 comments on commit d30b1b2

Please sign in to comment.