Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to add Precision and Recall Metrics #18

Open
SasipreetamMorsa opened this issue Apr 17, 2021 · 0 comments
Open

How to add Precision and Recall Metrics #18

SasipreetamMorsa opened this issue Apr 17, 2021 · 0 comments

Comments

@SasipreetamMorsa
Copy link

SasipreetamMorsa commented Apr 17, 2021

I would like to get Precision and Recall metrics during the training of the Varmisuse task. I have tried modifying varmisuse_task.py with the following code:

Inside make_task_output_model, at approximately line 438
`predicted = tf.argmax(tf.nn.softmax(logits), 1, output_type=tf.int32)
prediction_is_correct = tf.equal(predicted, correct_choices)
accuracy = tf.reduce_mean(tf.cast(prediction_is_correct, tf.float32))

    TP = tf.count_nonzero(predicted * correct_choices)
    TN = tf.count_nonzero((1-predicted) * (1-correct_choices))
    FP = tf.count_nonzero(predicted * (1-correct_choices))
    FN = tf.count_nonzero((1-predicted) * correct_choices)

    precision = tf.divide(TP, TP+FP)
    recall = tf.divide(TP, TP+FN)

    tf.summary.scalar('accuracy', accuracy)
    model_ops['task_metrics'] = {
        'loss': tf.reduce_mean(per_graph_loss),
        'total_loss': tf.reduce_sum(per_graph_loss),
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'num_correct_predictions': tf.reduce_sum(tf.cast(prediction_is_correct, tf.int32)),
    }`

Inside pretty_print_epoch_task_metrics:

acc = sum([m['num_correct_predictions'] for m in task_metric_results]) / float(num_graphs) precision = sum([m['precision'] for m in task_metric_results]) / float(num_graphs) recall = sum([m['recall'] for m in task_metric_results]) / float(num_graphs) return "Accuracy: %.3f | Precision: %.3f | Recall: %.3f" % (acc, precision, recall)

However, this code outputs nan for both Precision and Recall. If anyone knows why this happens and could point me in the right direction, I would greatly appreciate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant