Skip to content

Commit

Permalink
Add wandb overview
Browse files Browse the repository at this point in the history
  • Loading branch information
steffencruz committed Apr 5, 2024
1 parent 9123d56 commit 7980ef4
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 35 deletions.
46 changes: 36 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# https://huggingface.co/docs/hub/en/spaces-github-actions
import os
import time
import pandas as pd
import streamlit as st
Expand All @@ -7,10 +8,21 @@
# prompt-based completion score stats
# instrospect specific RUN-UID-COMPLETION
# cache individual file loads
# Hotkey churn
# Hotkey

# TODO: limit the historical lookup to something reasonable (e.g. 30 days)
# TODO: Add sidebar for filters such as tags, hotkeys, etc.
# TODO: Show trends for runs (versions, hotkeys, etc.). An area chart would be nice, a gantt chart would be better
# TODO: Add a search bar for runs
# TODO: Find a reason to make a pie chart (task distribution, maybe)
# TODO: remove repetition plots (it's not really a thing any more)
# TODO: MINER SKILLSET STAR CHART
# TODO: Status codes for runs vs time (from analysis notebook)

WANDB_PROJECT = "opentensor-dev/alpha-validators"
DEFAULT_FILTERS = {"tags": {"$in": [f'1.1.{i}' for i in range(10)]}}
PROJECT_URL = f'https://wandb.ai/{WANDB_PROJECT}/table?workspace=default'
MAX_RECENT_RUNS = 100
DEFAULT_FILTERS = {}#{"tags": {"$in": [f'1.1.{i}' for i in range(10)]}}
DEFAULT_SELECTED_HOTKEYS = None
DEFAULT_TASK = 'qa'
DEFAULT_COMPLETION_NTOP = 10
Expand All @@ -24,7 +36,7 @@
'About': f"""
This dashboard is part of the OpenTensor project. \n
To see runs in wandb, go to: \n
https://wandb.ai/{WANDB_PROJECT}/table?workspace=default
[Wandb Table](https://wandb.ai/{WANDB_PROJECT}/table?workspace=default) \n
"""
},
layout = "centered"
Expand All @@ -36,44 +48,58 @@
st.markdown('#')



with st.spinner(text=f'Checking wandb...'):
df_runs = io.load_runs(project=WANDB_PROJECT, filters=DEFAULT_FILTERS, min_steps=10)
df_runs = io.load_runs(project=WANDB_PROJECT, filters=DEFAULT_FILTERS, min_steps=10, max_recent=MAX_RECENT_RUNS)

metric.wandb(df_runs)

# add vertical space
st.markdown('#')

runid_c1, runid_c2 = st.columns([3, 1])
# make multiselect for run_ids with label on same line
run_ids = runid_c1.multiselect('Select one or more weights and biases run by id:', df_runs['run_id'], key='run_id', default=df_runs['run_id'][:3], help=f'Select one or more runs to analyze. You can find the raw data for these runs [here]({PROJECT_URL}).')
n_runs = len(run_ids)
df_runs_subset = df_runs[df_runs['run_id'].isin(run_ids)]

st.markdown('#')

tab1, tab2, tab3, tab4 = st.tabs(["Raw Data", "UID Health", "Completions", "Prompt-based scoring"])
tab1, tab2, tab3, tab4 = st.tabs(["Run Data", "UID Health", "Completions", "Prompt-based scoring"])

### Wandb Runs ###
with tab1:

st.markdown('#')
st.subheader(":violet[Run] Data")
with st.expander(f'Show :violet[raw] wandb data'):
with st.expander(f'Show :violet[all] wandb runs'):

edited_df = st.data_editor(
df_runs.assign(Select=False).set_index('Select'),
column_config={"Select": st.column_config.CheckboxColumn(required=True)},
disabled=df_runs.columns,
use_container_width=True,
)
df_runs_subset = df_runs[edited_df.index==True]
n_runs = len(df_runs_subset)
if edited_df.index.any():
df_runs_subset = df_runs[edited_df.index==True]
n_runs = len(df_runs_subset)

