Shortcuts

Source code for mmeval.metrics.mse

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from typing import Dict, List, Optional, Sequence

from mmeval.core import BaseMetric


[docs]class MeanSquaredError(BaseMetric): """Mean Squared Error metric for image. Formula: mean((a-b)^2). Args: **kwargs: Keyword parameters passed to :class:`BaseMetric`. Examples: >>> from mmeval import MeanSquaredError as MSE >>> import numpy as np >>> >>> mse = MSE() >>> gts = np.random.randint(0, 255, size=(3, 32, 32)) >>> preds = np.random.randint(0, 255, size=(3, 32, 32)) >>> mse(preds, gts) # doctest: +ELLIPSIS {'mse': ...} Calculate MeanSquaredError between 2 images with mask: >>> img1 = np.ones((32, 32, 3)) >>> img2 = np.ones((32, 32, 3)) * 2 >>> mask = np.ones((32, 32, 3)) * 2 >>> mask[:16] *= 0 >>> MSE.compute_mse(img1, img2, mask) 0.000015378700496 """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs)
[docs] def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarray], masks: Optional[Sequence[np.ndarray]] = None) -> None: # type: ignore # yapf: disable # noqa: E501 """Add MeanSquaredError score of batch to ``self._results`` Args: predictions (Sequence[np.ndarray]): Predictions of the model. groundtruths (Sequence[np.ndarray]): The ground truth images. masks (Sequence[np.ndarray], optional): Mask images. """ for i, (prediction, groundtruth) in enumerate(zip(predictions, groundtruths)): assert groundtruth.shape == prediction.shape, ( f'Image shapes are different: \ {groundtruth.shape}, {prediction.shape}.') if masks is None: result = self.compute_mse(prediction, groundtruth) else: # when prediction is a image if len(prediction.shape) <= 3: result = self.compute_mse(prediction, groundtruth, masks[i]) # when prediction is a video else: result_sum = 0 for j in range(prediction.shape[0]): result_sum += self.compute_mse(prediction[j], groundtruth[j], masks[i][j]) result = result_sum / prediction.shape[0] self._results.append(result)
[docs] def compute_metric(self, results: List[np.float32]) -> Dict[str, float]: """Compute the MeanSquaredError metric. This method would be invoked in ``BaseMetric.compute`` after distributed synchronization. Args: results (List[np.float32]): A list that consisting the MeanSquaredError score. This list has already been synced across all ranks. Returns: Dict[str, float]: The computed MeanSquaredError metric. """ return {'mse': float(np.array(results).mean())}
[docs] @staticmethod def compute_mse(prediction: np.ndarray, groundtruth: np.ndarray, mask: Optional[np.ndarray] = None) -> np.float32: """Calculate MeanSquaredError (Mean Squared Error). Args: prediction (np.ndarray): Images with range [0, 255]. groundtruth (np.ndarray): Images with range [0, 255]. mask (np.ndarray, optional): Mask of evaluation. Returns: np.float32: MeanSquaredError result. """ prediction = prediction / 255. groundtruth = groundtruth / 255. diff = groundtruth - prediction diff *= diff if mask is not None: diff *= mask result = diff.sum() / mask.sum() else: result = diff.mean() return result
# Keep the deprecated metric name as an alias. # The deprecated Metric names will be removed in 1.0.0! MSE = MeanSquaredError
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.