Shortcuts

Source code for mmeval.metrics.connectivity_error

# Copyright (c) OpenMMLab. All rights reserved.Dict
import cv2
import numpy as np
from typing import Dict, List, Sequence

from mmeval.core import BaseMetric


[docs]class ConnectivityError(BaseMetric): """Connectivity error for evaluating alpha matte prediction. Args: step (float): Step of threshold when computing intersection between `alpha` and `pred_alpha`. Default to 0.1 . norm_const (int): Divide the result to reduce its magnitude. Defaults to 1000 . **kwargs: Keyword parameters passed to :class:`BaseMetric`. Note: The current implementation assumes the image / alpha / trimap a numpy array with pixel values ranging from 0 to 255. The pred_alpha should be masked by trimap before passing into this metric. The trimap is the most commonly used prior knowledge. As the name implies, trimap is a ternary graph and each pixel takes one of {0, 128, 255}, representing the foreground, the unknown and the background respectively. Examples: >>> from mmeval import ConnectivityError >>> import numpy as np >>> >>> connectivity_error = ConnectivityError() >>> pred_alpha = np.zeros((32, 32), dtype=np.uint8) >>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 >>> trimap = np.zeros((32, 32), dtype=np.uint8) >>> trimap[:16, :16] = 128 >>> trimap[16:, 16:] = 255 >>> connectivity_error(pred_alpha, gt_alpha, trimap) {'connectivity_error': ...} """ def __init__(self, step: float = 0.1, norm_const: int = 1000, **kwargs) -> None: super().__init__(**kwargs) self.step = step self.norm_const = norm_const
[docs] def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 """Add ConnectivityError score of batch to ``self._results`` Args: pred_alphas (Sequence[np.ndarray]): Predict the probability that pixels belong to the foreground. gt_alphas (Sequence[np.ndarray]): Probability that the actual pixel belongs to the foreground. trimaps (Sequence[np.ndarray]): Broadly speaking, the trimap consists of foreground and unknown region. """ for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas, trimaps): assert pred_alpha.shape == gt_alpha.shape, 'The shape of ' \ '`pred_alpha` and `gt_alpha` should be the same, but got: ' \ f'{pred_alpha.shape} and {gt_alpha.shape}' thresh_steps = np.arange(0, 1 + self.step, self.step) round_down_map = -np.ones_like(gt_alpha) for i in range(1, len(thresh_steps)): gt_alpha_thresh = gt_alpha >= thresh_steps[i] pred_alpha_thresh = pred_alpha >= thresh_steps[i] intersection = gt_alpha_thresh & pred_alpha_thresh intersection = intersection.astype(np.uint8) # connected components _, output, stats, _ = cv2.connectedComponentsWithStats( intersection, connectivity=4) # start from 1 in dim 0 to exclude background size = stats[1:, -1] # largest connected component of the intersection omega = np.zeros_like(gt_alpha) if len(size) != 0: max_id = np.argmax(size) # plus one to include background omega[output == max_id + 1] = 1 mask = (round_down_map == -1) & (omega == 0) round_down_map[mask] = thresh_steps[i - 1] round_down_map[round_down_map == -1] = 1 gt_alpha_diff = gt_alpha - round_down_map pred_alpha_diff = pred_alpha - round_down_map # only calculate difference larger than or equal to 0.15 gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15) pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15) connectivity_error = np.sum( np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128)) # divide by norm_const to reduce the magnitude of the result connectivity_error /= self.norm_const self._results.append(connectivity_error)
[docs] def compute_metric(self, results: List) -> Dict[str, float]: """Compute the ConnectivityError metric. Args: results (List): A list that consisting the ConnectivityError score. This list has already been synced across all ranks. Returns: Dict[str, float]: The computed ConnectivityError metric. The keys are the names of the metrics, and the values are corresponding results. """ return {'connectivity_error': float(np.array(results).mean())}
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.