if n_runs:
df = io.load_data(df_runs_subset, load=True, save=True)
df = inspect.clean_data(df)
print(f'\nNans in columns: {df.isna().sum()}')
df_long = inspect.explode_data(df)
if 'rewards' in df_long:
df_long['rewards'] = df_long['rewards'].astype(float)
else:
st.info(f'You must select at least one run to load data')
st.stop()

metric.runs(df_long)

timeline_color = st.radio('Color by:', ['state', 'version', 'netuid'], key='timeline_color', horizontal=True)
plot.timeline(df_runs, color=timeline_color)

st.markdown('#')
st.subheader(":violet[Event] Data")
with st.expander(f'Show :violet[raw] event data for **{n_runs} selected runs**'):
Expand All @@ -97,7 +123,7 @@

uid_src = st.radio('Select task type:', step_types, horizontal=True, key='uid_src')
df_uid = df_long[df_long.task.str.contains(uid_src)] if uid_src != 'all' else df_long

metric.uids(df_uid, uid_src)
uids = st.multiselect('UID:', sorted(df_uid['uids'].unique()), key='uid')
with st.expander(f'Show UID health data for **{n_runs} selected runs** and **{len(uids)} selected UIDs**'):
Expand Down Expand Up @@ -158,7 +184,7 @@
# completion_src = msg_col1.radio('Select one:', ['followup', 'answer'], horizontal=True, key='completion_src')
completion_src = st.radio('Select task type:', step_types, horizontal=True, key='completion_src')
df_comp = df_long[df_long.task.str.contains(completion_src)] if completion_src != 'all' else df_long

completion_info.info(f"Showing **{completion_src}** completions for **{n_runs} selected runs**")

completion_ntop = msg_col2.slider('Top k:', min_value=1, max_value=50, value=DEFAULT_COMPLETION_NTOP, key='completion_ntop')
Expand Down
52 changes: 34 additions & 18 deletions opendashboards/assets/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import time
import pandas as pd
import streamlit as st

Expand All @@ -12,8 +13,16 @@
is_object_dtype,
)

@st.cache_data
def load_runs(project, filters, min_steps=10):
# @st.cache_data
def load_runs(project, filters, min_steps=10, max_recent=100, local_path='wandb_runs.csv', local_stale_time=3600):
# TODO: clean up the caching logic (e.g. take into account the args)

dtypes = {'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category', 'start_time': 'datetime64[s]', 'end_time': 'datetime64[s]', 'duration': 'timedelta64[s]'}

if local_path and os.path.exists(local_path) and (time.time() - float(os.path.getmtime(local_path))) < local_stale_time:
frame = pd.read_csv(local_path)
return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns})

runs = []
n_events = 0
successful = 0
Expand All @@ -22,20 +31,25 @@ def load_runs(project, filters, min_steps=10):

all_runs = utils.get_runs(project, filters)
for i, run in enumerate(all_runs):

if i > max_recent:
break
summary = run.summary
step = summary.get('_step',-1) + 1
if step < min_steps:
msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
continue

prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
progress.progress(i/len(all_runs),f'{prog_msg}... **fetching** `{run.name}`')
progress.progress(min(i/len(all_runs),1),f'{prog_msg}... **fetching** `{run.name}`')

duration = summary.get('_runtime')
end_time = summary.get('_timestamp')
# extract values for selected tags
rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
rules = {
'version': re.compile('^\\d\.\\d+\.\\d+$'),
'spec_version': re.compile('\\d{4}$'),
'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE)
}
tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
# include bool flag for remaining tags
tags.update({k: k in run.tags for k in ('mock','disable_set_weights')})
Expand All @@ -44,25 +58,27 @@ def load_runs(project, filters, min_steps=10):
'state': run.state,
'num_steps': step,
'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
'entity': run.entity,
'duration': pd.to_timedelta(duration, unit="s").round('T'), # round to nearest minute
'start_time': pd.to_datetime(end_time-duration, unit="s").round('T'),
'end_time': pd.to_datetime(end_time, unit="s").round('T'),
'netuid': run.config.get('netuid'),
**tags,
'username': run.user.username,
'run_id': run.id,
'run_name': run.name,
'project': run.project,
'url': run.url,
# 'entity': run.entity,
# 'project': run.project,
'run_path': os.path.join(run.entity, run.project, run.id),
'start_time': pd.to_datetime(end_time-duration, unit="s"),
'end_time': pd.to_datetime(end_time, unit="s"),
'duration': pd.to_timedelta(duration, unit="s").round('s'),
**tags
})
n_events += step
successful += 1

