Source code for mlscorecheck.individual._check_scores_tptn

"""
This module implements the checking of scores with possible
tp and tn combinations.
"""

import copy

from ..core import NUMERICAL_TOLERANCE, logger, update_uncertainty

from ._interval import Interval, IntervalUnion
from ._utils import resolve_aliases_and_complements, create_intervals
from ._tptn_solution_bundles import tptn_solutions, sens_tp, spec_tn, is_applicable_tptn
from ._pair_solutions import solution_specifications

__all__ = [
    "check_scores_tptn_pairs",
    "_check_scores_tptn_pairs",
    "_check_scores_tptn_intervals",
]

preferred_order = [
    "acc",
    "sens",
    "spec",
    "bacc",
    "npv",
    "ppv",
    "f1p",
    "f1n",
    "fbp",
    "fbn",
    "fm",
    "bm",
    "pt",
    "lrp",
    "lrn",
    "mk",
    "dor",
    "ji",
    "gm",
    "upm",
    "kappa",
    "mcc",
]


def iterate_tptn(
    *,
    score: str,
    score_value: float,
    valid_pairs: dict,
    sol_fun,
    params: dict,
    iterate_by: str
) -> dict:
    """
    Iterate through the potential values of tp or tn and construct feasible pairs

    Args:
        score (str): name of the score being used
        score_value (float): the value of the score
        valid_pairs (dict(int,Interval|IntervalUnion)): the valid tp,tn pairs
        sol_fun (callable): the solution function providing an interval for ``tp`` if
                            iterate_by is ``tn`` and vice versa.
        params (dict): the parameters to use
        iterate_by (str): the figure to iterate by (``tp``/``tn``)

    Returns:
        dict(int,Interval|IntervalUnion): the feasible (``tp``,``tn``) pairs
    """
    results = {}

    for value in valid_pairs:
        if not is_applicable_tptn(
            score, score_value, "tp" if iterate_by == "tn" else "tn"
        ):
            results[value] = valid_pairs[value]
            continue

        interval = sol_fun(**params, **{iterate_by: value})

        if interval is None:
            results[value] = valid_pairs[value]
            continue

        interval = interval.intersection(valid_pairs[value]).shrink_to_integers()

        if not interval.is_empty():
            results[value] = interval

    return results


def update_sens(p: int, valid_pairs: dict, score_int, solve_for: str) -> dict:
    """
    Update sensitivity intervals

    Args:
        p (int): the number of positives
        valid_pairs (dict(int,Interval|IntervalUnion)): the actual intervals
        score_int (Interval): the score interval
        solve_for (str): 'tp'/'tn' - the figure to solve for

    Returns:
        dict(int,Interval|IntervalUnion): the updated intervals
    """
    ints = sens_tp(sens=score_int, p=p)
    ints = ints.intersection(Interval(0, p + 1)).shrink_to_integers()

    if solve_for == "tp":
        valid_pairs = {
            key: value.intersection(ints) for key, value in valid_pairs.items()
        }
        valid_pairs = {
            key: value for key, value in valid_pairs.items() if not value.is_empty()
        }
    else:
        valid_pairs = {
            value: interval
            for value, interval in valid_pairs.items()
            if ints.contains(value)
        }

    return valid_pairs


def update_spec(n: int, valid_pairs: dict, score_int, solve_for: str) -> dict:
    """
    Update specificity intervals

    Args:
        n (int): the number of negatives
        valid_pairs (dict(int,Interval|IntervalUnion)): the actual intervals
        score_int (Interval|IntervalUnion): the score interval
        solve_for (str): 'tp'/'tn' - the figure to solve for

    Returns:
        dict(int,Interval|IntervalUnion): the updated intervals
    """
    ints = spec_tn(spec=score_int, n=n)
    ints = ints.intersection(Interval(0, n + 1)).shrink_to_integers()

    if solve_for == "tn":
        valid_pairs = {
            key: value.intersection(ints) for key, value in valid_pairs.items()
        }
        valid_pairs = {
            key: value for key, value in valid_pairs.items() if not value.is_empty()
        }
    else:
        valid_pairs = {
            value: interval
            for value, interval in valid_pairs.items()
            if ints.contains(value)
        }

    return valid_pairs


