Skip to content

Commit

Permalink
Add custom colors per model choice
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 25, 2024
1 parent d93b854 commit bb635fd
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions scripts/plotting/for_paper/plot_ais_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from natsort import natsorted

import matplotlib.pyplot as plt
import matplotlib.patches as patches


base_color = '#0562A0'
Expand Down Expand Up @@ -61,34 +60,25 @@
"semanticsam_sam": "SemSam\n$\it{(SAM)}$"
}

COLORS = {
'unet': '#045275',
'unetr_scratch': '#7CCBA2',
'unetr_sam': '#90477F',
'semanticsam_scratch': '#FCDE9C',
'semanticsam_sam': '#F0746E',
}


def make_livecell_barplot():
labels = list(LIVECELL_AIS.keys())
model_labels = [MODEL_NAME_MAPS[model] for model in labels]
scores = [LIVECELL_AIS[model][0] for model in labels]

max_index = scores.index(max(scores))

data = {"Model": model_labels, "Score": scores}
df = pd.DataFrame(data)

plt.figure(figsize=(20, 15))
bars = sns.barplot(x="Model", y="Score", data=df, color=base_color)

for i, bar in enumerate(bars.patches):
if i == max_index:
shadow = patches.FancyBboxPatch(
(bar.get_x() - 0.01, bar.get_y() - 0.01),
bar.get_width() + 0.02,
bar.get_height() + 0.0025,
boxstyle="round,pad=0.011",
linewidth=2.5,
edgecolor=None,
facecolor=highlight_color,
alpha=0.3,
zorder=-1
)
plt.gca().add_patch(shadow)
sns.barplot(x="Model", y="Score", data=df, hue='Model', legend=False, palette=list(COLORS.values()))

plt.xlabel(None)
plt.ylabel("Mean Segmentation Accuracy", fontweight="bold")
Expand All @@ -109,10 +99,10 @@ def make_livecell_barplot():

def make_covid_if_lineplot():
markers = {
'unet': 'P', 'unetr_scratch': 'X', 'unetr_sam': 'o', 'semanticsam_scratch': '^', 'semanticsam_sam': 'd'
'unet': 'o', 'unetr_scratch': 'o', 'unetr_sam': 'o', 'semanticsam_scratch': 'o', 'semanticsam_sam': 'o',
}
line_styles = {
'unet': '-', 'unetr_scratch': '--', 'unetr_sam': '-.', 'semanticsam_scratch': ':', 'semanticsam_sam': '-'
'unet': '-', 'unetr_scratch': '-', 'unetr_sam': '-.', 'semanticsam_scratch': '-', 'semanticsam_sam': '-.',
}

x = natsorted(COVID_IF_AIS.keys())
Expand All @@ -128,8 +118,10 @@ def make_covid_if_lineplot():
plt.figure(figsize=(20, 15))
for model in models:
sns.lineplot(
data=df[df["Model"] == model], x='Key', y='Score', marker=markers[model],
linestyle=line_styles[model], markersize=15, linewidth=2.5, label=MODEL_NAME_MAPS[model], color=base_color,
data=df[df["Model"] == model], x='Key', y='Score',
marker=markers[model], linestyle=line_styles[model],
markersize=15, linewidth=2.5, label=MODEL_NAME_MAPS[model],
color=COLORS[model],
)

plt.xlabel("Number of Images", fontweight="bold")
Expand Down

0 comments on commit bb635fd

Please sign in to comment.