-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_sim.py
138 lines (113 loc) · 4.66 KB
/
plot_sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def plot_correlation_comparison(
diffusion_csv, gan_csv, baseline_csv, output_filename=None
):
"""
Create a comparison plot of correlation coefficients for diffusion, GAN, and baseline metrics.
Parameters:
diffusion_csv: Path to CSV file containing diffusion model correlations
gan_csv: Path to CSV file containing GAN model correlations
baseline_csv: Path to CSV file containing baseline (real image) correlations
output_filename: Optional filename to save the plot
"""
# Read the CSV files
diff_df = pd.read_csv(diffusion_csv)
gan_df = pd.read_csv(gan_csv)
baseline_df = pd.read_csv(baseline_csv)
# Create sets of metrics for each type
diff_metrics = set(diff_df["Metric"])
gan_metrics = set(gan_df["Metric"])
baseline_metrics = set(baseline_df["Metric"])
# Get non-reference metrics (ones that appear in baseline)
nonref_metrics = baseline_metrics
# Get reference metrics (ones that only appear in synthetic datasets)
ref_metrics = diff_metrics - baseline_metrics
# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 9))
# Function to plot metrics on a given axis
def plot_metric_set(ax, metrics, title):
# Filter and sort dataframes for the given metrics
diff_filtered = diff_df[diff_df["Metric"].isin(metrics)].copy()
gan_filtered = gan_df[gan_df["Metric"].isin(metrics)].copy()
baseline_filtered = baseline_df[baseline_df["Metric"].isin(metrics)].copy()
# Sort by absolute correlation value of diffusion results
diff_filtered["Abs_Correlation"] = diff_filtered["Kendall Correlation"]
diff_filtered = diff_filtered.sort_values("Abs_Correlation", ascending=True)
metrics_order = diff_filtered["Metric"].values
y_pos = range(len(metrics_order))
# Plot bars
ax.barh(
y_pos,
diff_filtered["Kendall Correlation"],
height=0.25,
label="Diffusion Models",
color="#2196F3",
alpha=0.7,
)
# Add GAN correlations
gan_filtered = (
gan_filtered.set_index("Metric").reindex(metrics_order).reset_index()
)
ax.barh(
[y + 0.25 for y in y_pos],
gan_filtered["Kendall Correlation"],
height=0.25,
label="GANs",
color="#FF9800",
alpha=0.7,
)
# Add baseline correlations if available
if not baseline_filtered.empty:
baseline_filtered = (
baseline_filtered.set_index("Metric")
.reindex(metrics_order)
.reset_index()
)
baseline_values = baseline_filtered["Kendall Correlation"].values
ax.barh(
[y + 0.5 for y in y_pos],
baseline_values,
height=0.25,
label="Real Images (Baseline)",
color="#4CAF50",
alpha=0.7,
)
# Customize the plot
ax.axvline(x=0, color="black", linestyle="-", linewidth=0.5)
ax.grid(True, axis="x", linestyle="--", alpha=0.4)
# Add labels
ax.set_xlabel("Kendall Correlation Coefficient")
ax.set_ylabel("Quality Metrics")
ax.set_title(title, pad=10)
# Customize y-axis
ax.set_yticks([y + 0.25 for y in y_pos])
ax.set_yticklabels(metrics_order)
# Add legend
ax.legend(loc="lower right")
return ax
# Plot non-reference metrics
plot_metric_set(
ax1,
nonref_metrics,
"Non-Reference Quality Metrics\nCorrelation with Human Assessment",
)
# Plot reference metrics
plot_metric_set(
ax2, ref_metrics, "Reference Quality Metrics\nCorrelation with Human Assessment"
)
# Adjust layout
plt.tight_layout()
# Save if filename provided
if output_filename:
plt.savefig(output_filename, dpi=600, bbox_inches="tight")
return plt
# Example usage:
# Assuming you have your CSV files saved as 'diffusion_correlations.csv' and 'gan_correlations.csv'
plot_correlation_comparison(
"/home/ksamamov/GitLab/Notebooks/feat_ext_bench/data/features/20240930_150033/diff_ground_truth_correlations.csv",
"/home/ksamamov/GitLab/Notebooks/feat_ext_bench/data/features/20240930_150033/gan_ground_truth_correlations.csv",
"/home/ksamamov/GitLab/Notebooks/feat_ext_bench/data/features/20240930_150033/baseline_ground_truth_correlations.csv",
"/home/ksamamov/GitLab/Notebooks/feat_ext_bench/data/features/20240930_150033/kendall_sim_quality_metrics_correlation.png",
)