def initialize_valid_pairs(
    p: int, n: int, iterate_by: str, init_tptn_intervals: dict
) -> dict:
    """
    Initializes the valid pairs, either from the original input or the
    prefiltered intervals.

    Args:
        p (int): the number of positives
        n (int): the number of negatives
        iterate_by (str): the figure to iterate by ('tp'/'tn')
        init_tptn_intervals (dict(str,tuple)): the prefiltered 'tp' and 'tn' intervals

    Returns:
        dict(int,Interval|IntervalUnion): the initialized pairs
    """
    if init_tptn_intervals is not None:
        tp_int = IntervalUnion(init_tptn_intervals["tp"])
        tn_int = IntervalUnion(init_tptn_intervals["tn"])
        valid_pairs = {}
        if iterate_by == "tp":
            for interval in tp_int.intervals:
                for tp in range(interval.lower_bound, interval.upper_bound + 1):
                    valid_pairs[tp] = copy.deepcopy(tn_int)
        else:
            for interval in tn_int.intervals:
                for tn in range(interval.lower_bound, interval.upper_bound + 1):
                    valid_pairs[tn] = copy.deepcopy(tp_int)
        return valid_pairs

    init_interval = Interval(0, n + 1) if iterate_by == "tp" else Interval(0, p + 1)
    return {key: init_interval for key in range(p + 1 if iterate_by == "tp" else n + 1)}


def _check_scores_tptn_pairs(
    p: int,
    n: int,
    scores: dict,
    eps,
    *,
    numerical_tolerance: float = NUMERICAL_TOLERANCE,
    solve_for: str = None,
    init_tptn_intervals: dict = None
) -> dict:
    """
    Check scores by iteratively reducing the set of feasible ``tp``, ``tn`` pairs.

    Args:
        p (int): the number of positives
        n (int): the number of negatives
        scores (dict): the available reported scores
        eps (float|dict(str,float)): the numerical uncertainties for all scores or each
                                        score individually
        numerical_tolerance (float): the additional numerical tolerance
        solve_for (str): the figure solving for (the other is used to iterate by) (``tp``/``tn``)
                        If None, the optimal one is being used.
        init_tptn_intervals (None|dict(str,tuple)): the initial tp and tn intervals

    Returns:
        dict: a summary of the results. When the ``inconsistency`` flag is True, it indicates
        that the set of feasible ``tp``, ``tn`` pairs is empty. The list under the key
        ``details`` provides further details from the analysis of the scores one after the other.
        Under the key ``n_valid_tptn_pairs`` one finds the number of tp and tn pairs compatible with
        all scores.
    """
    if solve_for is None:
        solve_for = "tn" if p < n else "tp"

    if solve_for not in {"tp", "tn"}:
        raise ValueError(
            "The specified ``solve_for`` variable needs to be either "
            "``tp`` or ``tn``."
        )

    iterate_by = "tn" if solve_for == "tp" else "tp"

    # resolving aliases and complements
    scores = resolve_aliases_and_complements(scores)

    # updating the uncertainties
    eps = eps if isinstance(eps, dict) else {score: eps for score in scores}
    eps = update_uncertainty(eps, numerical_tolerance)

    params = {
        "p": p,
        "n": n,
        "beta_positive": scores.get("beta_positive"),
        "beta_negative": scores.get("beta_negative"),
    }

    valid_pairs = initialize_valid_pairs(p, n, iterate_by, init_tptn_intervals)

    details = []

    for score in [score for score in preferred_order if score in scores]:
        logger.info(
            "testing %s, feasible tptn pairs: %d",
            score,
            p * n if valid_pairs is None else len(valid_pairs),
        )

        score_int = Interval(scores[score] - eps[score], scores[score] + eps[score])

        params[score] = score_int

        detail = {
            "testing": score,
            "score_interval": score_int,
            "n_tptn_pairs_before": p * n if valid_pairs is None else len(valid_pairs),
        }

        if score not in {"sens", "spec"}:
            valid_pairs = iterate_tptn(
                score=score,
                score_value=scores[score],
                valid_pairs=valid_pairs,
                sol_fun=tptn_solutions[score][solve_for],
                params=params,
                iterate_by=iterate_by,
            )
        elif score == "sens":
            valid_pairs = update_sens(
                p=p, valid_pairs=valid_pairs, score_int=score_int, solve_for=solve_for
            )
        else:
            # score == 'spec'
            valid_pairs = update_spec(
                n=n, valid_pairs=valid_pairs, score_int=score_int, solve_for=solve_for
            )

        detail["n_tptn_pairs_after"] = len(valid_pairs)
        detail["decision"] = "continue" if len(valid_pairs) > 0 else "infeasible"
        details.append(detail)

        if len(valid_pairs) == 0:
            logger.info("no more feasible tp,tn pairs left")
            break

    total_count = sum(interval.integer_counts() for interval in valid_pairs.values())
    logger.info("constructing final tp, tn pair set")
    logger.info("final number of intervals: %d", len(valid_pairs))
    logger.info("final number of pairs: %d", total_count)

    return {
        "inconsistency": len(valid_pairs) == 0,
        "details": details,
        "n_valid_tptn_pairs": total_count,
        "iterate_by": iterate_by,
        "solve_for": solve_for,
        "evidence": (
            {
                iterate_by: list(valid_pairs.keys())[0],
                solve_for: valid_pairs[list(valid_pairs.keys())[0]].representing_int(),
            }
            if len(valid_pairs) > 0
            else None
        ),
    }


