Skip to content

Commit

Permalink
Merge pull request #175 from macrocosm-os/weight_thread_dev
Browse files Browse the repository at this point in the history
Split out a separate thread to set weights and use different subtensors.
  • Loading branch information
Sid-Data-Universe authored Dec 23, 2024
2 parents e7a8c8d + 6cf6436 commit bbbb561
Showing 1 changed file with 92 additions and 43 deletions.
135 changes: 92 additions & 43 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self):
# === Bittensor objects ====
self.wallet = bt.wallet(config=self.config)
self.subtensor = bt.subtensor(config=self.config)
self.weights_subtensor = bt.subtensor(config=self.config)
# If running on testnet, default to using finney for the dataset subtensor.
if self.config.using_test_subtensor:
self.dataset_subtensor = bt.subtensor()
Expand All @@ -159,8 +160,9 @@ def __init__(self):
torch.backends.cudnn.benchmark = True

# Setup metagraph syncer for the subnet based on config. This is non-lite for getting weights by vali.
syncer_subtensor = bt.subtensor(config=self.config)
self.subnet_metagraph_syncer = MetagraphSyncer(
self.subtensor,
syncer_subtensor,
config={
self.config.netuid: dt.timedelta(minutes=20).total_seconds(),
},
Expand Down Expand Up @@ -218,9 +220,9 @@ def __init__(self):
self._new_wandb_run()

# === Running args ===
self.weight_lock = threading.RLock()
self.weights = torch.zeros_like(torch.from_numpy(self.metagraph.S))
self.global_step = 0
self.last_epoch = self.metagraph.block.item()

self.uids_to_eval: typing.Dict[CompetitionId, typing.Set] = defaultdict(set)

Expand Down Expand Up @@ -300,6 +302,22 @@ def __init__(self):
f"Failed to load competition tracker state. Reason: {e}. Starting from scratch."
)

# Also update our internal weights based on the tracker.
cur_block = self._get_current_block()

# Get the competition schedule for the current block.
# This is a list of competitions
competition_schedule: typing.List[Competition] = (
competition_utils.get_competition_schedule_for_block(
block=cur_block,
schedule_by_block=constants.COMPETITION_SCHEDULE_BY_BLOCK,
)
)
with self.weight_lock:
self.weights = self.competition_tracker.get_subnet_weights(
competition_schedule
)

# Initialize the UIDs to eval.
if not os.path.exists(self.uids_filepath):
logging.warning("No uids state file found. Starting from scratch.")
Expand Down Expand Up @@ -337,8 +355,9 @@ def __init__(self):
self.miner_iterator = MinerIterator(self.metagraph.uids.tolist())

# Setup a ModelMetadataStore
chain_store_subtensor = bt.subtensor(config=self.config)
self.metadata_store = ChainModelMetadataStore(
subtensor=self.subtensor,
subtensor=chain_store_subtensor,
subnet_uid=self.config.netuid,
wallet=self.wallet,
)
Expand Down Expand Up @@ -372,6 +391,11 @@ def __init__(self):
)
self.clean_thread.start()

# == Initialize the weight setting thread ==
if not self.config.offline:
self.weight_thread = threading.Thread(target=self.set_weights, daemon=True)
self.weight_thread.start()

def __del__(self):
if hasattr(self, "stop_event"):
self.stop_event.set()
Expand Down Expand Up @@ -720,38 +744,72 @@ def clean_models(self):

logging.info("Exiting clean models loop.")

async def try_set_weights(self, block: int, ttl: int):
def set_weights(self):
"""Set weights on the chain regularly."""

# Check that we have some weights internally for startup situations.
all_zero_weights = True
while all_zero_weights is True:
# Technically returns a tensor but it evaluates to true.
with self.weight_lock:
all_zero_weights = torch.all(self.weights == 0)
logging.trace(
"Waiting 60 seconds for internal weights before continuing to try set weights."
)
time.sleep(60)

while not self.stop_event.is_set():
try:
set_weights_success = False
while not set_weights_success:
set_weights_success, _ = asyncio.run(self.try_set_weights(ttl=60))
# Wait for 60 seconds before we try to set weights again.
if set_weights_success:
logging.info("Successfully set weights.")
else:
time.sleep(60)
except Exception as e:
logging.error(f"Error in set weights: {e}")

# Only set weights once every hour
time.sleep(60 * 60)

logging.info("Exiting set weights loop.")

async def try_set_weights(self, ttl: int) -> typing.Tuple[bool, str]:
"""Sets the weights on the chain with ttl, without raising exceptions if it times out."""

