Shortcuts

Source code for mmeval.metrics.rouge

# Copyright (c) OpenMMLab. All rights reserved.
# This class is modified from `torchmetrics
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/rouge.py>`_.
import re
from collections import Counter
from difflib import SequenceMatcher
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional,
                    Sequence, Tuple, Union)

from mmeval import BaseMetric
from mmeval.metrics.utils import get_tokenizer, infer_language
from mmeval.utils import try_import

if TYPE_CHECKING:
    import nltk
else:
    nltk = try_import('nltk')


def _compute_precision_recall_fmeasure(matches: int, pred_len: int,
                                       reference_len: int) -> Dict[str, float]:
    """This computes precision, recall and F1 score based on matches.

    Args:
        matches (int): A number of matches or a length of
            the longest common subsequence.
        pred_len (int): A length of a tokenized predicted sentence.
        reference_len (int): A length of a tokenized referenced sentence.

    Returns:
        Dict[str, float]: A dict with the following keys:

        - precision (float): The precision score.
        - recall (float): The recall score.
        - fmeasure (float): The f1-score.
    """
    if matches == 0:
        return dict(precision=0., recall=0., fmeasure=0.)

    precision = matches / pred_len
    recall = matches / reference_len

    fmeasure = 2 * precision * recall / (precision + recall)
    return dict(
        precision=float(precision),
        recall=float(recall),
        fmeasure=float(fmeasure))


def _rougeL_score(pred: Sequence[str],
                  reference: Sequence[str]) -> Dict[str, float]:
    """This computes precision, recall and F1 score for the Rouge-L metric.

    Args:
        pred (Sequence[str]): A predicted sentence.
        reference (Sequence[str]): A referenced sentence.

    Returns:
        Dict[str, float]: Calculate the score of rougeL.
    """
    pred_len, reference_len = len(pred), len(reference)
    if pred_len == 0 or reference_len == 0:
        return dict(precision=0., recall=0., fmeasure=0.)
    lcs = 0
    matches = SequenceMatcher(None, pred, reference).get_matching_blocks()
    for match in matches:
        lcs += match.size
    return _compute_precision_recall_fmeasure(lcs, pred_len, reference_len)


def _rougeN_score(pred: Sequence[str], reference: Sequence[str],
                  n_gram: int) -> Dict[str, float]:
    """This computes precision, recall and F1 score for the Rouge-N metric.

    Args:
        pred (Sequence[str]): A predicted sentence.
        reference (Sequence[str]): A referenced sentence.
        n_gram (int): The number of words contained in a phrase
            when calculating word fragments.

    Returns:
        Dict[str, float]: Calculate the score of rougeN.
    """

    def _create_ngrams(tokens: Sequence[str], n: int) -> Counter:
        ngrams: Counter = Counter()
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i + n])
            ngrams[ngram] += 1
        return ngrams

    pred_ngarms = _create_ngrams(pred, n_gram)
    reference_ngarms = _create_ngrams(reference, n_gram)
    pred_len = sum(pred_ngarms.values())
    reference_len = sum(reference_ngarms.values())
    if pred_len == 0 or reference_len == 0:
        return dict(precision=0., recall=0., fmeasure=0.)

    # Take the intersection of n_gram of prediction and reference.
    hits = sum(
        min(pred_ngarms[w], reference_ngarms[w]) for w in set(pred_ngarms))
    return _compute_precision_recall_fmeasure(hits, pred_len, reference_len)


[docs]class ROUGE(BaseMetric): """Calculate Rouge Score used for automatic summarization. This metric proposed in `ROUGE: A Package for Automatic Evaluation of Summaries <https://www.aclweb.org/anthology/W04-1013.pdf>`_ are common evaluation indicators in the fields of machine translation, automatic summarization, question and answer generation, etc. Args: rouge_keys (List or Tuple or int or str): A list of rouge types to calculate. Keys that are allowed are ``L``, and ``1`` through ``9``. Defaults to ``(1, 2, 'L')``. use_stemmer (bool): Use Porter stemmer to strip word suffixes to improve matching. Defaults to False. normalizer (Callable, optional): A user's own normalizer function. If this is ``None``, replacing any non-alpha-numeric characters with spaces is default. Defaults to None. tokenizer (Callable or str, optional): A user's own tokenizer function. Defaults to None. accumulate (str): Useful in case of multi-reference rouge score. ``avg`` takes the average of all references with respect to predictions. ``best`` takes the best fmeasure score obtained between prediction and multiple corresponding references. Defaults to ``best``. lowercase (bool): If it is True, all characters will be lowercase. Defaults to True. **kwargs: Keyword parameters passed to :class:`BaseMetric`. Examples: >>> from mmeval import ROUGE >>> predictions = ['the cat is on the mat'] >>> references = [['a cat is on the mat']] >>> metric = ROUGE(rouge_keys='L') >>> metric.add(predictions, references) >>> results = metric.compute_metric() {'rougeL_fmeasure': 0.8333333, 'rougeL_precision': 0.8333333, 'rougeL_recall': 0.8333333} """ def __init__(self, rouge_keys: Union[List, Tuple, int, str] = (1, 2, 'L'), use_stemmer: bool = False, normalizer: Optional[Callable] = None, tokenizer: Union[Callable, str, None] = None, accumulate: str = 'best', lowercase: bool = True, **kwargs: Any): super().__init__(**kwargs) if isinstance(rouge_keys, int) or isinstance(rouge_keys, str): rouge_keys = [rouge_keys] # Check the legitimacy of the rouge_keys for rouge_key in rouge_keys: if isinstance(rouge_key, int): if rouge_key < 1 or rouge_key > 9: raise ValueError(f'Got unknown rouge key {rouge_key}. ' 'Expected to be one of {1 - 9} or L') elif rouge_key != 'L': raise ValueError(f'Got unknown rouge key {rouge_key}. ' 'Expected to be one of {1 - 9} or L') self.rouge_keys = rouge_keys # use stemmer in nltk if necessary if use_stemmer and nltk is not None: self.stemmer = nltk.stem.porter.PorterStemmer() elif use_stemmer and nltk is None: raise ValueError( 'The nltk package is needed to use stemmer, ' 'check https://www.nltk.org/install.html for installation.') else: self.stemmer = None self.normalizer = normalizer # Select tokenizer according to the entered value. self.tokenizer_fn = None if callable(tokenizer): self.tokenizer_fn = tokenizer elif isinstance(tokenizer, str): self.tokenizer_fn = get_tokenizer(tokenizer) if self.tokenizer_fn is None: raise ValueError('Right now, `tokenizer` only supports ' "pre-defined 'en' or 'cn'.") else: assert tokenizer is None, \ f'`tokenizer` supports Callable, str or None, but not `{type(tokenizer)}`' # noqa: E501 assert accumulate in ['best', 'avg'], \ f'Wrong accumulate {accumulate}. Supported accumulate are "best" and "avg"' # noqa: E501 self.accumulate = accumulate self.lowercase = lowercase
[docs] def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501 """Add the intermediate results to ``self._results``. Args: predictions (Sequence[str]): An iterable of predicted sentences. references (Sequence[Sequence[str]): An iterable of referenced sentences. Each predicted sentence may correspond to multiple referenced sentences. """ # If the tokenizer is None, check the first sentence # to determine which language the tokenizer is used. if self.tokenizer_fn is None: language = infer_language(predictions[0]) self.tokenizer_fn = get_tokenizer(language) # Traverse the predicted sentences for prediction, _references in zip(predictions, references): scores_per_rouge_keys = self._compute_rouge_score( prediction, _references) self._results.append(scores_per_rouge_keys)
def _compute_rouge_score(self, prediction: str, references: Sequence[str]) -> Sequence[tuple]: """Compute the rouge score. Args: prediction (str): The predicted sentence. references (Sequence[str]): The referenced sentences. Each predicted sentence may correspond to multiple referenced sentences. Returns: Sequence[tuple]: The rouge scores corresponding to each ``rouge_key``. And each scores is a tuple of (fmeasure, precision, recall). """ assert isinstance(references, Sequence), \ f'The `references` should be a sequence of string, but got {type(references)}.' # noqa: E501 assert len(references) > 0, \ 'The number of references should large than 0.' pred_token = self._normalize_and_tokenize(prediction) ref_tokens = [ self._normalize_and_tokenize(refs) for refs in references ] # Traverse the chosen rouge_keys scores_per_rouge_keys = [] for rouge_key in self.rouge_keys: # Traverse the tokens of references for single prediction. scores = [] for ref_token in ref_tokens: if isinstance(rouge_key, int): score = _rougeN_score(pred_token, ref_token, rouge_key) else: score = _rougeL_score(pred_token, ref_token) scores.append(score) # Accumulate rouge score across multiple reference. if self.accumulate == 'best': fmeasure = max(score['fmeasure'] for score in scores) precision = max(score['precision'] for score in scores) recall = max(score['recall'] for score in scores) else: fmeasure = sum(score['fmeasure'] for score in scores) / len(scores) precision = sum(score['precision'] for score in scores) / len(scores) recall = sum(score['recall'] for score in scores) / len(scores) scores_per_rouge_keys.append((fmeasure, precision, recall)) return scores_per_rouge_keys
[docs] def compute_metric(self, results: List[Any]) -> dict: """Compute the rouge metric. This method would be invoked in ``BaseMetric.compute`` after distributed synchronization. Args: results (List): A list that consists correct numbers. This list has already been synced across all ranks. Returns: Dict[str, float]: The computed rouge score. """ fmeasure = [0] * len(self.rouge_keys) recall = [0] * len(self.rouge_keys) precision = [0] * len(self.rouge_keys) for result in results: for i, each_rouge in enumerate(result): fmeasure[i] += each_rouge[0] precision[i] += each_rouge[1] recall[i] += each_rouge[2] metric_results = {} num_samples = len(self._results) for i, rouge_key in enumerate(self.rouge_keys): metric_results[ f'rouge{rouge_key}_fmeasure'] = fmeasure[i] / num_samples metric_results[ f'rouge{rouge_key}_precision'] = precision[i] / num_samples metric_results[ f'rouge{rouge_key}_recall'] = recall[i] / num_samples return metric_results
def _normalize_and_tokenize(self, text: str) -> Sequence[str]: """Normalize and tokenize the given text. Rouge score should be calculated only over lowercased words and digits. Optionally, ``nltk.stem.porter.PorterStemmer`` can be used to strip word suffixes for better matching. Args: text (str): An input sentence. Returns: Sequence[str]: The tokens after normalizer and tokenizer. """ if self.tokenizer_fn == str.split: if callable(self.normalizer): text = self.normalizer(text) elif self.lowercase: text = re.sub(r'[^a-z0-9]+', ' ', text.lower()) else: text = re.sub(r'[^A-Za-z0-9]+', ' ', text) tokens = self.tokenizer_fn(text) # type: ignore if self.stemmer: tokens = [ self.stemmer.stem(x) if len(x) > 3 else x for x in tokens ] tokens = [x for x in tokens if (isinstance(x, str) and len(x) > 0)] return tokens
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.