Module elpis.trainer

Expand source code
from elpis.trainer.trainer import run_job

__all__ = ["run_job"]

Sub-modules

elpis.trainer.data_collator
elpis.trainer.metrics
elpis.trainer.trainer
elpis.trainer.utils

Functions

def run_job(job: Job, log_file: Optional[pathlib.Path] = None) ‑> pathlib.Path

Fine-tunes a model for use in transcription.

Parameters

job: Info about the training job, e.g. training options. dataset_dir: A directory containing the preprocessed dataset to train with. log_file: An optional file to write training logs to.

Returns

A path to the folder containing the trained model.

Expand source code
def run_job(
    job: Job,
    log_file: Optional[Path] = None,
) -> Path:
    """Fine-tunes a model for use in transcription.

    Parameters:
        job: Info about the training job, e.g. training options.
        dataset_dir: A directory containing the preprocessed dataset to train with.
        log_file: An optional file to write training logs to.

    Returns:
        A path to the folder containing the trained model.
    """

    logging_context = log_to_file(log_file) if log_file is not None else nullcontext()
    with logging_context:
        # Setup required directories.
        output_dir = job.training_args.output_dir
        cache_dir = job.model_args.cache_dir
        Path(output_dir).mkdir(exist_ok=True, parents=True)

        job.save(Path(output_dir) / "job.json")
        set_seed(job.training_args.seed)

        logger.info("Preparing Datasets...")
        config = create_config(job)
        dataset = create_dataset(job)

        tokenizer = create_tokenizer(job, config, dataset)
        logger.info(f"Tokenizer: {tokenizer}")  # type: ignore
        feature_extractor = AutoFeatureExtractor.from_pretrained(
            job.model_args.model_name_or_path,
            cache_dir=cache_dir,
            token=job.data_args.token,
            trust_remote_code=job.data_args.trust_remote_code,
        )
        dataset = prepare_dataset(job, tokenizer, feature_extractor, dataset)
        logger.info("Finished Preparing Datasets")

        update_config(job, config, tokenizer)

        logger.info("Downloading pretrained model...")
        model = create_ctc_model(job, config)
        logger.info("Downloaded model.")

        # Now save everything to be able to create a single processor later
        # make sure all processes wait until data is saved
        logger.info("Saving config, tokenizer and feature extractor.")
        with job.training_args.main_process_first():
            # only the main process saves them
            if is_main_process(job.training_args.local_rank):
                feature_extractor.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)  # type: ignore
                config.save_pretrained(output_dir)  # type: ignore

        try:
            processor = AutoProcessor.from_pretrained(output_dir)
        except (OSError, KeyError):
            warnings.warn(
                "Loading a processor from a feature extractor config that does not"
                " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
                " attribute to your `preprocessor_config.json` file to suppress this warning: "
                " `'processor_class': 'Wav2Vec2Processor'`",
                FutureWarning,
            )
            processor = Wav2Vec2Processor.from_pretrained(output_dir)

        data_collator = DataCollatorCTCWithPadding(processor=processor)  # type: ignore

        # Initialize Trainer
        trainer = Trainer(
            model=model,  # type: ignore
            data_collator=data_collator,
            args=job.training_args,
            compute_metrics=create_metrics(job.data_args.eval_metrics, processor),
            train_dataset=dataset["train"] if job.training_args.do_train else None,  # type: ignore
            eval_dataset=dataset["eval"] if job.training_args.do_eval else None,  # type: ignore
            tokenizer=processor,  # type: ignore
        )

        logger.info(f"Begin training model...")
        train(job, trainer, dataset)
        logger.info(f"Finished training!")

        evaluate(job, trainer, dataset)
        clean_up(job, trainer)

        return Path(output_dir)