progress.empty()
msg.empty()
frame = pd.DataFrame(runs)
mappings = {'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'}
return frame.astype({k:v for k,v in mappings.items() if k in frame.columns})
frame.to_csv(local_path, index=False)
return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns})


@st.cache_data
Expand All @@ -84,7 +100,7 @@ def load_data(selected_runs, load=True, save=False):
if load and os.path.exists(file_path):
progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`')
try:
df = utils.load_data(file_path)
df = utils.read_data(file_path)
except Exception as e:
info.warning(f'Failed to load history from `{file_path}`')
st.exception(e)
Expand All @@ -97,7 +113,7 @@ def load_data(selected_runs, load=True, save=False):

print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}')
df.info()

if save and run.state != 'running':
df.to_csv(file_path, index=False)
# st.info(f'Saved history to {file_path}')
Expand Down Expand Up @@ -137,7 +153,7 @@ def filter_dataframe(df: pd.DataFrame, demo_selection=None) -> pd.DataFrame:
df = df.loc[demo_selection]
run_msg.info(f"Selected {len(df)} runs")
return df

df = df.copy()

# Try to convert datetimes into a standarrd format (datetime, no timezone)
Expand Down
6 changes: 6 additions & 0 deletions opendashboards/assets/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import streamlit as st
import opendashboards.utils.plotting as plotting

def timeline(df_runs, color='state'):
return st.plotly_chart(
plotting.plot_gantt(df_runs, color=color),
use_container_width=True
)

# @st.cache_data
def uid_diversty(df, rm_failed=True):
return st.plotly_chart(
Expand Down
12 changes: 7 additions & 5 deletions opendashboards/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import numpy as np
import networkx as nx
import streamlit as st

import plotly.express as px
import plotly.graph_objects as go
Expand All @@ -28,21 +29,22 @@

plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}

def plot_gantt(df_runs: pd.DataFrame, y='username'):
fig = px.timeline(df_runs,
x_start="start_time", x_end="end_time", y=y, color="state",
def plot_gantt(df_runs: pd.DataFrame, y='username', color="state"):
color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'}
fig = px.timeline(df_runs.astype({color: str}),
x_start="start_time", x_end="end_time", y=y, color=color,
title="Timeline of WandB Runs",
category_orders={'run_name': df_runs.run_name.unique()},
hover_name="run_name",
hover_data=[col for col in ['hotkey','user','username','run_id','num_steps','num_completions'] if col in df_runs],
color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'},
color_discrete_map={k: v for k, v in color_discrete_map.items() if k in df_runs[color].unique()},
opacity=0.3,
width=1200,
height=800,
template="plotly_white",
)
# remove y axis ticks
fig.update_yaxes(tickfont_size=8, title='')
fig.update_yaxes(title='')
return fig

def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
Expand Down
7 changes: 5 additions & 2 deletions opendashboards/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def read_data(path: str, nrows: int = None):
"""Load data from csv."""
df = pd.read_csv(path, nrows=nrows)
# filter out events with missing step length
df = df.loc[df.step_length.notna()]
# df = df.loc[df.step_length.notna()]

# detect list columns which as stored as strings
list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
def is_list_col(x):
return isinstance(x, str) and x[0]=='[' and x[-1]==']'
list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].apply(is_list_col).all()]
# convert string representation of list to list
df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')

Expand All @@ -161,6 +163,7 @@ def load_data(selected_runs, load=True, save=False, explode=True, datadir='data/
if not os.path.exists(datadir):
os.makedirs(datadir)

st.write(selected_runs)
pbar = tqdm.tqdm(selected_runs.index, desc="Loading runs", total=len(selected_runs), unit="run")
for i, idx in enumerate(pbar):
run = selected_runs.loc[idx]
Expand Down

0 comments on commit 7980ef4

Please sign in to comment.