Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WandbWriter class for Weights & Biases logging integration #5406

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions detectron2/utils/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional
import torch
from fvcore.common.history_buffer import HistoryBuffer
import wandb

from detectron2.utils.file_io import PathManager

Expand Down Expand Up @@ -191,6 +192,62 @@ def close(self):
if "_writer" in self.__dict__:
self._writer.close()

class WandbWriter(EventWriter):
def __init__(self, project_name, run_name=None, window_size=20, **kwargs):
"""
Args:
project_name (str): The name of the W&B project.
run_name (str): The name of the W&B run.
window_size (int): The window size for smoothing metrics.
kwargs: Additional arguments for wandb.init().
"""
self._window_size = window_size
self._last_write = -1
wandb.init(project=project_name, name=run_name, **kwargs)

def write(self):
storage = get_event_storage()
new_last_write = self._last_write
metrics_dict = storage.latest_with_smoothing_hint(self._window_size).items()
wandb_metrics = {}
new_last_write = self._last_write
for k, (v, iter) in metrics_dict:
if iter > self._last_write:
wandb_metrics[k] = v
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write

if len(storage._vis_data) >= 1:
# Create a list to store all images for this step
images_dict = {}

for img_name, img, step_num in storage._vis_data:
# Transpose from C,H,W to H,W,C
img = img.transpose(1, 2, 0)
# Add image to dictionary
images_dict[img_name] = wandb.Image(img)

# Log both metrics and all images for this step
log_dict = {
**wandb_metrics, # Unpack all metrics
**images_dict # Unpack all images
}
wandb.log(log_dict, step=iter)

# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
else:
wandb.log(wandb_metrics, step=new_last_write)

self._last_write = new_last_write

def close(self):
wandb.finish()


class CommonMetricPrinter(EventWriter):
"""
Expand Down