def check_all_negative_base(sols: list) -> bool:
    """
    Check if all solutions have negative base

    Args:
        sols (list(dict)): the list of solutions

    Returns:
        bool: True if all solutions have negative base, False otherwise
    """
    return all(sol.get("message") == "negative base" for sol in sols)


def check_any_zero_division(sols: list) -> bool:
    """
    Check if any solution has zero division

    Args:
        sols (list(dict)): the list of solutions

    Returns:
        bool: True if at least one solution has zero division, False otherwise
    """
    return any(sol.get("message") == "zero division" for sol in sols)


def update_tptn(tp, tn, sols: list):
    """
    Updates the tp and tn intervals based on the solutions

    Args:
        tp (Interval|IntervalUnion): the true positive interval
        tn (Interval|IntervalUnion): the true negative interval
        sols (list(dict)): the list of the solutions

    Returns:
        IntervalUnion, IntervalUnion: the updated tp and tn intervals
    """
    tp_union = IntervalUnion([sol["tp"] for sol in sols if sol["tp"] is not None])
    tn_union = IntervalUnion([sol["tn"] for sol in sols if sol["tn"] is not None])

    logger.info("the tp solutions: %s", tp_union)
    logger.info("the tn solutions: %s", tn_union)

    tp = tp.intersection(tp_union).shrink_to_integers()
    tn = tn.intersection(tn_union).shrink_to_integers()

    return tp, tn


