Skip to content

Commit

Permalink
fix NBLAST batching for small number of queries/targets
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Nov 14, 2023
1 parent 98aa6c0 commit 7dce6b3
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions navis/nbl/nblast_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ def nblast_smart(query: Union[Dotprops, NeuronList],
approx_nn=approx_nn,
progress=progress,
smat_kwargs=smat_kwargs)

# Add queries and targets
for i, ix in enumerate(qix):
this.append(query_dps[ix], query_self_hits[ix])
Expand Down Expand Up @@ -1141,7 +1140,7 @@ def nblast_allbyall(x: NeuronList,
return scores


def test_single_query_time(q, t, it=20):
def test_single_query_time(q, t, it=100):
"""Test average time of a single NBLAST query."""
# Get a median-sized query and target
q_ix = np.argsort(q.n_points)[len(q)//2]
Expand All @@ -1156,7 +1155,7 @@ def test_single_query_time(q, t, it=20):
return np.mean(timings) # seconds per medium sized query


def find_batch_partition(q, t, T=10, N_cores=None):
def find_batch_partition(q, t, T=10, n_cores=None):
"""Find partitions such that each batch takes about `T` seconds.
Parameters
Expand All @@ -1165,7 +1164,7 @@ def find_batch_partition(q, t, T=10, N_cores=None):
Query and targets, respectively.
T : int
Time (in seconds) to aim for.
N_cores : int, optional
n_cores : int, optional
Number of cores that will be used. If provided, will try to
make sure that (n_rows * n_cols) is a multiple of n_cores by
increasing the number of rows (thereby decreasing the time
Expand All @@ -1176,6 +1175,7 @@ def find_batch_partition(q, t, T=10, N_cores=None):
n_rows, n_cols
"""
# Test a single query
time_per_query = test_single_query_time(q, t)

# Number of queries per job such that each job runs in `T` second
Expand All @@ -1184,14 +1184,11 @@ def find_batch_partition(q, t, T=10, N_cores=None):
# Number of neurons per batch
neurons_per_batch = max(1, int(np.sqrt(queries_per_batch)))

#n_rows = max(1, len(q) // neurons_per_batch)
#n_cols = max(1, len(t) // neurons_per_batch)
n_rows = max(1, 1000 // neurons_per_batch)
n_cols = max(1, 1000 // neurons_per_batch)
n_rows = max(1, len(q) // neurons_per_batch)
n_cols = max(1, len(t) // neurons_per_batch)

if N_cores and ((n_rows * n_cols) > N_cores):
while (n_rows * n_cols) % N_cores:
print(n_rows, n_cols, (n_rows * n_cols) % N_cores)
if n_cores and ((n_rows * n_cols) > n_cores):
while (n_rows * n_cols) % n_cores:
n_rows += 1

return n_rows, n_cols
Expand Down

0 comments on commit 7dce6b3

Please sign in to comment.