Skip to content

Commit

Permalink
Use jinja templates
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-lightly committed Nov 7, 2023
1 parent feb637b commit d16406c
Show file tree
Hide file tree
Showing 7 changed files with 805 additions and 195 deletions.
278 changes: 264 additions & 14 deletions compute_prototype.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import random
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Counter, Dict, Set, Tuple
from typing import Any, Counter, Dict, Set, Tuple

from jinja2 import Environment, FileSystemLoader
from labelformat.formats import LightlyObjectDetectionInput
from labelformat.model.object_detection import ObjectDetectionInput
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
from PIL import Image


Expand All @@ -22,12 +28,20 @@ def main() -> None:
od_insights_data = get_object_detection_insights(label_input=label_input)
present_object_detection_insights(od_insights_data=od_insights_data)

output_folder = Path("/Users/michal/tmp/lightly_insights_output")
create_html_report(
output_folder=output_folder,
image_data=image_insights_data,
od_data=od_insights_data,
)


@dataclass(frozen=True)
class ImageInsightsData:
num_images: int
images_sizes: Counter[Tuple[int, int]]
image_sizes: Counter[Tuple[int, int]]
filename_set: Set[str]
image_folder: Path


@dataclass
Expand Down Expand Up @@ -57,28 +71,30 @@ class ObjectDetectionInsightsData:

def get_image_insights(image_folder: Path) -> ImageInsightsData:
num_images = 0
images_sizes = Counter[Tuple[int, int]]()
image_sizes = Counter[Tuple[int, int]]()
filename_set = set()

# Param: Recursive?
# Param: Subsample?
for image_path in image_folder.glob("*.jpg"):
sorted_paths = sorted(image_folder.glob("*.jpg"))
for image_path in sorted_paths:
num_images += 1
filename_set.add(image_path.name)
with Image.open(image_path) as image:
images_sizes[image.size] += 1
image_sizes[image.size] += 1

return ImageInsightsData(
num_images=num_images,
images_sizes=images_sizes,
image_sizes=image_sizes,
filename_set=filename_set,
image_folder=image_folder,
)


def present_image_insights(image_insights_data: ImageInsightsData) -> None:
print(f"Num images: {image_insights_data.num_images}")
print(f"Images sizes: {image_insights_data.images_sizes.most_common()}")
print(f"Filename sample: {list(image_insights_data.filename_set)[:5]}")
print(f"Images sizes: {image_insights_data.image_sizes.most_common()}")
print(f"Filename sample: {list(image_insights_data.filename_set)[:10]}")


