From 7dce6b3974951fe9668f9a972e325033f7b34e47 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Tue, 14 Nov 2023 13:03:47 +0000 Subject: [PATCH] fix NBLAST batching for small number of queries/targets --- navis/nbl/nblast_funcs.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/navis/nbl/nblast_funcs.py b/navis/nbl/nblast_funcs.py index 2f92c220..5d57661d 100644 --- a/navis/nbl/nblast_funcs.py +++ b/navis/nbl/nblast_funcs.py @@ -582,7 +582,6 @@ def nblast_smart(query: Union[Dotprops, NeuronList], approx_nn=approx_nn, progress=progress, smat_kwargs=smat_kwargs) - # Add queries and targets for i, ix in enumerate(qix): this.append(query_dps[ix], query_self_hits[ix]) @@ -1141,7 +1140,7 @@ def nblast_allbyall(x: NeuronList, return scores -def test_single_query_time(q, t, it=20): +def test_single_query_time(q, t, it=100): """Test average time of a single NBLAST query.""" # Get a median-sized query and target q_ix = np.argsort(q.n_points)[len(q)//2] @@ -1156,7 +1155,7 @@ def test_single_query_time(q, t, it=20): return np.mean(timings) # seconds per medium sized query -def find_batch_partition(q, t, T=10, N_cores=None): +def find_batch_partition(q, t, T=10, n_cores=None): """Find partitions such that each batch takes about `T` seconds. Parameters @@ -1165,7 +1164,7 @@ def find_batch_partition(q, t, T=10, N_cores=None): Query and targets, respectively. T : int Time (in seconds) to aim for. - N_cores : int, optional + n_cores : int, optional Number of cores that will be used. If provided, will try to make sure that (n_rows * n_cols) is a multiple of n_cores by increasing the number of rows (thereby decreasing the time @@ -1176,6 +1175,7 @@ def find_batch_partition(q, t, T=10, N_cores=None): n_rows, n_cols """ + # Test a single query time_per_query = test_single_query_time(q, t) # Number of queries per job such that each job runs in `T` second @@ -1184,14 +1184,11 @@ def find_batch_partition(q, t, T=10, N_cores=None): # Number of neurons per batch neurons_per_batch = max(1, int(np.sqrt(queries_per_batch))) - #n_rows = max(1, len(q) // neurons_per_batch) - #n_cols = max(1, len(t) // neurons_per_batch) - n_rows = max(1, 1000 // neurons_per_batch) - n_cols = max(1, 1000 // neurons_per_batch) + n_rows = max(1, len(q) // neurons_per_batch) + n_cols = max(1, len(t) // neurons_per_batch) - if N_cores and ((n_rows * n_cols) > N_cores): - while (n_rows * n_cols) % N_cores: - print(n_rows, n_cols, (n_rows * n_cols) % N_cores) + if n_cores and ((n_rows * n_cols) > n_cores): + while (n_rows * n_cols) % n_cores: n_rows += 1 return n_rows, n_cols