-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prepare batches in a parallel thread #186
Comments
The following is the fastest solution implemented, where the inputs are passed via a shared memory buffer of fixed size. import multiprocessing as mp
import os
import random
from time import perf_counter
import numpy as np
import pynapple as nap
nap.nap_config.suppress_conversion_warnings = True
# class DebugSemaphore:
# def __init__(self, initial_value=1):
# self.semaphore = mp.Semaphore(initial_value)
# self.counter = mp.Value('i', initial_value)
# def acquire(self, timeout=None):
# result = self.semaphore.acquire(timeout=timeout)
# if result:
# with self.counter.get_lock():
# self.counter.value -= 1
# print("acquire:", self.counter.value)
# return result
# def release(self):
# self.semaphore.release()
# with self.counter.get_lock():
# self.counter.value += 1
# print("release:", self.counter.value)
# def get_value(self):
# with self.counter.get_lock():
# return self.counter.value
class Server:
def __init__(self, conns, semaphore_dict, shared_arrays, stop_event, num_iterations, shared_results, array_shape, n_basis_funcs=9, bin_size=None, hist_window_sec=None, neuron_id=0):
os.environ["JAX_PLATFORM_NAME"] = "gpu"
import nemos
self.model = nemos.glm.GLM(
regularizer=nemos.regularizer.UnRegularized(
solver_name="GradientDescent",
solver_kwargs={"stepsize": 0.2, "acceleration": False},
)
)
# set mp attributes
self.conns = conns
self.semaphore_dict = semaphore_dict
self.stop_event = stop_event
self.num_iterations = num_iterations
self.shared_results = shared_results
self.bin_size = bin_size
self.hist_window_size = int(hist_window_sec / bin_size)
self.basis = nemos.basis.RaisedCosineBasisLog(
n_basis_funcs, mode="conv", window_size=self.hist_window_size
)
self.array_shape = array_shape
self.neuron_id = neuron_id
self.shared_arrays = shared_arrays
print(f"ARRAY SHAPE {self.array_shape}")
def run(self):
params, state = None, None
print("server run called...")
counter = 0
while not self.stop_event.is_set() and counter < self.num_iterations:
for conn in self.conns:
if conn.poll(1): # Check if there is data in the pipe
print("data received in the pipe...")
try:
t0 = perf_counter()
worker_id = conn.recv()
print(f"control message worker {worker_id} loaded, time: {np.round(perf_counter() - t0, 5)}")
# Read data from the shared array
t0 = perf_counter()
x_count = np.frombuffer(self.shared_arrays[worker_id], dtype=np.float32).reshape(self.array_shape)
print(f"data loaded, time: {np.round(perf_counter() - t0, 5)}")
# Notify the worker that the data has been read
t0 = perf_counter()
self.semaphore_dict[worker_id].release()
print(f"release worker, time: {np.round(perf_counter() - t0, 5)}")
y = x_count[:, self.neuron_id]
# initialize at first iteration
t0 = perf_counter()
X = self.basis.compute_features(x_count)
print(f"convolution performed, time: {np.round(perf_counter() - t0, 5)}")
if counter == 0:
params, state = self.model.initialize_solver(X, y)
print("initialized parameters...")
# update
t0 = perf_counter()
params, state = self.model.update(params, state, X, y)
print(f"update performed, time: {np.round(perf_counter() - t0, 5)}")
counter += 1
except Exception as e:
print(f"Exception: {e}")
pass
# stop workers
self.stop_event.set()
# add the model to the manager shared param
self.shared_results[:] = [(params, state)]
print("run returned for server")
class Worker:
def __init__(self, conn, worker_id, spike_times, time_quiet, batch_size_sec, n_batches, shared_array, semaphore, bin_size=0.001, hist_window_sec=0.4, shutdown_flag=None, n_seconds=8000):
"""Store parameters and config jax"""
# store worker info
self.conn = conn
self.worker_id = worker_id
# store multiprocessing attributes
self.shutdown_flag = shutdown_flag
self.shared_array = shared_array
self.semaphore = semaphore
# store model design hyperparameters
self.bin_size = bin_size
self.hist_window_size = int(hist_window_sec / bin_size)
self.batch_size_sec = batch_size_sec
self.batch_size = int(batch_size_sec / bin_size)
self.spike_times = spike_times
self.epochs = self.compute_starts(n_bat=n_batches, time_quiet=time_quiet, n_seconds=n_seconds)
print(f"worker {worker_id} stored model parameters...")
# set worker based seed
np.random.seed(123 + worker_id)
def compute_starts(self, n_bat, time_quiet, n_seconds):
iset_batches = []
cnt = 0
while cnt < n_bat:
start = np.random.uniform(0, n_seconds)
end = start + self.batch_size_sec
tot_time = nap.IntervalSet(end, n_seconds).intersect(time_quiet)
if tot_time.tot_length() < self.batch_size_sec:
continue
ep = nap.IntervalSet(start, end).intersect(time_quiet)
delta_t = self.batch_size_sec - ep.tot_length()
while delta_t > 0:
end += delta_t
ep = nap.IntervalSet(start, end).intersect(time_quiet)
delta_t = self.batch_size_sec - ep.tot_length()
iset_batches.append(ep)
cnt += 1
return iset_batches
def batcher(self):
ep = self.epochs[np.random.choice(range(len(self.epochs)))]
X_counts = self.spike_times.count(self.bin_size, ep=ep)
return nap.TsdFrame(X_counts.t, X_counts.d.astype(np.float32), time_support=X_counts.time_support)
def run(self):
try:
while not self.shutdown_flag.is_set():
print(f"worker {self.worker_id} acquiring semaphore...")
if not self.semaphore.acquire(timeout=1):
continue
print(f"worker {self.worker_id} preparing a batch...")
t0 = perf_counter()
x_count = self.batcher()
n_samp = x_count.shape[0]
splits = [x_count.get(a, b).d for a, b in x_count.time_support.values]
padding = np.vstack([np.vstack((s, np.full((1, *s.shape[1:]), np.nan))) for s in splits])
print(f"worker {self.worker_id} batch ready, time: {np.round(perf_counter() - t0, 5)}")
# Write data to shared memory using dedicated slice
t0 = perf_counter()
buffer_array = np.frombuffer(self.shared_array, dtype=np.float32)
np.copyto(buffer_array, padding[:n_samp].flatten())
print(f"worker {self.worker_id} batch copied, time: {np.round(perf_counter() - t0, 5)}")
print(f"worker {self.worker_id} sending control message...")
t0 = perf_counter()
self.conn.send(self.worker_id)
print(f"worker {self.worker_id} sent control message to the server, time: {np.round(perf_counter() - t0, 5)}...")
# # Wait for confirmation from server
# if not self.conn.recv():
# print(f"worker {self.worker_id} retrying to send control message...")
# continue
finally:
print(f"worker {self.worker_id} exits loop...")
def worker_process(conn, semaphore, *args, **kwargs):
worker = Worker(conn, semaphore, *args, **kwargs)
worker.run()
print(f"run returned for worker {args[0]}")
def server_process(conns, semaphore_dict, shared_arrays, *args, **kwargs):
server = Server(conns, semaphore_dict, shared_arrays, *args, **kwargs)
server.run()
print(f"run returned for server")
if __name__ == "__main__":
mp.set_start_method("spawn", force=True) # Use 'spawn' start method
# set MP parameters
shutdown_flag = mp.Event()
# shared params
manager = mp.Manager()
shared_results = manager.list() # return the model to the main thread
# generate some data
n_neurons = 195
n_sec = 8000.0
n_batches = 1000
spikes = [np.random.uniform(0.0, n_sec, 2000).astype(np.float32) for i in range(n_neurons)]
spikes = np.array(spikes)
ts_dict = {key: nap.Ts(spikes[key, :].flatten()) for key in range(spikes.shape[0])}
spike_times = nap.TsGroup(ts_dict)
neuron_id = 0 # id of neuron to fit
batch_size_sec = n_sec / n_batches # 1 sec batches
gap_starts = np.random.uniform(5, 95, 10)
gaps = nap.IntervalSet(gap_starts, gap_starts + 6)
time_quiet = spike_times.time_support.set_diff(gaps)
bin_size = 0.0001
hist_window_sec = 0.004
# set the number of iterations
num_iterations = 100
# num of workers
n_workers = 4
# shared arrays for data transfer
array_shape = (int(batch_size_sec / bin_size), len(ts_dict)) # Adjust to match actual data size
shared_arrays = {i: mp.Array('f', array_shape[0] * array_shape[1], lock=False) for i in range(n_workers)}
# set up pipes, semaphores, and workers
parent_conns, child_conns = zip(*[mp.Pipe() for _ in range(n_workers)])
semaphore_dict = {i: mp.Semaphore(1) for i in range(n_workers)}
workers = []
#conn, worker_id, spike_times, time_quiet, batch_size_sec, n_batches, shared_array, semaphore, bin_size=0.001, hist_window_sec=0.4, shutdown_flag=None
for i, conn in enumerate(child_conns):
p = mp.Process(
target=worker_process,
args=(conn, i, spike_times, time_quiet, batch_size_sec, n_batches, shared_arrays[i], semaphore_dict[i]),
kwargs=dict(
bin_size=bin_size,
shutdown_flag=shutdown_flag,
hist_window_sec=hist_window_sec,
n_seconds=n_sec
)
)
p.start()
workers.append(p)
print(f"Worker id {i} pid = {p.pid}")
# conns, semaphore_dict, shared_arrays, stop_event, num_iterations, shared_results, array_shape, n_basis_funcs=9, bin_size=None, hist_window_sec=None, neuron_id=0
server = mp.Process(
target=server_process,
args=(parent_conns, semaphore_dict, shared_arrays, shutdown_flag, num_iterations, shared_results, array_shape),
kwargs=dict(n_basis_funcs=9, hist_window_sec=hist_window_sec, bin_size=bin_size, neuron_id=neuron_id)
)
server.start()
server.join()
out = shared_results[0]
if out:
params, state = out
print("final params", params)
else:
print("no shared model in the list...")
# Signal workers to stop
shutdown_flag.set()
print("flag set")
# Release all semaphores to unblock workers if they are waiting
for semaphore in semaphore_dict.values():
semaphore.release()
# Join worker processes
for p in workers:
p.join(timeout=5)
if p.is_alive():
print(f"Worker {p.pid} did not exit gracefully, terminating.")
p.terminate()
else:
print(f"Joined worker {p.pid}")
print("Script terminated") There are a couple of implementation details to remember:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is your feature request related to a problem? Please describe.
Currently, the feature construction for a batch is done on the fly, but for large problems (many neurons, small bin-size, a lot of features), it may be a bottle-neck for the model fitting.
Describe the solution you'd like
A solution that we can implement is to have a parallel thread (a server) preparing N-batches and store them into a queue. The main thread fitting the model will get a batch from the queue and run an iteration of the solver. The server will keep checking the queue, and if it finds that < N batches are available will prepare an extra batch.
Describe alternatives you've considered
This is a scheme on how such solution could be implemented
Additional context
This should play nicely with the
run_iterator
method for the stochastic optimizers in jaxoptThe text was updated successfully, but these errors were encountered: