Skip to content

Commit

Permalink
chore: Update version 0.6.0 (#37)
Browse files Browse the repository at this point in the history
* feat: translated examples to new Nillion Python Client
  • Loading branch information
jcabrero authored Nov 21, 2024
1 parent 38e68a5 commit ba96a3b
Show file tree
Hide file tree
Showing 13 changed files with 2,213 additions and 1,541 deletions.
203 changes: 102 additions & 101 deletions examples/complex_model/main.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,77 @@
"""Complex model example"""

import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
"""Dot Product example script"""

import asyncio
import os

import nada_numpy as na
import nada_numpy.client as na_client
import numpy as np
import py_nillion_client as nillion
import torch
from common.utils import compute, store_program, store_secrets
from cosmpy.aerial.client import LedgerClient
from cosmpy.aerial.wallet import LocalWallet
from cosmpy.crypto.keypairs import PrivateKey
from dotenv import load_dotenv
from nillion_python_helpers import (create_nillion_client,
create_payments_config)
from py_nillion_client import NodeKey, UserKey
from nillion_client import (InputPartyBinding, Network, NilChainPayer,
NilChainPrivateKey, OutputPartyBinding,
Permissions, PrivateKey, SecretInteger, VmClient)

from nada_ai.client import TorchClient

home = os.getenv("HOME")
load_dotenv(f"{home}/.config/nillion/nillion-devnet.env")


async def new_client(network, id: int, private_key: str = None):
# Create payments config and set up Nillion wallet with a private key to pay for operations
nilchain_key: str = os.getenv(f"NILLION_NILCHAIN_PRIVATE_KEY_{id}") # type: ignore
payer = NilChainPayer(
network,
wallet_private_key=NilChainPrivateKey(bytes.fromhex(nilchain_key)),
gas_limit=10000000,
)

# Use a random key to identify ourselves
signing_key = PrivateKey(private_key)
print(signing_key.private_key)
client = await VmClient.create(signing_key, network, payer)
return client


# 1 Party running simple addition on 1 stored secret and 1 compute time secret
async def main() -> None:
"""Main nada program"""
network = Network.from_config("devnet")

# WARNING: In a real use case, the Provider and User would never have access to the Private Key
# This is just for demonstration purposes
# Provider and User should only exchange their IDs
model_provider_name = "Party0"
model_provider = await new_client(
network,
0,
b'\xbf\xdf7\xa9\x1eL\x10i"\xd8\x1f\xbb\xe8\r;\x1b`\x1a\xd1\xa1;\xef\xd8\xbbf|\xf9\x12\xe9\xef\x03\xc7',
)
model_user_name = "Party1"
model_user = await new_client(
network,
1,
b"\x15\xa0\xc1\xcc\x12\xb5r\xf9\xcb\x89\x95\x8d\x94\xfb\xfe)\xdf\xfe\xbd3\x00\x18\x80\xc1\xd9W\x8b\xf7\xc0\x92S\xe9",
)

cluster_id = os.getenv("NILLION_CLUSTER_ID")
grpc_endpoint = os.getenv("NILLION_NILCHAIN_GRPC")
chain_id = os.getenv("NILLION_NILCHAIN_CHAIN_ID")
seed = "my_seed"
userkey = UserKey.from_seed((seed))
nodekey = NodeKey.from_seed((seed))
client = create_nillion_client(userkey, nodekey)
party_id = client.party_id
user_id = client.user_id

party_names = na_client.parties(2)
program_name = "complex_model"
program_mir_path = f"target/{program_name}.nada.bin"

# Configure payments
payments_config = create_payments_config(chain_id, grpc_endpoint)
payments_client = LedgerClient(payments_config)
payments_wallet = LocalWallet(
PrivateKey(bytes.fromhex(os.getenv("NILLION_NILCHAIN_PRIVATE_KEY_0"))),
prefix="nillion",
)
program_mir_path = f"./target/{program_name}.nada.bin"

##### STORE PROGRAM
print("-----STORE PROGRAM")

# Store program
program_id = await store_program(
client,
payments_wallet,
payments_client,
user_id,
cluster_id,
program_name,
program_mir_path,
)
program_mir = open(
os.path.join(os.path.dirname(os.path.abspath(__file__)), program_mir_path), "rb"
).read()
program_id = await model_provider.store_program(program_name, program_mir).invoke()

# Print details about stored program
print(f"Stored program_id: {program_id}")

##### STORE SECRETS
print("-----STORE SECRETS Party 0")

# Create custom torch Module
class MyConvModule(torch.nn.Module):
Expand Down Expand Up @@ -109,76 +119,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# Create and store model secrets via ModelClient
model_client = TorchClient(my_model)
model_secrets = nillion.NadaValues(
model_client.export_state_as_secrets("my_model", na.SecretRational)
)
permissions = nillion.Permissions.default_for_user(client.user_id)
permissions.add_compute_permissions({client.user_id: {program_id}})

model_store_id = await store_secrets(
client,
payments_wallet,
payments_client,
cluster_id,
model_secrets,
1,
permissions,
model_secrets = model_client.export_state_as_secrets(
"my_model", na_client.SecretRational
)

# Store inputs to perform inference for
my_input = na_client.array(np.ones((3, 4, 3)), "my_input", na.SecretRational)
input_secrets = nillion.NadaValues(my_input)

data_store_id = await store_secrets(
client,
payments_wallet,
payments_client,
cluster_id,
input_secrets,
1,
permissions,
# Create a permissions object to attach to the stored secret
permissions = Permissions.defaults_for_user(model_provider.user_id).allow_compute(
model_user.user_id, program_id
)

# Set up the compute bindings for the parties
compute_bindings = nillion.ProgramBindings(program_id)
# Store a secret, passing in the receipt that shows proof of payment
my_nn_store_id = await model_provider.store_values(
model_secrets, ttl_days=5, permissions=permissions
).invoke()

for party_name in party_names:
compute_bindings.add_input_party(party_name, party_id)
compute_bindings.add_output_party(party_names[-1], party_id)
print("-----STORE SECRETS Party 1")

print(f"Computing using program {program_id}")
print(f"Use secret store_id: {model_store_id}, {data_store_id}")
# Create a secret
my_input = na_client.array(np.ones((3, 4, 3)), "my_input", na.SecretRational)
# Create a permissions object to attach to the stored secret
permissions = Permissions.defaults_for_user(model_user.user_id).allow_compute(
model_user.user_id, program_id
)

# Create a computation time secret to use
computation_time_secrets = nillion.NadaValues({})
# Store a secret, passing in the receipt that shows proof of payment
my_inputs_store_id = await model_user.store_values(
my_input, ttl_days=5, permissions=permissions
).invoke()

# Compute, passing all params including the receipt that shows proof of payment
result = await compute(
client,
payments_wallet,
payments_client,
program_id,
cluster_id,
compute_bindings,
[model_store_id, data_store_id],
computation_time_secrets,
verbose=True,
)
##### COMPUTE
print("-----COMPUTE")

# Sort & rescale the obtained results by the quantization scale
outputs = [
na_client.float_from_rational(result[1])
for result in sorted(
result.items(),
key=lambda x: int(x[0].replace("my_output", "").replace("_", "")),
)
# Bind the parties in the computation to the client to set input and output parties
input_bindings = [
InputPartyBinding(model_provider_name, model_provider.user_id),
InputPartyBinding(model_user_name, model_user.user_id),
]
output_bindings = [OutputPartyBinding(model_user_name, [model_user.user_id])]

print(f"🖥️ The processed result is {outputs} @ {na.get_log_scale()}-bit precision")

expected = my_model.forward(torch.ones((3, 4, 3))).detach().numpy().tolist()
# Create a computation time secret to use
compute_time_values = {
# "my_int2": SecretInteger(10)
}

print(f"🖥️ VS expected result {expected}")
# Compute, passing in the compute time values as well as the previously uploaded value.
print(
f"Invoking computation using program {program_id} and values id {my_nn_store_id}, {my_inputs_store_id}"
)
compute_id = await model_user.compute(
program_id,
input_bindings,
output_bindings,
values=compute_time_values,
value_ids=[my_nn_store_id, my_inputs_store_id],
).invoke()

# Print compute result
print(f"The computation was sent to the network. compute_id: {compute_id}")
result = await model_user.retrieve_compute_results(compute_id).invoke()
print(f"✅ Compute complete for compute_id {compute_id}")
print(f"🖥️ The result is {result}")
return result


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit ba96a3b

Please sign in to comment.