async def _try_set_weights():
async def _try_set_weights() -> typing.Tuple[bool, str]:
with self.metagraph_lock:
uids = self.metagraph.uids
try:
weight_subtensor = bt.subtensor(config=self.config)
success, message = weight_subtensor.set_weights(
with self.weight_lock:
self.weights.nan_to_num(0.0)
weights_to_set = self.weights

return self.weights_subtensor.set_weights(
netuid=self.config.netuid,
wallet=self.wallet,
uids=uids,
weights=self.weights.numpy(),
wait_for_inclusion=False,
weights=weights_to_set.numpy(),
wait_for_inclusion=True,
version_key=constants.weights_version_key,
max_retries=1,
)
if not success:
logging.warning(
f"Failed to set weights (will retry later): {message}"
)
else:
# We only update the last epoch when we successfully set weights.
self.last_epoch = block
except:
logging.warning("Failed to set weights. Trying again later.")
except Exception as e:
logging.warning(
f"Failed to set weights due to {e}. Trying again later."
)
return (False, str(e))

try:
logging.debug(f"Setting weights.")
await asyncio.wait_for(_try_set_weights(), ttl)
logging.debug(f"Finished setting weights.")
status = await asyncio.wait_for(_try_set_weights(), ttl)
logging.debug(f"Finished setting weights with status: {status}.")
return status
except asyncio.TimeoutError:
logging.error(f"Failed to set weights after {ttl} seconds")
return (False, f"Timeout after {ttl} seconds")

def _get_current_block(self) -> int:
"""Returns the current block."""
Expand Down Expand Up @@ -1140,10 +1198,11 @@ async def run_step(self):
# Align competition_tracker to only track active competitions.
self.competition_tracker.reset_competitions(active_competition_ids)
# Update self.weights to the merged values across active competitions.
self.weights = self.competition_tracker.get_subnet_weights(
competitions=competition_schedule,
min_comp_weight_threshold=constants.MIN_WEIGHT_THRESHOLD,
)
with self.weight_lock:
self.weights = self.competition_tracker.get_subnet_weights(
competitions=competition_schedule,
min_comp_weight_threshold=constants.MIN_WEIGHT_THRESHOLD,
)

# Prioritize models for keeping up to the sample_min for the next eval loop.
# If the model has any significant weight, prioritize by weight with greater weights being kept first.
Expand Down Expand Up @@ -1280,6 +1339,11 @@ def log_step(
"uids": uids,
"uid_data": {},
}

# Get a copy of weights to print.
with self.weight_lock:
log_weights = self.weights

for uid in uids:
step_log["uid_data"][str(uid)] = {
"uid": uid,
Expand All @@ -1294,7 +1358,7 @@ def log_step(
),
"win_rate": win_rate[uid],
"win_total": wins[uid],
"weight": self.weights[uid].item(),
"weight": log_weights[uid].item(),
"norm_weight": competition_weights[
uid
].item(), # Named norm_weight for leaderboard pipeline compatibilty.
Expand Down Expand Up @@ -1329,7 +1393,7 @@ def log_step(
str(round(step_log["uid_data"][str(uid)]["average_loss"], 4)),
str(round(step_log["uid_data"][str(uid)]["epsilon_adv"], 4)),
str(round(step_log["uid_data"][str(uid)]["win_rate"], 4)),
str(round(self.weights[uid].item(), 4)),
str(round(log_weights[uid].item(), 4)),
str(round(competition_weights[uid].item(), 4)),
str(step_log["uid_data"][str(uid)]["block"]),
str(step_log["uid_data"][str(uid)]["competition_id"]),
Expand All @@ -1339,7 +1403,7 @@ def log_step(
console = Console()
console.print(table)

ws, ui = self.weights.topk(len(self.weights))
ws, ui = log_weights.topk(len(log_weights))
table = Table(title=f"Weights >= {constants.WEIGHT_SYNC_MINER_MIN_PERCENT}")
table.add_column("uid", justify="right", style="cyan", no_wrap=True)
table.add_column("weight", style="magenta")
Expand Down Expand Up @@ -1390,7 +1454,7 @@ def log_step(
"win_total_data": {
str(uid): uid_data[str(uid)]["win_total"] for uid in uids
},
"weight_data": {str(uid): self.weights[uid].item() for uid in uids},
"weight_data": {str(uid): log_weights[uid].item() for uid in uids},
"competition_weight_data": {
str(uid): competition_weights[uid].item() for uid in uids
},
Expand Down Expand Up @@ -1480,24 +1544,9 @@ async def run(self):

while True:
try:

# First run a step.
await self.try_run_step(ttl=75 * 60)
self.global_step += 1

block = self._get_current_block()

# Then check if we should set weights and do so if needed.
if not self.config.offline:
blocks_until_epoch = block - self.last_epoch

if blocks_until_epoch >= self.config.blocks_per_epoch:
await self.try_set_weights(block=block, ttl=60)
else:
logging.debug(
f"{blocks_until_epoch} / {self.config.blocks_per_epoch} blocks until next epoch."
)

except KeyboardInterrupt:
logging.info(
"KeyboardInterrupt caught, gracefully closing the wandb run..."
Expand Down

0 comments on commit bbbb561

Please sign in to comment.