Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare batches in a parallel thread #186

Open
BalzaniEdoardo opened this issue Jul 11, 2024 · 1 comment
Open

Prepare batches in a parallel thread #186

BalzaniEdoardo opened this issue Jul 11, 2024 · 1 comment

Comments

@BalzaniEdoardo
Copy link
Collaborator

Is your feature request related to a problem? Please describe.
Currently, the feature construction for a batch is done on the fly, but for large problems (many neurons, small bin-size, a lot of features), it may be a bottle-neck for the model fitting.

Describe the solution you'd like
A solution that we can implement is to have a parallel thread (a server) preparing N-batches and store them into a queue. The main thread fitting the model will get a batch from the queue and run an iteration of the solver. The server will keep checking the queue, and if it finds that < N batches are available will prepare an extra batch.

Describe alternatives you've considered
This is a scheme on how such solution could be implemented

import threading
import queue
import time
import numpy as np

# Shutdown flag
shutdown_flag = threading.Event()

def prepare_batch():
    time.sleep(0.1)  # Simulate time taken to prepare a batch
    X = np.random.rand(70000, 195, 9).astype(np.float32)
    y = np.random.rand(70000, 195).astype(np.float32)
    return (X, y)

def batch_loader(batch_queue, batch_size, shutdown_flag):
    while not shutdown_flag.is_set():
        if batch_queue.qsize() < batch_size:
            batch = prepare_batch()
            batch_queue.put(batch)

def algorithm_update(batch_queue, shutdown_flag, max_iterations, stopping_criterion):
    iteration = 0
    while iteration < max_iterations and not shutdown_flag.is_set():
        try:
            batch = batch_queue.get(timeout=1)  # Use timeout to periodically check shutdown_flag
            X, y = batch
            # Your logic to update the algorithm with (X, y)
            # Check your stopping criterion here
            params_update, state = model.update(params_update, state, X, y)
            if stopping_criterion(state):
                break
            batch_queue.task_done()
            iteration += 1
        except queue.Empty:
            continue

# Example stopping criterion function
def stopping_criterion(state):
    # Implement your stopping criterion logic here
    # Return True if the criterion is met, otherwise False
    return False

# Parameters
batch_size = 5  # Number of pre-loaded batches
batch_queue = queue.Queue(maxsize=batch_size)
max_iterations = 100  # Maximum number of iterations

# Start the batch loader thread
loader_thread = threading.Thread(target=batch_loader, args=(batch_queue, batch_size, shutdown_flag))
loader_thread.daemon = True  # This makes the batch loader a daemon thread
loader_thread.start()

# Main thread for algorithm update
try:
    algorithm_update(batch_queue, shutdown_flag, max_iterations, stopping_criterion)
finally:
    # Set the shutdown flag to stop the loader thread
    shutdown_flag.set()
    # Wait for the loader thread to exit
    loader_thread.join()

Additional context
This should play nicely with the run_iterator method for the stochastic optimizers in jaxopt

@BalzaniEdoardo
Copy link
Collaborator Author

BalzaniEdoardo commented Jul 31, 2024

The following is the fastest solution implemented, where the inputs are passed via a shared memory buffer of fixed size.

import multiprocessing as mp
import os
import random
from time import perf_counter
import numpy as np
import pynapple as nap


nap.nap_config.suppress_conversion_warnings = True

# class DebugSemaphore:
#     def __init__(self, initial_value=1):
#         self.semaphore = mp.Semaphore(initial_value)
#         self.counter = mp.Value('i', initial_value)

#     def acquire(self, timeout=None):
#         result = self.semaphore.acquire(timeout=timeout)
#         if result:
#             with self.counter.get_lock():
#                 self.counter.value -= 1
#                 print("acquire:", self.counter.value)

#         return result

#     def release(self):
#         self.semaphore.release()
#         with self.counter.get_lock():
#             self.counter.value += 1
#             print("release:", self.counter.value)

#     def get_value(self):
#         with self.counter.get_lock():
#             return self.counter.value

