Shortcuts

Source code for mmeval.metrics.matting_mse

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from typing import Dict, List, Sequence

from mmeval.core import BaseMetric


[docs]class MattingMeanSquaredError(BaseMetric): """Mean Squared Error metric for image matting. This metric computes the per-pixel squared error average across all pixels. i.e. mean((a-b)^2) Args: **kwargs: Keyword parameters passed to :class:`BaseMetric`. Note: The current implementation assumes the image / alpha / trimap a numpy array with pixel values ranging from 0 to 255. The pred_alpha should be masked by trimap before passing into this metric. The trimap is the most commonly used prior knowledge. As the name implies, trimap is a ternary graph and each pixel takes one of {0, 128, 255}, representing the foreground, the unknown and the background respectively. Examples: >>> from mmeval import MattingMeanSquaredError as MattingMSE >>> import numpy as np >>> >>> matting_mse = MattingMSE() >>> pred_alpha = np.zeros((32, 32), dtype=np.uint8) >>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 >>> trimap = np.zeros((32, 32), dtype=np.uint8) >>> trimap[:16, :16] = 128 >>> trimap[16:, 16:] = 255 >>> matting_mse(pred_alpha, gt_alpha, trimap) # doctest: +ELLIPSIS {'matting_mse': ...} """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs)
[docs] def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 """Add MattingMeanSquaredError score of batch to ``self._results`` Args: pred_alphas (Sequence[np.ndarray]): Predict the probability that pixels belong to the foreground. gt_alphas (Sequence[np.ndarray]): Probability that the actual pixel belongs to the foreground. trimaps (Sequence[np.ndarray]): Broadly speaking, the trimap consists of foreground and unknown region. """ for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas, trimaps): assert pred_alpha.shape == gt_alpha.shape, 'The shape of ' \ '`pred_alpha` and `gt_alpha` should be the same, but got: ' \ f'{pred_alpha.shape} and {gt_alpha.shape}' weight_sum = (trimap == 128).sum() if weight_sum != 0: mse_result = ((pred_alpha - gt_alpha)**2).sum() / weight_sum else: mse_result = 0 self._results.append(mse_result)
[docs] def compute_metric(self, results: List) -> Dict[str, float]: """Compute the MattingMeanSquaredError metric. Args: results (List): A list that consisting the MattingMeanSquaredError score. This list has already been synced across all ranks. Returns: Dict[str, float]: The computed MattingMeanSquaredError metric. The keys are the names of the metrics, and the values are corresponding results. """ return {'matting_mse': float(np.array(results).mean())}
# Keep the deprecated metric name as an alias. # The deprecated Metric names will be removed in 1.0.0! MattingMSE = MattingMeanSquaredError
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.