Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I add the importance_weights of each class to the corn loss? #39

Open
teinhonglo opened this issue Aug 4, 2023 · 3 comments
Open

Comments

@teinhonglo
Copy link

Hi,

Thanks for sharing the code.

I noticed that a importance_weights of the coral loss.
Could I add the importance_weights of each class to the corn loss?

Many thanks,
Tien-Hong

@rasbt
Copy link
Member

rasbt commented Aug 4, 2023

Yes, they could be added. We omitted them for simplicity in the CORN paper.

@teinhonglo
Copy link
Author

teinhonglo commented Aug 4, 2023

Thanks for your kind reply.

I haven't run the code
if the shape of the importance_weights is (#NUM_CLASS, 1),
Is the following modified code (#comment) correct?

def corn_loss(logits, y_train, num_classes, importance_weights):
    sets = []
    for i in range(num_classes-1):
        label_mask = y_train > i-1
        label_tensor = (y_train[label_mask] > i).to(torch.int64)
        sets.append((label_mask, label_tensor))

    num_examples = 0
    losses = 0.
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]

        if len(train_labels) < 1:
            continue

        num_examples += len(train_labels)
        pred = logits[train_examples, task_index]

        loss = -torch.sum(F.logsigmoid(pred)*train_labels
                          + (F.logsigmoid(pred) - pred)*(1-train_labels))
        
        #losses += loss
        losses += importance_weights[task_index] * loss

    return losses/num_examples

@teinhonglo teinhonglo changed the title Could I add the importance_weights of each class to the corn loss? How can I add the importance_weights of each class to the corn loss? Aug 6, 2023
@rasbt
Copy link
Member

rasbt commented Aug 7, 2023

Yes, this looks correct to me. You can also add a default argument so that it performs like before if someone doesn't specify the importance weights:

def corn_loss(logits, y_train, num_classes, importance_weights=None):
    sets = []
    
    for i in range(num_classes-1):
        label_mask = y_train > i-1
        label_tensor = (y_train[label_mask] > i).to(torch.int64)
        sets.append((label_mask, label_tensor))

    num_examples = 0
    losses = 0.
    
    if importance_weights is None:
        importance_weights = torch.ones(len(sets))
    
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]

        if len(train_labels) < 1:
            continue

        num_examples += len(train_labels)
        pred = logits[train_examples, task_index]

        loss = -torch.sum(F.logsigmoid(pred)*train_labels
                          + (F.logsigmoid(pred) - pred)*(1-train_labels))
        
        #losses += loss
        losses += importance_weights[task_index] * loss

    return losses/num_examples

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants