Module elpis.trainer.metrics

Expand source code
from typing import Callable, Dict, Optional, Sequence

import evaluate
import numpy as np
from loguru import logger
from transformers import EvalPrediction, Wav2Vec2Processor


def create_metrics(
    metric_names: Sequence[str], processor: Wav2Vec2Processor
) -> Optional[Callable[[EvalPrediction], Dict]]:
    # Handle metrics
    if len(metric_names) == 0:
        return

    # Note: was using evaluate.combine but was having many unexpected errors.
    metrics = {name: evaluate.load(name) for name in metric_names}

    def compute_metrics(pred: EvalPrediction) -> Dict:
        # taken from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
        pred_logits = pred.predictions

        pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id  # type: ignore

        # Taken from: https://discuss.huggingface.co/t/code-review-compute-metrics-for-wer-with-wav2vec2processorwithlm/16841/3
        if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":
            pred_str = processor.batch_decode(pred_logits).text
        else:
            pred_ids = np.argmax(pred_logits, axis=-1)
            pred_str = processor.batch_decode(pred_ids)

        # We do not want to group tokens when computing the metrics
        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

        logger.debug(f"METRICS->pred: {pred_str} label:{label_str}")

        result = {
            name: metric.compute(predictions=pred_str, references=label_str)
            for name, metric in metrics.items()
        }
        logger.debug(f"Metrics Result: {result}")
        return result

    return compute_metrics

Functions

def create_metrics(metric_names: Sequence[str], processor: transformers.models.wav2vec2.processing_wav2vec2.Wav2Vec2Processor) ‑> Optional[Callable[[transformers.trainer_utils.EvalPrediction], Dict]]
Expand source code
def create_metrics(
    metric_names: Sequence[str], processor: Wav2Vec2Processor
) -> Optional[Callable[[EvalPrediction], Dict]]:
    # Handle metrics
    if len(metric_names) == 0:
        return

    # Note: was using evaluate.combine but was having many unexpected errors.
    metrics = {name: evaluate.load(name) for name in metric_names}

    def compute_metrics(pred: EvalPrediction) -> Dict:
        # taken from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
        pred_logits = pred.predictions

        pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id  # type: ignore

        # Taken from: https://discuss.huggingface.co/t/code-review-compute-metrics-for-wer-with-wav2vec2processorwithlm/16841/3
        if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":
            pred_str = processor.batch_decode(pred_logits).text
        else:
            pred_ids = np.argmax(pred_logits, axis=-1)
            pred_str = processor.batch_decode(pred_ids)

        # We do not want to group tokens when computing the metrics
        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

        logger.debug(f"METRICS->pred: {pred_str} label:{label_str}")

        result = {
            name: metric.compute(predictions=pred_str, references=label_str)
            for name, metric in metrics.items()
        }
        logger.debug(f"Metrics Result: {result}")
        return result

    return compute_metrics