Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RichHandler breaks Progress bar with multiprocessing #3529

Open
2 tasks done
abulgher opened this issue Oct 12, 2024 · 5 comments
Open
2 tasks done

[BUG] RichHandler breaks Progress bar with multiprocessing #3529

abulgher opened this issue Oct 12, 2024 · 5 comments

Comments

@abulgher
Copy link

Describe the bug
The output of the log messages is not written above the progress bar as in normal circumstances but it is written on the side of the progress bar and another progress bar is drawn underneath.
See picture.
rich-handler

import logging
import logging.handlers
import multiprocessing
import multiprocessing.pool
from rich.logging import RichHandler
from rich.progress import Progress
from random import choice, random
import time


class ProcessLogger(multiprocessing.Process):
    _global_process_logger = None

    def __init__(self):
        super().__init__()
        self.queue = multiprocessing.Queue(-1)

    @classmethod
    def get_global_logger(cls):
        if cls._global_process_logger is not None:
            return cls._global_process_logger
        raise Exception("No global process logger exists.")

    @classmethod
    def create_global_logger(cls):
        cls._global_process_logger = ProcessLogger()
        return cls._global_process_logger

    @staticmethod
    def configure():
        root = logging.getLogger()
        h = RichHandler(rich_tracebacks=True, markup=True, show_path=False, log_time_format='%Y%m%d-%H:%M:%S')
        fs = '%(message)s'
        f = logging.Formatter(fs)
        h.setFormatter(f)
        root.addHandler(h)

    def stop(self):
        self.queue.put_nowait(None)

    def run(self):
        self.configure()
        while True:
            try:
                record = self.queue.get()
                if record is None:
                    break
                logger = logging.getLogger(record.name)
                logger.handle(record)
            except Exception:
                import sys, traceback
                print('Whoops! Problem:', file=sys.stderr)
                traceback.print_exc(file=sys.stderr)

def configure_new_process(log_process_queue):
    h = logging.handlers.QueueHandler(log_process_queue)
    root = logging.getLogger()
    root.addHandler(h)
    root.setLevel(logging.DEBUG)


class ProcessWithLogging(multiprocessing.Process):
    def __init__(self, target, args=[], kwargs={}, log_process=None):
        super().__init__()
        self.target = target
        self.args = args
        self.kwargs = kwargs
        if log_process is None:
            log_process = ProcessLogger.get_global_logger()
        self.log_process_queue = log_process.queue

    def run(self):
        configure_new_process(self.log_process_queue)
        self.target(*self.args, **self.kwargs)


class PoolWithLogging(multiprocessing.pool.Pool):
    def __init__(self, processes=None, context=None, log_process=None):
        if log_process is None:
            log_process = ProcessLogger.get_global_logger()
        super().__init__(processes=processes, initializer=configure_new_process,
                         initargs=(log_process.queue,), context=context)


LEVELS = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL]

MESSAGES = [
    'Random message #1',
    'Random message #2',
    'Random message #3',
]


def worker_process(param=None):
    name = multiprocessing.current_process().name
    print('Worker started: %s' % name)
    for i in range(10):
        time.sleep(random())
        logger = logging.getLogger()
        level = choice(LEVELS)
        message = choice(MESSAGES)
        logger.log(level, message)
    print('Worker finished: {}, param: {}'.format(name, param))
    return param


def main():
    process_logger = ProcessLogger.create_global_logger()
    process_logger.start()

    workers = []
    with Progress() as progress:
        n = 10
        loop_task = progress.add_task('[red]Loop', total=n)
        for i in range(n):

            worker = ProcessWithLogging(worker_process)
            workers.append(worker)
            worker.start()

        while (n_finished := sum([worker.exitcode is not None for worker in workers])) < n:
            progress.update(loop_task, completed=n_finished)
        progress.update(loop_task,completed=n)

        for w in workers:
            w.join()


    process_logger.stop()
    process_logger.join()

if __name__ == '__main__':
    main()

Platform

