Module elpis.trainer
Expand source code
from elpis.trainer.trainer import run_job
__all__ = ["run_job"]
Sub-modules
elpis.trainer.data_collator
elpis.trainer.metrics
elpis.trainer.trainer
elpis.trainer.utils
Functions
def run_job(job: Job, log_file: Optional[pathlib.Path] = None) ‑> pathlib.Path
-
Fine-tunes a model for use in transcription.
Parameters
job: Info about the training job, e.g. training options. dataset_dir: A directory containing the preprocessed dataset to train with. log_file: An optional file to write training logs to.
Returns
A path to the folder containing the trained model.
Expand source code
def run_job( job: Job, log_file: Optional[Path] = None, ) -> Path: """Fine-tunes a model for use in transcription. Parameters: job: Info about the training job, e.g. training options. dataset_dir: A directory containing the preprocessed dataset to train with. log_file: An optional file to write training logs to. Returns: A path to the folder containing the trained model. """ logging_context = log_to_file(log_file) if log_file is not None else nullcontext() with logging_context: # Setup required directories. output_dir = job.training_args.output_dir cache_dir = job.model_args.cache_dir Path(output_dir).mkdir(exist_ok=True, parents=True) job.save(Path(output_dir) / "job.json") set_seed(job.training_args.seed) logger.info("Preparing Datasets...") config = create_config(job) dataset = create_dataset(job) tokenizer = create_tokenizer(job, config, dataset) logger.info(f"Tokenizer: {tokenizer}") # type: ignore feature_extractor = AutoFeatureExtractor.from_pretrained( job.model_args.model_name_or_path, cache_dir=cache_dir, token=job.data_args.token, trust_remote_code=job.data_args.trust_remote_code, ) dataset = prepare_dataset(job, tokenizer, feature_extractor, dataset) logger.info("Finished Preparing Datasets") update_config(job, config, tokenizer) logger.info("Downloading pretrained model...") model = create_ctc_model(job, config) logger.info("Downloaded model.") # Now save everything to be able to create a single processor later # make sure all processes wait until data is saved logger.info("Saving config, tokenizer and feature extractor.") with job.training_args.main_process_first(): # only the main process saves them if is_main_process(job.training_args.local_rank): feature_extractor.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # type: ignore config.save_pretrained(output_dir) # type: ignore try: processor = AutoProcessor.from_pretrained(output_dir) except (OSError, KeyError): warnings.warn( "Loading a processor from a feature extractor config that does not" " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following " " attribute to your `preprocessor_config.json` file to suppress this warning: " " `'processor_class': 'Wav2Vec2Processor'`", FutureWarning, ) processor = Wav2Vec2Processor.from_pretrained(output_dir) data_collator = DataCollatorCTCWithPadding(processor=processor) # type: ignore # Initialize Trainer trainer = Trainer( model=model, # type: ignore data_collator=data_collator, args=job.training_args, compute_metrics=create_metrics(job.data_args.eval_metrics, processor), train_dataset=dataset["train"] if job.training_args.do_train else None, # type: ignore eval_dataset=dataset["eval"] if job.training_args.do_eval else None, # type: ignore tokenizer=processor, # type: ignore ) logger.info(f"Begin training model...") train(job, trainer, dataset) logger.info(f"Finished training!") evaluate(job, trainer, dataset) clean_up(job, trainer) return Path(output_dir)