Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
latentvector committed Sep 21, 2024
1 parent 4de530e commit 775c5f8
Show file tree
Hide file tree
Showing 174 changed files with 320 additions and 815 deletions.
90 changes: 31 additions & 59 deletions commune/cli.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@

import commune as c
import json
import sys
import time
import os
import threading
import sys


class cli:
class Cli:
"""
Create and init the CLI class, which handles the coldkey, hotkey and tao transfer
"""
#

def __init__(self,
args = None,
module = 'module',
verbose = True,
forget_fns = ['module.key_info', 'module.save_keys'],
seperator = ' ',
buffer_size=4,
save: bool = False):
def __init__(self, module = 'module', args = None):

self.seperator = seperator
self.buffer_size = buffer_size
self.verbose = verbose
self.save = save
self.forget_fns = forget_fns
self.base_module = c.module(module) if isinstance(module, str) else module
self.base_module_attributes = list(set(self.base_module.functions() + self.base_module.attributes()))
self.forward(args)

def forward(self, argv=None):
t0 = time.time()
argv = argv or self.argv()
self.input_msg = 'c ' + ' '.join(argv)
input_msg =' '.join(argv)
output = None
init_kwargs = {}
if any([arg.startswith('--') for arg in argv]):
Expand All @@ -51,47 +33,50 @@ def forward(self, argv=None):
if '=' not in arg:
value = True
value = arg.split('=')[1]
init_kwargs[key] = self.determine_type(value)
init_kwargs[key] = self.get_value(value)

# any of the --flags are init kwargs
if argv[0].endswith('.py'):
argv[0] = argv[0][:-3]
if ':' in argv[0]:
# {module}:{fn} arg1 arg2 arg3 ... argn
if ':' in argv[0]: # {module}:{fn} arg1 arg2 arg3 ... argn
argv[0] = argv[0].replace(':', '/')
if '/' in argv[0]:
if '/' in argv[0]: # {module}/{fn} arg1 arg2 arg3 ... argn
# prioritize the module over the function
module = '.'.join(argv[0].split('/')[:-1])
fn = argv[0].split('/')[-1]
argv = [module , fn , *argv[1:]]
is_fn = False
else:
is_fn = argv[0] in self.base_module_attributes

if is_fn:
module = self.base_module
fn = argv.pop(0)
else:
module = argv.pop(0)
fn = argv.pop(0)

if isinstance(module, str):
module = c.get_module(module)

module_name = module.module_name()
fn_path = f'{module_name}/{fn}'
fn_obj = getattr(module, fn)
fn_class = c.classify_fn(fn_obj)

print(f'fn_class: {fn_class}')

if fn_class == 'self':
fn_obj = getattr(module(**init_kwargs), fn)


input_msg = f'[bold]fn[/bold]: {fn_path}'

if callable(fn_obj):
args, kwargs = self.parse_args(argv)
if len(args) > 0 or len(kwargs) > 0:
input_msg += ' ' + f'[purple][bold]params:[/bold] args:{args} kwargs:{kwargs}[/purple]'
output = fn_obj(*args, **kwargs)
else:
output = fn_obj
self.input_msg = input_msg

buffer = '⚡️'*4
c.print(buffer+input_msg+buffer, color='yellow')
latency = time.time() - t0
Expand All @@ -105,34 +90,25 @@ def forward(self, argv=None):
msg = f'Result(latency={latency:.3f})'

print(buffer + msg + buffer)

num_spacers = max(0, len(self.input_msg) - len(msg) )
num_spacers = max(0, len(input_msg) - len(msg) )
left_spacers = num_spacers//2
right_spacers = num_spacers - left_spacers
msg = self.seperator*left_spacers + msg + self.seperator*right_spacers
buffer = self.buffer_size * buffer
seperator = ' '
msg = seperator*left_spacers + msg + seperator*right_spacers
is_generator = c.is_generator(output)

if is_generator:
# print the items side by side instead of vertically
for item in output:
if isinstance(item, dict):
c.print(item)
else:
c.print(item, end='')
else:
c.print(output)

return output

# c.print( f'Result ✅ (latency={self.latency:.2f}) seconds ✅')



@classmethod
def is_property(cls, obj):
return isinstance(obj, property)


@classmethod
def parse_args(cls, argv = None):
Expand All @@ -145,30 +121,24 @@ def parse_args(cls, argv = None):
if '=' in arg:
parsing_kwargs = True
key, value = arg.split('=')
kwargs[key] = cls.determine_type(value)

kwargs[key] = cls.get_value(value)
else:
assert parsing_kwargs is False, 'Cannot mix positional and keyword arguments'
args.append(cls.determine_type(arg))
args.append(cls.get_value(arg))
return args, kwargs

@classmethod
def determine_type(cls, x):
def get_value(cls, x):

if x.startswith('py(') and x.endswith(')'):
try:
return eval(x[3:-1])
except:
return x
if x.lower() in ['null'] or x == 'None': # convert 'null' or 'None' to None
if x.lower() in ['null', 'None']: # convert 'null' or 'None' to None
return None
elif x.lower() in ['true', 'false']: # convert 'true' or 'false' to bool
return bool(x.lower() == 'true')
elif x.startswith('[') and x.endswith(']'): # this is a list
try:
list_items = x[1:-1].split(',')
# try to convert each item to its actual type
x = [cls.determine_type(item.strip()) for item in list_items]
x = [cls.get_value(item.strip()) for item in list_items]
if len(x) == 1 and x[0] == '':
x = []
return x
Expand All @@ -183,7 +153,7 @@ def determine_type(cls, x):
try:
dict_items = x[1:-1].split(',')
# try to convert each item to a key-value pair
return {key.strip(): cls.determine_type(value.strip()) for key, value in [item.split(':', 1) for item in dict_items]}
return {key.strip(): cls.get_value(value.strip()) for key, value in [item.split(':', 1) for item in dict_items]}
except:
# if conversion fails, return as string
return x
Expand All @@ -192,14 +162,16 @@ def determine_type(cls, x):
try:
return int(x)
except ValueError:
try:
return float(x)
except ValueError:
return x