def get_object_detection_insights(
Expand All @@ -96,14 +112,15 @@ def get_object_detection_insights(
num_images += 1
filename_set.add(label.image.filename)

total_data.num_objects += len(label.objects)
total_data.objects_per_image[len(label.objects)] += 1

num_objects_per_category = Counter[str]()

for obj in label.objects:
# Number of objects.
total_data.num_objects += 1
class_data[obj.category.name].num_objects += 1

# Objects per image.
total_data.objects_per_image[len(label.objects)] += 1
class_data[obj.category.name].objects_per_image[len(label.objects)] += 1
num_objects_per_category[obj.category.name] += 1

# Object sizes.
obj_size_abs = (
Expand All @@ -119,6 +136,11 @@ def get_object_detection_insights(
class_data[obj.category.name].object_sizes_abs[obj_size_abs] += 1
class_data[obj.category.name].object_sizes_rel[obj_size_rel] += 1

for category in label_input.get_categories():
class_data[category.name].objects_per_image[
num_objects_per_category[category.name]
] += num_objects_per_category[category.name]

return ObjectDetectionInsightsData(
num_images=num_images,
filename_set=filename_set,
Expand All @@ -131,7 +153,7 @@ def present_object_detection_insights(
od_insights_data: ObjectDetectionInsightsData,
) -> None:
print(f"Num images with labels: {od_insights_data.num_images}")
print(f"Filename sample: {list(od_insights_data.filename_set)[:5]}")
print(f"Filename sample: {list(od_insights_data.filename_set)[:10]}")
print(f"Num objects: {od_insights_data.total.num_objects}")
print(
f"Objects per image: {od_insights_data.total.objects_per_image.most_common()}"
Expand All @@ -151,5 +173,233 @@ def present_object_detection_insights(
print(f"Class histogram: {class_histogram.most_common()}")


def create_html_report(
output_folder: Path,
image_data: ImageInsightsData,
od_data: ObjectDetectionInsightsData,
) -> None:
output_folder.mkdir(parents=True, exist_ok=True)

image_props = get_image_props(
output_folder=output_folder,
image_data=image_data,
)
od_props = get_object_detection_props(
output_folder=output_folder,
od_data=od_data,
)
report_props = {
"image": image_props,
"object_detection": od_props,
# Now.
"date_generated": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}

# Setup Jinja2 environment
env = Environment(loader=FileSystemLoader(searchpath="./templates"))
template = env.get_template("report.html")

# Render the template with data
html_output = template.render(report_props)

# Write the HTML to file
html_output_path = output_folder / "report.html"
html_output_path.write_text(html_output)


def get_image_props(
output_folder: Path,
image_data: ImageInsightsData,
) -> Dict[str, Any]:
# Image size plot.
_width_heigth_pixels_plot(
output_file=output_folder / "image_size_plot.png",
size_histogram=image_data.image_sizes,
title="Image Sizes",
)

sample_folder = output_folder / "sample"
sample_folder.mkdir(parents=True, exist_ok=True)
sample_images = []
# TODO: Do this more efficiently.
rng = random.Random(42)
selection = rng.sample(sorted(list(image_data.filename_set)), k=8)
for filename in selection:
shutil.copy2(
src=image_data.image_folder / filename, dst=sample_folder / filename
)
sample_images.append(
{
"filename": filename,
"path": sample_folder / filename,
}
)

return {
"raw": image_data,
"num_images": image_data.num_images,
"image_sizes": image_data.image_sizes.most_common(),
"filename_sample": list(image_data.filename_set)[:10],
"image_size_plot": "image_size_plot.png",
"sample_images": sample_images,
}


def get_object_detection_props(
output_folder: Path,
od_data: ObjectDetectionInsightsData,
) -> Dict[str, Any]:
_width_heigth_pixels_plot(
output_file=output_folder / "object_size_abs_plot.png",
size_histogram=od_data.total.object_sizes_abs,
title="Object Sizes in Pixels",
)

_width_heigth_percent_plot(
output_file=output_folder / "object_size_rel_plot.png",
size_histogram=od_data.total.object_sizes_rel,
title="Object Sizes in Percent",
)

_objects_per_image_plot(
output_file=output_folder / "objects_per_image_plot.png",
objects_per_image=od_data.total.objects_per_image,
title="Objects per Image",
)

# Class histogram.
class_histogram = Counter[str]()
for class_name, class_data in od_data.classes.items():
class_histogram[class_name] += class_data.num_objects

return {
"num_images": od_data.num_images,
"filename_sample": list(od_data.filename_set)[:10],
"num_objects": od_data.total.num_objects,
"objects_per_image": od_data.total.objects_per_image.most_common(),
"object_size_abs_plot": "object_size_abs_plot.png",
"object_size_rel_plot": "object_size_rel_plot.png",
"objects_per_image_plot": "objects_per_image_plot.png",
"num_classes": len(od_data.classes),
"class_histogram": class_histogram.most_common(),
"classes": [
{
"name": class_name,
"num_objects": class_data.num_objects,
"objects_per_image": class_data.objects_per_image.most_common(),
}
for class_name, class_data in od_data.classes.items()
],
}


def _width_heigth_pixels_plot(
output_file: Path,
size_histogram: Counter[Tuple[float, float]],
title: str,
) -> None:
# Image size plot.
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)
xs = []
ys = []
sizes = []
for size, count in size_histogram.items():
xs.append(size[0])
ys.append(size[1])
sizes.append(count)
ax.scatter(
xs,
ys,
s=sizes,
marker="o",
color="blue",
alpha=0.5,
)
ax.set_xlabel("Width (px)")
ax.set_ylabel("Height (px)")
ax.set_title(title)
ax.set_aspect("equal", "box")
ax.set_xlim(0, max(xs) * 1.1)
ax.set_ylim(0, max(ys) * 1.1)

# Save the plot.
plt.savefig(output_file)


def _width_heigth_percent_plot(
output_file: Path,
size_histogram: Counter[Tuple[float, float]],
title: str,
) -> None:
# Image size plot.
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)
xs = []
ys = []
sizes = []
for size, count in size_histogram.items():
xs.append(size[0])
ys.append(size[1])
sizes.append(count)
ax.scatter(
xs,
ys,
s=sizes,
marker="o",
color="blue",
alpha=0.5,
)
ax.set_xlabel("Width (%)")
ax.set_ylabel("Height (%)")
ax.set_title(title)
ax.set_aspect("equal", "box")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Save the plot.
plt.savefig(output_file)


def _objects_per_image_plot(
output_file: Path,
objects_per_image: Counter[int],
title: str,
) -> None:
# Image size plot.
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)

# Vertical bars.
xs = []
ys = []
for num_objects, count in objects_per_image.items():
xs.append(num_objects)
ys.append(count)
ax.bar(
xs,
ys,
color="blue",
alpha=0.5,
)

# Horizontal line.
ax.axhline(
y=sum(ys) / len(ys),
color="red",
linestyle="--",
)

# Show x-ticks only at integers.
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

ax.set_xlabel("Number of Objects")
ax.set_ylabel("Number of Images")
ax.set_title(title)

# Save the plot.
plt.savefig(output_file)


if __name__ == "__main__":
main()
Loading

0 comments on commit d16406c

Please sign in to comment.