class Server:
    def __init__(self, conns, semaphore_dict, shared_arrays, stop_event, num_iterations, shared_results, array_shape, n_basis_funcs=9, bin_size=None, hist_window_sec=None, neuron_id=0):
        os.environ["JAX_PLATFORM_NAME"] = "gpu"
        import nemos
        self.model = nemos.glm.GLM(
            regularizer=nemos.regularizer.UnRegularized(
                solver_name="GradientDescent",
                solver_kwargs={"stepsize": 0.2, "acceleration": False},
            )
        )
        # set mp attributes
        self.conns = conns
        self.semaphore_dict = semaphore_dict
        self.stop_event = stop_event
        self.num_iterations = num_iterations
        self.shared_results = shared_results
        self.bin_size = bin_size
        self.hist_window_size = int(hist_window_sec / bin_size)
        self.basis = nemos.basis.RaisedCosineBasisLog(
            n_basis_funcs, mode="conv", window_size=self.hist_window_size
        )
        self.array_shape = array_shape
        self.neuron_id = neuron_id
        self.shared_arrays = shared_arrays
        print(f"ARRAY SHAPE {self.array_shape}")

    def run(self):
        params, state = None, None
        print("server run called...")
        counter = 0
        while not self.stop_event.is_set() and counter < self.num_iterations:
            for conn in self.conns:
                if conn.poll(1):  # Check if there is data in the pipe
                    print("data received in the pipe...")
                    try:
                        t0 = perf_counter()
                        worker_id = conn.recv()
                        print(f"control message worker {worker_id} loaded, time: {np.round(perf_counter() - t0, 5)}")
                        
                        # Read data from the shared array
                        t0 = perf_counter()
                        x_count = np.frombuffer(self.shared_arrays[worker_id], dtype=np.float32).reshape(self.array_shape)
                        print(f"data loaded, time: {np.round(perf_counter() - t0, 5)}")
                        
                        # Notify the worker that the data has been read
                        t0 = perf_counter()
                        self.semaphore_dict[worker_id].release()
                        print(f"release worker, time: {np.round(perf_counter() - t0, 5)}")

                        y = x_count[:, self.neuron_id]
                        
                        # initialize at first iteration
                        t0 = perf_counter()
                        X = self.basis.compute_features(x_count)
                        print(f"convolution performed, time: {np.round(perf_counter() - t0, 5)}")
                        
                        if counter == 0:
                            params, state = self.model.initialize_solver(X, y)
                            print("initialized parameters...")
                        
                        # update
                        t0 = perf_counter()
                        params, state = self.model.update(params, state, X, y)
                        print(f"update performed, time: {np.round(perf_counter() - t0, 5)}")
                        counter += 1

                    except Exception as e:
                        print(f"Exception: {e}")
                        pass

        # stop workers
        self.stop_event.set()
        # add the model to the manager shared param
        self.shared_results[:] = [(params, state)]
        print("run returned for server")

class Worker:
    def __init__(self, conn, worker_id, spike_times, time_quiet, batch_size_sec, n_batches, shared_array, semaphore, bin_size=0.001, hist_window_sec=0.4, shutdown_flag=None, n_seconds=8000):
        """Store parameters and config jax"""

        # store worker info
        self.conn = conn
        self.worker_id = worker_id

        # store multiprocessing attributes
        self.shutdown_flag = shutdown_flag
        self.shared_array = shared_array
        self.semaphore = semaphore

        # store model design hyperparameters
        self.bin_size = bin_size
        self.hist_window_size = int(hist_window_sec / bin_size)
        self.batch_size_sec = batch_size_sec
        self.batch_size = int(batch_size_sec / bin_size)
        self.spike_times = spike_times
        self.epochs = self.compute_starts(n_bat=n_batches, time_quiet=time_quiet, n_seconds=n_seconds)
        print(f"worker {worker_id} stored model parameters...")

        # set worker based seed
        np.random.seed(123 + worker_id)

    def compute_starts(self, n_bat, time_quiet, n_seconds):
        iset_batches = []
        cnt = 0
        while cnt < n_bat:
            start = np.random.uniform(0, n_seconds)
            end = start + self.batch_size_sec
            tot_time = nap.IntervalSet(end, n_seconds).intersect(time_quiet)
            if tot_time.tot_length() < self.batch_size_sec:
                continue
            ep = nap.IntervalSet(start, end).intersect(time_quiet)
            delta_t = self.batch_size_sec - ep.tot_length()

            while delta_t > 0:
                end += delta_t
                ep = nap.IntervalSet(start, end).intersect(time_quiet)
                delta_t = self.batch_size_sec - ep.tot_length()

            iset_batches.append(ep)
            cnt += 1

        return iset_batches

    def batcher(self):
        ep = self.epochs[np.random.choice(range(len(self.epochs)))]
        X_counts = self.spike_times.count(self.bin_size, ep=ep)
        return nap.TsdFrame(X_counts.t, X_counts.d.astype(np.float32), time_support=X_counts.time_support)

    def run(self):
        try:
            while not self.shutdown_flag.is_set():
                print(f"worker {self.worker_id} acquiring semaphore...")
                if not self.semaphore.acquire(timeout=1):
                    continue
                print(f"worker {self.worker_id} preparing a batch...")
                t0 = perf_counter()
                x_count = self.batcher()
                n_samp = x_count.shape[0]
                splits = [x_count.get(a, b).d for a, b in x_count.time_support.values]
                padding = np.vstack([np.vstack((s, np.full((1, *s.shape[1:]), np.nan))) for s in splits])
                print(f"worker {self.worker_id} batch ready, time: {np.round(perf_counter() - t0, 5)}")
                # Write data to shared memory using dedicated slice
                t0 = perf_counter()
                buffer_array = np.frombuffer(self.shared_array, dtype=np.float32)
                np.copyto(buffer_array, padding[:n_samp].flatten())
                print(f"worker {self.worker_id} batch copied, time: {np.round(perf_counter() - t0, 5)}")

                print(f"worker {self.worker_id} sending control message...")
                t0 = perf_counter()
                self.conn.send(self.worker_id)
                print(f"worker {self.worker_id} sent control message to the server, time: {np.round(perf_counter() - t0, 5)}...")

                # # Wait for confirmation from server
                # if not self.conn.recv():
                #     print(f"worker {self.worker_id} retrying to send control message...")
                #     continue
        finally:
            print(f"worker {self.worker_id} exits loop...")