pass
try:
return float(x)
except ValueError:
pass
return x


def argv(self):
return sys.argv[1:]

def main():
cli()
Cli()
3 changes: 3 additions & 0 deletions commune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def resolve_module_address(self,

else:
url = module
ip = c.ip()
if ip in url:
url = url.replace(ip, '0.0.0.0')
url = f'{mode}://' + url if not url.startswith(f'{mode}://') else url
return url

Expand Down
84 changes: 7 additions & 77 deletions commune/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,41 +234,31 @@ def get_key(cls,
print(path)
raise ValueError(f'key does not exist at --> {path}')
key_json = cls.get(path)

# if key is encrypted, decrypt it
if cls.is_encrypted(key_json):
key_json = c.decrypt(data=key_json, password=password)
if key_json == None:
c.print({'status': 'error', 'message': f'key is encrypted, please {path} provide password'}, color='red')
return None


if isinstance(key_json, str):
key_json = c.jload(key_json)


if json:
key_json['path'] = path
return key_json
else:
return cls.from_json(key_json)



@classmethod
def get_keys(cls, search=None, clean_failed_keys=False):
def get_keys(cls, search=None):
keys = {}
for key in cls.keys():
if str(search) in key or search == None:
try:
keys[key] = cls.get_key(key)
key_obj = cls.get_key(key)
if hasattr(key_obj, 'ss58_address'):
keys[key] = key_obj
except Exception as e:
c.print(f'failed to get key {key} due to {e}', color='red')
continue
if keys[key] == None:
if clean_failed_keys:
cls.rm_key(key)
keys.pop(key)
print('Key error')

return keys

Expand Down Expand Up @@ -854,8 +844,6 @@ def verify(self,
True if data is signed with this Key, otherwise False
"""
data = c.copy(data)
if isinstance(data, str) and seperator in data:
data, signature = data.split(seperator)

if isinstance(data, dict):
if self.is_ticket(data):
Expand All @@ -870,9 +858,7 @@ def verify(self,
assert address != None, 'address not found in data'

if max_age != None:
if isinstance(data, int):
staleness = c.timestamp() - int(data)
elif 'timestamp' in data or 'time' in data:
if 'timestamp' in data or 'time' in data:
timestamp = data.get('timestamp', data.get('time'))
staleness = c.timestamp() - int(timestamp)
else:
Expand All @@ -890,7 +876,6 @@ def verify(self,
public_key = self.ss58_decode(public_key)
if isinstance(public_key, str):
public_key = bytes.fromhex(public_key.replace('0x', ''))

if type(data) is ScaleBytes:
data = bytes(data.data)
elif data[0:2] == '0x':
Expand All @@ -911,9 +896,7 @@ def verify(self,
crypto_verify_fn = ecdsa_verify
else:
raise ConfigurationError("Crypto type not supported")

verified = crypto_verify_fn(signature, data, public_key)

if not verified:
# Another attempt with the data wrapped, as discussed in https://github.com/polkadot-js/extension/pull/743
# Note: As Python apps are trusted sources on its own, no need to wrap data when signing from this lib
Expand All @@ -933,7 +916,6 @@ def resolve_encryption_password(self, password:str):
def resolve_encryption_data(self, data):
if not isinstance(data, str):
data = str(data)

return data

def encrypt(self, data, password=None):
Expand Down Expand Up @@ -1067,58 +1049,12 @@ def state_dict(self):
return self.__dict__

to_dict = state_dict
@classmethod
def dashboard(cls):
import streamlit as st
self = cls.new_key()
keys = self.keys()
selected_keys = st.multiselect('Keys', keys)
buttons = {}
for key_name in selected_keys:
key = cls.get_key(key_name)
with st.expander('Key Info'):
st.write(key.to_dict())
buttons[key_name] = {}
buttons[key_name]['sign'] = st.button('Sign', key_name)
st.write(self.keys())

@classmethod
def key2type(cls):
keys = cls.keys(object=True)
return {k.path: k.key_type for k in keys}
@classmethod
def key2mem(cls, search=None):
keys = cls.keys(search, object=True)
key2mem = {k.path: k.mnemonic for k in keys}
return key2mem

@classmethod
def type2keys(cls):
type2keys = {}
key2type = cls.key2type()
for k,t in key2type.items():
type2keys[t] = type2keys.get(t, []) + [k]
return type2keys

@classmethod
def pubkey2multihash(cls, pk:bytes) -> str:
import multihash
hashed_public_key = multihash.encode(pk, code=multihash.SHA2_256)
return hashed_public_key.hex()

@classmethod
def duplicate_keys(cls) -> dict:

key2address = cls.key2address()
duplicate_keys = {}

for k,a in key2address.items():
if a not in duplicate_keys:
duplicate_keys[a] = []

duplicate_keys[a] += [k]

return {k:v for k,v in duplicate_keys.items() if len(v) > 1}

@classmethod
def from_private_key(cls, private_key:str):
Expand Down Expand Up @@ -1226,7 +1162,7 @@ def is_ss58(address):
return False

return True

@classmethod
def is_encrypted(cls, data, prefix=encrypted_prefix):
if isinstance(data, str):
Expand Down Expand Up @@ -1282,12 +1218,6 @@ def resolve_key_address(cls, key):
else:
address = key
return address



if __name__ == "__main__":
Key.run()




Expand Down
Loading

0 comments on commit 775c5f8

Please sign in to comment.