Click to expand
What platform (Win/Linux/Mac) are you running on? What terminal software are you using? Windows 11 Enterpreise. Command Prompt

I may ask you to copy and paste the output of the following commands. It may save some time if you do it now.

If you're using Rich in a terminal:

python -m rich.diagnose
pip freeze | grep rich

╭───────────────────────── <class 'rich.console.Console'> ─────────────────────────╮
│ A high level console interface. │
│ │
│ ╭──────────────────────────────────────────────────────────────────────────────╮ │
│ │ │ │
│ ╰──────────────────────────────────────────────────────────────────────────────╯ │
│ │
│ color_system = 'truecolor' │
│ encoding = 'utf-8' │
│ file = <_io.TextIOWrapper name='' mode='w' encoding='utf-8'> │
│ height = 51 │
│ is_alt_screen = False │
│ is_dumb_terminal = False │
│ is_interactive = True │
│ is_jupyter = False │
│ is_terminal = True │
│ legacy_windows = False │
│ no_color = False │
│ options = ConsoleOptions( │
│ size=ConsoleDimensions(width=103, height=51), │
│ legacy_windows=False, │
│ min_width=1, │
│ max_width=103, │
│ is_terminal=True, │
│ encoding='utf-8', │
│ max_height=51, │
│ justify=None, │
│ overflow=None, │
│ no_wrap=False, │
│ highlight=None, │
│ markup=None, │
│ height=None │
│ ) │
│ quiet = False │
│ record = False │
│ safe_box = True │
│ size = ConsoleDimensions(width=103, height=51) │
│ soft_wrap = False │
│ stderr = False │
│ style = None │
│ tab_size = 8 │
│ width = 103 │
╰──────────────────────────────────────────────────────────────────────────────────╯
╭── <class 'rich._windows.WindowsConsoleFeatures'> ───╮
│ Windows features available. │
│ │
│ ╭─────────────────────────────────────────────────╮ │
│ │ WindowsConsoleFeatures(vt=True, truecolor=True) │ │
│ ╰─────────────────────────────────────────────────╯ │
│ │
│ truecolor = True │
│ vt = True │
╰─────────────────────────────────────────────────────╯
╭────── Environment Variables ───────╮
│ { │
│ 'TERM': None, │
│ 'COLORTERM': None, │
│ 'CLICOLOR': None, │
│ 'NO_COLOR': None, │
│ 'TERM_PROGRAM': None, │
│ 'COLUMNS': None, │
│ 'LINES': None, │
│ 'JUPYTER_COLUMNS': None, │
│ 'JUPYTER_LINES': None, │
│ 'JPY_PARENT_PID': None, │
│ 'VSCODE_VERBOSE_LOGGING': None │
│ } │
╰────────────────────────────────────╯
platform="Windows"

rich==13.5.2

Copy link

Thank you for your issue. Give us a little time to review it.

PS. You might want to check the FAQ if you haven't done so already.

This is an automated reply, generated by FAQtory

@willmcgugan
Copy link
Collaborator

Rich is only able to capture output within a single process. Any text written by another process is likely to break the output.

The solution is to ensure that only your main process writes to stdout. Your subprocesses will need to communicate their output to the main process, possibly with a Pipe.

@abulgher
Copy link
Author

Thanks, I see...

What about different threads?