def worker_process(conn, semaphore, *args, **kwargs):
    worker = Worker(conn, semaphore, *args, **kwargs)
    worker.run()
    print(f"run returned for worker {args[0]}")

def server_process(conns, semaphore_dict, shared_arrays, *args, **kwargs):
    server = Server(conns, semaphore_dict, shared_arrays, *args, **kwargs)
    server.run()
    print(f"run returned for server")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)  # Use 'spawn' start method

    # set MP parameters
    shutdown_flag = mp.Event()

    # shared params
    manager = mp.Manager()
    shared_results = manager.list()  # return the model to the main thread

    # generate some data
    n_neurons = 195
    n_sec = 8000.0
    n_batches = 1000
    spikes = [np.random.uniform(0.0, n_sec, 2000).astype(np.float32) for i in range(n_neurons)]
    spikes = np.array(spikes)
    ts_dict = {key: nap.Ts(spikes[key, :].flatten()) for key in range(spikes.shape[0])}
    spike_times = nap.TsGroup(ts_dict)
    neuron_id = 0  # id of neuron to fit
    batch_size_sec = n_sec / n_batches  # 1 sec batches
    gap_starts = np.random.uniform(5, 95, 10)
    gaps = nap.IntervalSet(gap_starts, gap_starts + 6)
    time_quiet = spike_times.time_support.set_diff(gaps)
    bin_size = 0.0001
    hist_window_sec = 0.004

    # set the number of iterations
    num_iterations = 100

    # num of workers
    n_workers = 4

    # shared arrays for data transfer
    array_shape = (int(batch_size_sec / bin_size), len(ts_dict))  # Adjust to match actual data size
    shared_arrays = {i: mp.Array('f', array_shape[0] * array_shape[1], lock=False) for i in range(n_workers)}

    # set up pipes, semaphores, and workers
    parent_conns, child_conns = zip(*[mp.Pipe() for _ in range(n_workers)])
    
    semaphore_dict = {i: mp.Semaphore(1) for i in range(n_workers)}
    workers = []
    #conn, worker_id, spike_times, time_quiet, batch_size_sec, n_batches, shared_array, semaphore, bin_size=0.001, hist_window_sec=0.4, shutdown_flag=None
    for i, conn in enumerate(child_conns):
        p = mp.Process(
            target=worker_process,
            args=(conn, i, spike_times, time_quiet, batch_size_sec, n_batches, shared_arrays[i], semaphore_dict[i]),
            kwargs=dict(
                bin_size=bin_size,
                shutdown_flag=shutdown_flag,
                hist_window_sec=hist_window_sec,
                n_seconds=n_sec
            )
        )
        p.start()
        workers.append(p)
        print(f"Worker id {i} pid = {p.pid}")

    # conns, semaphore_dict, shared_arrays, stop_event, num_iterations, shared_results, array_shape, n_basis_funcs=9, bin_size=None, hist_window_sec=None, neuron_id=0
    server = mp.Process(
        target=server_process,
        args=(parent_conns, semaphore_dict, shared_arrays, shutdown_flag, num_iterations, shared_results, array_shape),
        kwargs=dict(n_basis_funcs=9, hist_window_sec=hist_window_sec, bin_size=bin_size, neuron_id=neuron_id)
    )
    server.start()
    server.join()

    out = shared_results[0]
    if out:
        params, state = out
        print("final params", params)
    else:
        print("no shared model in the list...")

    # Signal workers to stop
    shutdown_flag.set()
    print("flag set")

    # Release all semaphores to unblock workers if they are waiting
    for semaphore in semaphore_dict.values():
        semaphore.release()

    # Join worker processes
    for p in workers:
        p.join(timeout=5)
        if p.is_alive():
            print(f"Worker {p.pid} did not exit gracefully, terminating.")
            p.terminate()
        else:
            print(f"Joined worker {p.pid}")

    print("Script terminated")

There are a couple of implementation details to remember:

  • jax is imported run-time in the server. If the workers need jax, it should be configured to use CPU by setting the appropriate env variable os.environ["JAX_PLATFORM_NAME"] = "cpu" in the worker function body before importing jax (otherwise it will result in GPU memory allocation issues. This may be solvable but it's hard.
  • mp.set_start_method("spawn", force=True) is mandatory, otherwise jax compiled funcs like jnp.exp would not be picklable and the sub-processes would not start.
  • encapsulate in if __name__ == "__main__" is mandatory in order to make sure that only the main thread starts off workers.
  • Using queues or pipes to pass around arrays would be way too slow to make a difference in execution time. The shared buffer is by far the most efficient way to pass around data.
  • If semaphore logic is off, one can use the commented out DebugSemaphore to keep track of the semaphore counters.
  • Release the semaphores as soon as the data read is done, so that the worker don't hang
  • Make sure workers don't hang forever at semaphore by adding a time out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant