Skip to content

Commit

Permalink
threads
Browse files Browse the repository at this point in the history
  • Loading branch information
latentvector committed Sep 4, 2024
1 parent de9fda4 commit 87bdcb2
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 200 deletions.
205 changes: 5 additions & 200 deletions commune/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,8 @@ def id(self):
return self.key.ss58_address

@classmethod
def encrypt(cls,
data: Union[str, bytes],
key: str = None,
password: str = None,
**kwargs
) -> bytes:
"""
encrypt data with key
"""
key = c.get_key(key)
return key.encrypt(data, password=password,**kwargs)
def encrypt(cls,data: Union[str, bytes], key: str = None, password: str = None, **kwargs ) -> bytes:
return c.get_key(key).encrypt(data, password=password,**kwargs)

@classmethod
def decrypt(cls,
Expand All @@ -422,7 +413,6 @@ def keys(cls, search = None, ss58=False,*args, **kwargs):
keys = [c.get_key_address(k) for k in keys]
return keys


@classmethod
def set_key(self, key:str, **kwargs) -> None:
key = self.get_key(key)
Expand Down Expand Up @@ -1182,7 +1172,6 @@ def load_config(cls, path:str=None,
def save_config(cls, config:Union['Munch', Dict]= None, path:str=None) -> 'Munch':
from copy import deepcopy
from munch import Munch

'''
Saves the config to a yaml file
'''
Expand Down Expand Up @@ -2644,18 +2633,15 @@ def child_functions(cls, obj=None):
return methods

@classmethod
def locals2kwargs(cls,locals_dict:dict, kwargs_keys=['kwargs']) -> dict:
def locals2kwargs(cls,locals_dict:dict, kwargs_keys=['kwargs'], remove_arguments=['cls','self']) -> dict:
locals_dict = locals_dict or {}
kwargs = locals_dict or {}
kwargs.pop('cls', None)
kwargs.pop('self', None)

for k in remove_arguments:
kwargs.pop(k, None)
assert isinstance(kwargs, dict), f'kwargs must be a dict, got {type(kwargs)}'

# These lines are needed to remove the self and cls from the locals_dict
for k in kwargs_keys:
kwargs.update( locals_dict.pop(k, {}) or {})

return kwargs

def kwargs2attributes(self, kwargs:dict, ignore_error:bool = False):
Expand Down Expand Up @@ -2942,9 +2928,6 @@ def find_classes(cls, path='./', working=False):
libpath_objpath_prefix = cls.libpath.replace('/', '.')[1:] + '.'
classes = [c.replace(libpath_objpath_prefix, '') for c in classes]
return classes




@classmethod
def find_class2functions(cls, path, working=False):
Expand Down Expand Up @@ -4103,194 +4086,16 @@ def rm_api_keys(self):

thread_map = {}

@classmethod
def wait(cls, futures:list, timeout:int = None, generator:bool=False, return_dict:bool = True) -> list:
is_singleton = bool(not isinstance(futures, list))

futures = [futures] if is_singleton else futures
# if type(futures[0]) in [asyncio.Task, asyncio.Future]:
# return cls.gather(futures, timeout=timeout)

if len(futures) == 0:
return []
if cls.is_coroutine(futures[0]):
return cls.gather(futures, timeout=timeout)

future2idx = {future:i for i,future in enumerate(futures)}

if timeout == None:
if hasattr(futures[0], 'timeout'):
timeout = futures[0].timeout
else:
timeout = 30

if generator:
def get_results(futures):
try:
for future in concurrent.futures.as_completed(futures, timeout=timeout):
if return_dict:
idx = future2idx[future]
yield {'idx': idx, 'result': future.result()}
else:
yield future.result()
except Exception as e:
yield None

else:
def get_results(futures):
results = [None]*len(futures)
try:
for future in concurrent.futures.as_completed(futures, timeout=timeout):
idx = future2idx[future]
results[idx] = future.result()
del future2idx[future]
if is_singleton:
results = results[0]
except Exception as e:
unfinished_futures = [future for future in futures if future in future2idx]
cls.print(f'Error: {e}, {len(unfinished_futures)} unfinished futures with timeout {timeout} seconds')
return results

return get_results(futures)

@classmethod
def submit(cls,
fn,
params = None,
kwargs: dict = None,
args:list = None,
timeout:int = 40,
return_future:bool=True,
init_args : list = [],
init_kwargs:dict= {},
executor = None,
module: str = None,
mode:str='thread',
max_workers : int = 100,
):
kwargs = {} if kwargs == None else kwargs
args = [] if args == None else args
if params != None:
if isinstance(params, dict):
kwargs = {**kwargs, **params}
elif isinstance(params, list):
args = [*args, *params]
else:
raise ValueError('params must be a list or a dictionary')

fn = cls.get_fn(fn)
executor = cls.executor(max_workers=max_workers, mode=mode) if executor == None else executor
args = cls.copy(args)
kwargs = cls.copy(kwargs)
init_kwargs = cls.copy(init_kwargs)
init_args = cls.copy(init_args)
if module == None:
module = cls
else:
module = cls.module(module)
if isinstance(fn, str):
method_type = cls.classify_fn(getattr(module, fn))
elif callable(fn):
method_type = cls.classify_fn(fn)
else:
raise ValueError('fn must be a string or a callable')

if method_type == 'self':
module = module(*init_args, **init_kwargs)

future = executor.submit(fn=fn, args=args, kwargs=kwargs, timeout=timeout)

if not hasattr(cls, 'futures'):
cls.futures = []

cls.futures.append(future)

if return_future:
return future
else:
return cls.wait(future, timeout=timeout)

executor_cache = {}
@classmethod
def executor(cls, max_workers:int=None, mode:str="thread", maxsize=200, **kwargs):
return c.module(f'executor')(max_workers=max_workers, maxsize=maxsize ,mode=mode, **kwargs)

@classmethod
def as_completed(cls , futures:list, timeout:int=10, **kwargs):
return concurrent.futures.as_completed(futures, timeout=timeout)

@classmethod
def is_coroutine(cls, future):
"""
returns True if future is a coroutine
"""
return cls.obj2typestr(future) == 'coroutine'

@classmethod
def obj2typestr(cls, obj):
return str(type(obj)).split("'")[1]

@classmethod
def tasks(cls, task = None, mode='pm2',**kwargs) -> List[str]:
kwargs['network'] = 'local'
kwargs['update'] = False
modules = cls.servers( **kwargs)
tasks = getattr(cls, f'{mode}_list')(task)
tasks = list(filter(lambda x: x not in modules, tasks))
return tasks

thread_map = {}

@classmethod
def thread(cls,fn: Union['callable', str],
args:list = None,
kwargs:dict = None,
daemon:bool = True,
name = None,
tag = None,
start:bool = True,
tag_seperator:str='::',
**extra_kwargs):

if isinstance(fn, str):
fn = cls.get_fn(fn)
if args == None:
args = []
if kwargs == None:
kwargs = {}

assert callable(fn), f'target must be callable, got {fn}'
assert isinstance(args, list), f'args must be a list, got {args}'
assert isinstance(kwargs, dict), f'kwargs must be a dict, got {kwargs}'

# unique thread name
if name == None:
name = fn.__name__
cnt = 0
while name in cls.thread_map:
cnt += 1
if tag == None:
tag = ''
name = name + tag_seperator + tag + str(cnt)

if name in cls.thread_map:
cls.thread_map[name].join()

t = threading.Thread(target=fn, args=args, kwargs=kwargs, **extra_kwargs)
# set the time it starts
setattr(t, 'start_time', cls.time())
t.daemon = daemon
if start:
t.start()
cls.thread_map[name] = t
return t

@classmethod
def threads(cls, search:str = None):
threads = list(cls.thread_map.keys())
if search != None:
threads = [t for t in threads if search in t]
return threads

c.add_routes()
Module = c # Module is alias of c
Expand Down
6 changes: 6 additions & 0 deletions commune/routes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,11 @@ utils.asyncio:
- get_event_loop
- new_event_loop

utils.thread:
- thread
- threads
- submit
- as_completed



Loading

0 comments on commit 87bdcb2

Please sign in to comment.