Module elpis.datasets

from elpis.datasets.dataset import CleaningOptions, Dataset, ProcessingBatch
from elpis.datasets.preprocessing import process_batch
from elpis.datasets.processing import prepare_dataset, create_dataset

__all__ = [




def create_dataset(job: Job) ‑> datasets.dataset_dict.DatasetDict | datasets.dataset_dict.IterableDatasetDict
def create_dataset(job: Job) -> DatasetDict | IterableDatasetDict:
    if Path(job.data_args.dataset_name_or_path).is_dir():
        return create_local_dataset(job)

    return create_hf_dataset(job)
def prepare_dataset(job: Job, tokenizer:, feature_extractor:, dataset: datasets.dataset_dict.DatasetDict | datasets.dataset_dict.IterableDatasetDict) ‑> datasets.dataset_dict.DatasetDict | datasets.dataset_dict.IterableDatasetDict

Runs some preprocessing over the given dataset.


dataset: The dataset on which to apply the preprocessing processor: The processor to apply over the dataset

def prepare_dataset(
    job: Job,
    tokenizer: AutoTokenizer,
    feature_extractor: AutoFeatureExtractor,
    dataset: DatasetDict | IterableDatasetDict,
) -> DatasetDict | IterableDatasetDict:
    """Runs some preprocessing over the given dataset.

        dataset: The dataset on which to apply the preprocessing
        processor: The processor to apply over the dataset
    dataset = clean_dataset(job, dataset)
    dataset = constrain_to_max_samples(job, dataset)

    # Load the audio data and resample if necessary.
    dataset = dataset.cast_column(
        Audio(sampling_rate=feature_extractor.sampling_rate),  # type: ignore

    def _prepare_dataset(batch: Dict) -> Dict[str, List]:
        audio = batch[job.data_args.audio_column_name]
        inputs = feature_extractor(  # type: ignore
            audio["array"], sampling_rate=audio["sampling_rate"]

        batch["input_values"] = inputs.input_values[0]
        batch["input_length"] = len(batch["input_values"])

        # encode targets
        additional_kwargs = {}
        phoneme_language = job.data_args.phoneme_language
        if phoneme_language is not None:
            additional_kwargs["phonemizer_lang"] = phoneme_language

        batch["labels"] = tokenizer(batch[job.data_args.text_column_name], **additional_kwargs).input_ids  # type: ignore
        return batch

    max_input_length = (
        job.data_args.max_duration_in_seconds * feature_extractor.sampling_rate  # type: ignore
    min_input_length = (
        job.data_args.min_duration_in_seconds * feature_extractor.sampling_rate  # type: ignore

    def is_audio_in_length_range(length: int):
        return length >= min_input_length and length <= max_input_length

    with job.training_args.main_process_first(desc="dataset map preprocessing"):
        worker_count = job.data_args.preprocessing_num_workers

        kwargs = {}
        if not job.data_args.stream_dataset:
            kwargs = {
                "num_proc": worker_count,
                "desc": "Dataset Preprocessing",

        dataset =

        # filter data that is shorter than min_input_length
        dataset = dataset.filter(
            is_audio_in_length_range, input_columns=["input_length"], **kwargs
        )"Test encoding labels: {dataset['train'][0]['labels']}")

    return dataset
def process_batch(batch: ProcessingBatch, output_dir: pathlib.Path = PosixPath('/tmp')) ‑> Iterable[pathlib.Path]

Generates training files from the processing batch and puts them in the given directory.


batch: The processing batch to generate files from output_dir: The directory in which to stick the files.


The paths of the generated files.

def process_batch(
    batch: ProcessingBatch, output_dir: Path = DEFAULT_DIR
) -> Iterable[Path]:
    """Generates training files from the processing batch and puts them in
    the given directory.

        batch: The processing batch to generate files from
        output_dir: The directory in which to stick the files.

        The paths of the generated files.
    annotations = extract_annotations(
        transcription_file=batch.transcription_file, elan_options=batch.elan_options

    annotations = map(
        lambda annotation: clean_annotation(annotation, batch.cleaning_options),

    # Generate training files from the annotations
    return chain(
            lambda annotation: generate_training_files(
                annotation, output_dir=output_dir


class CleaningOptions (punctuation_to_remove: str = '', punctuation_to_explode: str = '', words_to_remove: List[str] = <factory>)

A class representing cleaning options for a dataset.

class CleaningOptions:
    """A class representing cleaning options for a dataset."""

    punctuation_to_remove: str = ""
    punctuation_to_explode: str = ""
    words_to_remove: List[str] = field(default_factory=list)

    def from_dict(cls, data: Dict[str, Any]) -> CleaningOptions:
        kwargs = { data[] for field in fields(CleaningOptions)}
        return cls(**kwargs)

    def to_dict(self) -> Dict[str, Any]:
        return dict(self.__dict__)

var punctuation_to_explode : str
var punctuation_to_remove : str
var words_to_remove : List[str]

def from_dict(data: Dict[str, Any]) ‑> CleaningOptions
def from_dict(cls, data: Dict[str, Any]) -> CleaningOptions:
    kwargs = { data[] for field in fields(CleaningOptions)}
    return cls(**kwargs)


def to_dict(self) ‑> Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
    return dict(self.__dict__)
class Dataset (name: str, files: List[Path], cleaning_options: CleaningOptions, elan_options: Optional[ElanOptions])

A class representing an unprocessed dataset.

class Dataset:
    """A class representing an unprocessed dataset."""

    name: str
    files: List[Path]
    cleaning_options: CleaningOptions
    elan_options: Optional[ElanOptions]

    def __post_init__(self):
        self.files = sorted(self.files)

    def is_empty(self) -> bool:
        """Returns true iff the dataset contains no files."""
        return len(self.files) == 0

    def has_elan(self) -> bool:
        """Returns true iff any of the files in the dataset is an elan file."""
        return any(map((lambda file_name: file_name.suffix == ".eaf"), self.files))

    def is_valid(self) -> bool:
        """Returns true iff this dataset is valid for processing."""
        return (
            not self.is_empty()
            and len(self.files) % 2 == 0
            and len(self.mismatched_files) == 0
            and len(self.colliding_files) == 0

    def is_audio(file: Path) -> bool:
        return file.suffix == ".wav"

    def is_transcript(file: Path) -> bool:
        return file.suffix in TRANSCRIPTION_EXTENSIONS

    def corresponding_audio_name(transcript_file: Path) -> Path:
        """Gets the corresponding audio file name for a given transcript file."""
        return Path(transcript_file).parent / (transcript_file.stem + ".wav")

    def transcript_files(self) -> Iterable[Path]:
        """Returns an iterable of all transcription files within the dataset."""
        return filter(Dataset.is_transcript, self.files)

    def mismatched_files(self) -> Set[Path]:
        """Returns the list of transcript files with no corresponding
        audio and vice versa.

        Corresponding in this case means that for every transcript file with
        name x.some_extension, there is a corresponding file x.wav in the dataset.

            A list of the mismatched file names.
        grouped_by_stems = groupby(self.files, lambda path: path.stem)

        def mismatches(files: Iterable[Path]) -> list[Path]:
            files = list(files)
            has_audio = any(Dataset.is_audio(file) for file in files)
            has_transcript = any(Dataset.is_transcript(file) for file in files)
            return [] if has_transcript == has_audio else files

        groups = (mismatches(g) for _, g in grouped_by_stems)
        result = set(chain.from_iterable(groups))
        return result

    def colliding_files(self) -> Set[Path]:
        """Returns the list of transcript file names that collide.

        Collide means that two transcript files would be for the same .wav

            A list of the colliding file names.
        grouped_by_stems = groupby(self.transcript_files, lambda path: path.stem)

        def collisions(files: Iterable[Path]) -> list[Path]:
            files = list(files)
            return files if len(files) >= 2 else []

        collision_groups = (collisions(g) for _, g in grouped_by_stems)
        return set(chain.from_iterable(collision_groups))

    def from_dict(cls, data: Dict[str, Any]) -> Dataset:
        name = data["name"]
        files = [Path(file) for file in data["files"]]
        cleaning_options = CleaningOptions.from_dict(data["cleaning_options"])

        elan_options = None
        if "elan_options" in data:
            elan_options = ElanOptions.from_dict(data["elan_options"])

        return cls(

    def valid_transcriptions(self):
        is_valid = lambda path: path not in (
            self.mismatched_files | self.colliding_files
        return filter(is_valid, self.transcript_files)

    def to_batches(self) -> Iterable[ProcessingBatch]:
        """Converts a valid dataset to a list of processing jobs, matching
        transcript and audio files.
        return (
            for transcription_file in self.valid_transcriptions

    def to_dict(self) -> Dict[str, Any]:
        result = {
            "files": [ for file in self.files],
            "cleaning_options": self.cleaning_options.to_dict(),

        if self.elan_options is not None:
            result["elan_options"] = self.elan_options.to_dict()

        return result

var cleaning_optionsCleaningOptions
var elan_options : Optional[ElanOptions]
var files : List[pathlib.Path]
var name : str

def corresponding_audio_name(transcript_file: Path) ‑> pathlib.Path

Gets the corresponding audio file name for a given transcript file.

def corresponding_audio_name(transcript_file: Path) -> Path:
    """Gets the corresponding audio file name for a given transcript file."""
    return Path(transcript_file).parent / (transcript_file.stem + ".wav")
def from_dict(data: Dict[str, Any]) ‑> Dataset
def from_dict(cls, data: Dict[str, Any]) -> Dataset:
    name = data["name"]
    files = [Path(file) for file in data["files"]]
    cleaning_options = CleaningOptions.from_dict(data["cleaning_options"])

    elan_options = None
    if "elan_options" in data:
        elan_options = ElanOptions.from_dict(data["elan_options"])

    return cls(
def is_audio(file: Path) ‑> bool
def is_audio(file: Path) -> bool:
    return file.suffix == ".wav"
def is_transcript(file: Path) ‑> bool
def is_transcript(file: Path) -> bool:
    return file.suffix in TRANSCRIPTION_EXTENSIONS

var colliding_files

Returns the list of transcript file names that collide.

Collide means that two transcript files would be for the same .wav file.


A list of the colliding file names.

var mismatched_files

Returns the list of transcript files with no corresponding audio and vice versa.

Corresponding in this case means that for every transcript file with name x.some_extension, there is a corresponding file x.wav in the dataset.


A list of the mismatched file names.

var transcript_files : Iterable[pathlib.Path]

Returns an iterable of all transcription files within the dataset.

def transcript_files(self) -> Iterable[Path]:
    """Returns an iterable of all transcription files within the dataset."""
    return filter(Dataset.is_transcript, self.files)
var valid_transcriptions
def valid_transcriptions(self):
    is_valid = lambda path: path not in (
        self.mismatched_files | self.colliding_files
    return filter(is_valid, self.transcript_files)


def has_elan(self) ‑> bool

Returns true iff any of the files in the dataset is an elan file.

def has_elan(self) -> bool:
    """Returns true iff any of the files in the dataset is an elan file."""
    return any(map((lambda file_name: file_name.suffix == ".eaf"), self.files))
def is_empty(self) ‑> bool

Returns true iff the dataset contains no files.

def is_empty(self) -> bool:
    """Returns true iff the dataset contains no files."""
    return len(self.files) == 0
def is_valid(self) ‑> bool

Returns true iff this dataset is valid for processing.

def is_valid(self) -> bool:
    """Returns true iff this dataset is valid for processing."""
    return (
        not self.is_empty()
        and len(self.files) % 2 == 0
        and len(self.mismatched_files) == 0
        and len(self.colliding_files) == 0
def to_batches(self) ‑> Iterable[ProcessingBatch]

Converts a valid dataset to a list of processing jobs, matching transcript and audio files.

def to_batches(self) -> Iterable[ProcessingBatch]:
    """Converts a valid dataset to a list of processing jobs, matching
    transcript and audio files.
    return (
        for transcription_file in self.valid_transcriptions
def to_dict(self) ‑> Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
    result = {
        "files": [ for file in self.files],
        "cleaning_options": self.cleaning_options.to_dict(),

    if self.elan_options is not None:
        result["elan_options"] = self.elan_options.to_dict()

    return result
class ProcessingBatch (audio_file: Path, transcription_file: Path, cleaning_options: CleaningOptions, elan_options: Optional[ElanOptions])

A class encapsulating the data needed for an individual processing job

class ProcessingBatch:
    """A class encapsulating the data needed for an individual processing job"""

    audio_file: Path
    transcription_file: Path
    cleaning_options: CleaningOptions
    elan_options: Optional[ElanOptions]

    def to_dict(self) -> Dict[str, Any]:
        result = {}

        result["audio_file"] = str(self.audio_file)
        result["transcription_file"] = str(self.transcription_file)
        result["cleaning_options"] = self.cleaning_options.to_dict()
        if self.elan_options is not None:
            result["elan_options"] = self.elan_options.to_dict()

        return result

    def from_dict(cls, data: Dict[str, Any]) -> ProcessingBatch:
        audio_file = Path(data["audio_file"])
        transcription_file = Path(data["transcription_file"])
        cleaning_options = CleaningOptions.from_dict(data["cleaning_options"])
        elan_options = ElanOptions.from_dict(data["elan_options"])
        return cls(

var audio_file : pathlib.Path
var cleaning_optionsCleaningOptions
var elan_options : Optional[ElanOptions]
var transcription_file : pathlib.Path

def from_dict(data: Dict[str, Any]) ‑> ProcessingBatch
def from_dict(cls, data: Dict[str, Any]) -> ProcessingBatch:
    audio_file = Path(data["audio_file"])
    transcription_file = Path(data["transcription_file"])
    cleaning_options = CleaningOptions.from_dict(data["cleaning_options"])
    elan_options = ElanOptions.from_dict(data["elan_options"])
    return cls(


def to_dict(self) ‑> Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
    result = {}

    result["audio_file"] = str(self.audio_file)
    result["transcription_file"] = str(self.transcription_file)
    result["cleaning_options"] = self.cleaning_options.to_dict()
    if self.elan_options is not None:
        result["elan_options"] = self.elan_options.to_dict()

    return result