Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/contrastive loss #50

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Prev Previous commit
trained after 2 days
Dmitry Fadeev committed Feb 10, 2024
commit 8eeca2f015b32cfa639b31824958ee655100bac2
4 changes: 2 additions & 2 deletions embedding.pt
Git LFS file not shown
3 changes: 3 additions & 0 deletions instance.pt
Git LFS file not shown
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -408,6 +408,7 @@ def train(_run, _log):
# if not (_run._id is None):
torch.save(network.state_dict(), model_path)
torch.save(embedding, 'embedding.pt')
torch.save(instance, 'instance.pt')
pickle.dump(history, open(os.path.join(checkpoint_dir, 'history_semantic.pkl'), 'wb'))


12 changes: 6 additions & 6 deletions utils/loss.py
Original file line number Diff line number Diff line change
@@ -198,8 +198,8 @@ def contrastive_loss(embedding, num_planes, segmentation, device, temperature=0.
embeddings = []

# Debug print
print(f"Batch size: {b}, Channels: {c}, Height: {h}, Width: {w}")
print(f"Number of planes: {num_planes}")
# print(f"Batch size: {b}, Channels: {c}, Height: {h}, Width: {w}")
# print(f"Number of planes: {num_planes}")

# print(embedding.size()) # 2 x 192 x 256 CHECK
nonzero = 0
@@ -210,8 +210,8 @@ def contrastive_loss(embedding, num_planes, segmentation, device, temperature=0.
# print(feature.shape) # num pixels of plane i x 2 CHECK
embeddings.append(feature)

# Debug print
print(f"Non-zero features count: {nonzero}")
# # Debug print
# print(f"Non-zero features count: {nonzero}")

centers = []
for feature in embeddings:
@@ -253,8 +253,8 @@ def contrastive_loss(embedding, num_planes, segmentation, device, temperature=0.
loss = - (temperature / base_temperature) * log_prob

# Debug print
print(f"Logits shape: {logits.shape}, Positive shape: {positive.shape}")
print(f"Sample logits: {logits[:5]}, Sample positive: {positive[:5]}")
# print(f"Logits shape: {logits.shape}, Positive shape: {positive.shape}")
# print(f"Sample logits: {logits[:5]}, Sample positive: {positive[:5]}")
print(f"Loss tensor: {loss}")

return torch.mean(loss), torch.mean(loss), torch.tensor(0)
31 changes: 31 additions & 0 deletions utils/subset_npz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import shutil
import os

def main():
# Path to the directory containing all .npz files
npz_directory = '/Users/dimafadeev/Desktop/Catalog/TUM/WS23/ML3D/repo/processed_data/train'

# Path to the .txt file containing the list of files to subset
txt_file_path = '/Users/dimafadeev/Desktop/Catalog/TUM/WS23/ML3D/repo/processed_data/train.txt'

# Path to the directory where you want to save the subset
subset_directory = '/Users/dimafadeev/Desktop/Catalog/TUM/WS23/ML3D/repo/processed_data/train_subset'

# Make sure the subset directory exists
os.makedirs(subset_directory, exist_ok=True)

# Read the list of .npz file names from the .txt file
with open(txt_file_path, 'r') as file:
subset_files = [line.strip() for line in file]

# Copy the subset .npz files
for file_name in subset_files:
full_file_path = os.path.join(npz_directory, file_name)
if os.path.isfile(full_file_path):
# Copy the file to the subset directory
shutil.copy(full_file_path, subset_directory)
else:
print(f"File {file_name} not found in the npz directory.")

if __name__ == "__main__":
main()