Source code for domainlab.utils.perf_metrics

"""Classification Performance"""
import numpy as np
import torch
from torchmetrics.classification import (
    AUC,
    AUROC,
    Accuracy,
    ConfusionMatrix,
    F1Score,
    Precision,
    Recall,
    Specificity,
)


[docs] class PerfMetricClassif: """Classification Performance metrics""" def __init__(self, num_classes, agg_precision_recall_f1="macro"): super().__init__() self.acc = Accuracy(num_classes=num_classes, average="micro") # NOTE: only micro aggregation make sense for acc self.precision = Precision( num_classes=num_classes, average=agg_precision_recall_f1 ) # Calculate the metric for each class separately, and average the # metrics across classes (with equal weights for each class). self.recall = Recall(num_classes=num_classes, average=agg_precision_recall_f1) self.f1_score = F1Score( num_classes=num_classes, average=agg_precision_recall_f1 ) # NOTE: auroc does nto support "micro" as aggregation self.auroc = AUROC(num_classes=num_classes, average=agg_precision_recall_f1) self.specificity = Specificity( num_classes=num_classes, average=agg_precision_recall_f1 ) self.confmat = ConfusionMatrix(num_classes=num_classes)
[docs] def cal_metrics(self, model, loader_te, device): """ :param model: :param loader_te: :param device: for final test, GPU can be used """ self.acc.reset() self.precision.reset() self.recall.reset() self.f1_score.reset() self.auroc.reset() self.specificity.reset() self.confmat.reset() self.acc = self.acc.to(device) self.precision = self.precision.to(device) self.recall = self.recall.to(device) self.f1_score = self.f1_score.to(device) self.auroc = self.auroc.to(device) self.specificity = self.specificity.to(device) self.confmat = self.confmat.to(device) model_local = model.to(device) model_local.eval() with torch.no_grad(): for _, (x_s, y_s, *_) in enumerate(loader_te): x_s, y_s = x_s.to(device), y_s.to(device) _, prob, _, *_ = model_local.infer_y_vpicn(x_s) _, target_label = torch.max(y_s, 1) # self.acc.update(pred_label, target_label) # bug found by xinyuejohn # pred_label is (N,C) which we need to convert to (N) # prob is (N,C) which is fine for torchmetrics self.acc.update(prob, target_label) self.precision.update(prob, target_label) self.recall.update(prob, target_label) self.specificity.update(prob, target_label) self.f1_score.update(prob, target_label) self.auroc.update(prob, target_label) self.confmat.update(prob, target_label) acc_y = self.acc.compute() precision_y = self.precision.compute() recall_y = self.recall.compute() specificity_y = self.specificity.compute() f1_score_y = self.f1_score.compute() auroc_y = self.auroc.compute() confmat_y = self.confmat.compute() dict_metric = { "acc": acc_y, "precision": precision_y, "recall": recall_y, "specificity": specificity_y, "f1": f1_score_y, "auroc": auroc_y, "confmat": confmat_y, } keys = list(dict_metric) keys.remove("confmat") for key in keys: dict_metric[key] = dict_metric[key].cpu().numpy().sum() dict_metric["confmat"] = dict_metric["confmat"].cpu().numpy() return dict_metric