def _check_scores_tptn_intervals(
    p: int,
    n: int,
    scores: dict,
    eps,
    *,
    numerical_tolerance: float = NUMERICAL_TOLERANCE
) -> dict:
    """
    Check scores by iteratively reducing the set of feasible ``tp``, ``tn`` pairs.

    Args:
        p (int): the number of positives
        n (int): the number of negatives
        scores (dict): the available reported scores
        eps (float|dict(str,float)): the numerical uncertainties for all scores or each
                                        score individually
        numerical_tolerance (float): the additional numerical tolerance

    Returns:
        dict: a summary of the results. When the ``inconsistency`` flag is True, it indicates
        that the set of feasible ``tp``, ``tn`` pairs is empty. The list under the key
        ``details`` provides further details from the analysis of the scores one after the other.
        Under the keys ``tp`` and ``tn`` one finds the final intervals or interval unions.
    """
    logger.info("checking the scores %s", str(scores))

    # resolving aliases and complements
    scores = resolve_aliases_and_complements(scores)

    # updating the uncertainties
    eps = eps if isinstance(eps, dict) else {score: eps for score in scores}
    eps = update_uncertainty(eps, numerical_tolerance)

    params = {
        "p": p,
        "n": n,
        "beta_positive": scores.get("beta_positive"),
        "beta_negative": scores.get("beta_negative"),
    }
    params |= create_intervals(scores, eps)

    tp = Interval(0, p)
    tn = Interval(0, n)

    details = []

    score_names = list(scores.keys())

    for idx, score0 in enumerate(score_names):
        for score1 in score_names[idx + 1 :]:
            detail = {
                "base_score_0": score0,
                "base_score_1": score1,
                "base_score_0_interval": params[score0].to_tuple(),
                "base_score_1_interval": params[score1].to_tuple(),
            }

            if tuple(sorted([score0, score1])) not in solution_specifications:
                logger.info("there is no solution for %s and %s", score0, score1)
                details.append(
                    detail
                    | {
                        "inconsistency": False,
                        "explanation": "there is no solution for the pair",
                    }
                )
                continue

            logger.info(
                "evaluating the tp and tn solution for %s and %s", score0, score1
            )

            sols = solution_specifications[(tuple(sorted([score0, score1])))].evaluate(
                params
            )

            if check_all_negative_base(sols):
                details.append(
                    detail
                    | {
                        "inconsistency": True,
                        "explanation": "all solutions lead to negative bases",
                    }
                )
                logger.info("all negative bases - iteration finished")
                break
            if check_any_zero_division(sols):
                details.append(
                    detail
                    | {
                        "inconsistency": False,
                        "explanation": "zero division indicates an "
                        "underdetermined system",
                    }
                )
                logger.info("all zero divisions - iteration continued")
                continue

            logger.info(
                "intervals before: %s, %s", str(tp.to_tuple()), str(tn.to_tuple())
            )
            tp, tn = update_tptn(tp, tn, sols)
            logger.info(
                "intervals after: %s, %s", str(tp.to_tuple()), str(tn.to_tuple())
            )

            details.append(
                detail
                | {
                    "tp_after": tp.to_tuple(),
                    "tn_after": tn.to_tuple(),
                    "inconsistency": (tp.is_empty()) and (tn.is_empty()),
                }
            )

        else:
            continue
        break

    return {
        "inconsistency": (tp.is_empty()) and (tn.is_empty()),
        "tp": tp.to_tuple(),
        "tn": tn.to_tuple(),
        "details": details,
    }


[docs] def check_scores_tptn_pairs( p: int, n: int, scores: dict, eps, *, numerical_tolerance: float = NUMERICAL_TOLERANCE, solve_for: str = None, prefilter_by_pairs: bool = False ) -> dict: """ Check scores by iteratively reducing the set of feasible ``tp``, ``tn`` pairs. Args: p (int): the number of positives n (int): the number of negatives scores (dict): the available reported scores eps (float|dict(str,float)): the numerical uncertainties for all scores or each score individually numerical_tolerance (float): the additional numerical tolerance solve_for (str): the figure solving for (the other is used to iterate by) (``tp``/``tn``) If None, the optimal one is being used. prefilter_by_pairs (bool): whether to prefilter the tp and tn intervals by the pairwise solutions Returns: dict: a summary of the results. When the ``inconsistency`` flag is True, it indicates that the set of feasible ``tp``, ``tn`` pairs is empty. The list under the key ``details`` provides further details from the analysis of the scores one after the other. Under the key ``n_valid_tptn_pairs`` one finds the number of tp and tn pairs compatible with all scores. Under the key ``prefiltering_details`` one finds the results of the prefiltering by using the solutions for the score pairs. """ if not prefilter_by_pairs: return _check_scores_tptn_pairs( p, n, scores, eps, numerical_tolerance=numerical_tolerance, solve_for=solve_for, ) results_interval = _check_scores_tptn_intervals( p, n, scores, eps, numerical_tolerance=numerical_tolerance ) results = _check_scores_tptn_pairs( p, n, scores, eps, numerical_tolerance=numerical_tolerance, solve_for=solve_for, init_tptn_intervals={ "tp": results_interval["tp"], "tn": results_interval["tn"], }, ) results["prefiltering_details"] = results_interval return results