My application is CPU bound, so I need multiple processes if I want to speed it up. But I can imagine to have a separate thread collecting the messages from all the processes (I don't know if you can have a pipe between a thread and another process) and send them to the RichHandler.

Thanks for your answer,
abulgher

@willmcgugan
Copy link
Collaborator

Rich is thread-safe. You can have a thread reading from a pipe, so your solution should be doable.

@abulgher
Copy link
Author

So I am almost there.

I kept the multiprocessing.Queue instead of Pipe because I only need one Queue instead of one Pipe for each Process. According to the documentation a multiprocessing.Queue is process- and thread-safe.

Here is the modified code that seems to work.

import logging
import logging.handlers
import threading
import multiprocessing
import multiprocessing.pool
from rich.logging import RichHandler
from rich.progress import Progress
from random import choice, random
import time


class ProcessLogger(threading.Thread):
    _global_process_logger = None

    def __init__(self):
        super().__init__()
        self.queue = multiprocessing.Queue(-1)

    @classmethod
    def get_global_logger(cls):
        if cls._global_process_logger is not None:
            return cls._global_process_logger
        raise Exception("No global process logger exists.")

    @classmethod
    def create_global_logger(cls):
        cls._global_process_logger = ProcessLogger()
        return cls._global_process_logger

    @staticmethod
    def configure():
        root = logging.getLogger()
        root.setLevel(20)
        h = RichHandler(rich_tracebacks=True, markup=True, show_path=False, log_time_format='%Y%m%d-%H:%M:%S')
        fs = '%(message)s'
        f = logging.Formatter(fs)
        h.setFormatter(f)
        root.addHandler(h)

    def stop(self):
        self.queue.put_nowait(None)

    def run(self):
        self.configure()
        while True:
            try:
                record = self.queue.get()
                if record is None:
                    break
                logger = logging.getLogger(record.name)
                logger.handle(record)
            except Exception:
                import sys, traceback
                # print('Whoops! Problem:', file=sys.stderr)
                traceback.print_exc(file=sys.stderr)

def configure_new_process(log_process_queue):
    h = logging.handlers.QueueHandler(log_process_queue)
    root = logging.getLogger()
    root.addHandler(h)
    root.setLevel(logging.DEBUG)


class ProcessWithLogging(multiprocessing.Process):
    def __init__(self, target, args=[], kwargs={}, log_process=None):
        super().__init__()
        self.target = target
        self.args = args
        self.kwargs = kwargs
        if log_process is None:
            log_process = ProcessLogger.get_global_logger()
        self.log_process_queue = log_process.queue

    def run(self):
        configure_new_process(self.log_process_queue)
        self.target(*self.args, **self.kwargs)


class PoolWithLogging(multiprocessing.pool.Pool):
    def __init__(self, processes=None, context=None, log_process=None):
        if log_process is None:
            log_process = ProcessLogger.get_global_logger()
        super().__init__(processes=processes, initializer=configure_new_process,
                         initargs=(log_process.queue,), context=context)


LEVELS = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL]

MESSAGES = [
    'Random message #1',
    'Random message #2',
    'Random message #3',
]


def worker_process(param=None):
    name = multiprocessing.current_process().name
    for i in range(10):
        time.sleep(random())
        logger = logging.getLogger()
        level = choice(LEVELS)
        message = choice(MESSAGES)
        logger.log(level, message)
    logger.info("[orange3]Processor %s finished" % name)

    return param


def main():
    process_logger = ProcessLogger.create_global_logger()

    workers = []
    with Progress() as progress:
        process_logger.start()
        n = 10
        loop_task = progress.add_task('[red]Loop', total=n)
        for i in range(n):
            log = logging.getLogger()

            worker = ProcessWithLogging(worker_process)
            log.info('Starting processor %s' % worker.name)
            workers.append(worker)
            worker.start()

        while (n_finished := sum([worker.exitcode is not None for worker in workers])) < n:
            progress.update(loop_task, completed=n_finished)
            time.sleep(0.2)
        progress.update(loop_task,completed=n)

        for w in workers:
            w.join()


    process_logger.stop()
    process_logger.join()

if __name__ == '__main__':
    main()

There are three main differences between this and the previous implementation.

  1. Instead of a separate process doing the redirecting of the log event, I have a separate thread. Since the Thread and Process classes have a similar interface, the only visible difference in the code is that the ProcessLogger is now inheriting from Thread instead of Process.
  2. I have removed all print from spawned processes. Those are for sure breaking the output.
  3. The logging thread must be started after that Progress instance is created. This means that the process_logger.start() should be inside the Progress context manager. This could be a strong limitation for complex applications.

Do you have an explanation for the third point?

Thanks
abulgher

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants