diff --git a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml index 0c0e40562506..f111573f21eb 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml @@ -11,7 +11,7 @@ model: # configs for huggingface load_dataset function data_path: "librispeech_asr" data_name: null # name for the specific dataset to load, e.g., 'en' for MCV datasets, but some datasets don't require this field. - streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps instead of trainer.max_epochs. + streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps and trainer.limit_train_batches, instead of trainer.max_epochs. # keys for audio, sample_rate and transcription in the huggingface dataset, keys seperated by `.` for nested fields. See example at the bottom of this file. audio_key: "audio.array" diff --git a/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml b/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml index 30c082aff91f..1d00eca6d3be 100644 --- a/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml +++ b/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml @@ -4,11 +4,11 @@ input_manifest: null # Path of json file of evaluation data. Audio files should output_dir: null # Path to output directory where results will be stored num_workers: 12 sample_rate: 16000 -evaluate: False # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled +evaluate: false # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled prepare_manifest: - auto_split: True # whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue. - split_duration: 400 # try smaller number if you still have CUDA memory issue + auto_split: true # whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue. + split_duration: 400 # max length in seconds, try smaller number if you still have CUDA memory issue vad: model_path: "vad_multilingual_frame_marblenet" #.nemo local model path or pretrained model name or none diff --git a/examples/asr/speech_to_text_finetune.py b/examples/asr/speech_to_text_finetune.py index 148b11d8b70f..36a7bdc3bbdc 100644 --- a/examples/asr/speech_to_text_finetune.py +++ b/examples/asr/speech_to_text_finetune.py @@ -108,6 +108,7 @@ def get_base_model(trainer, cfg): # restore model from cached model dir asr_model = ASRModel.from_pretrained(model_name=pretrained_name) + asr_model.set_trainer(trainer) return asr_model diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 68dfaf3d6c76..138b2e36b7fa 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -57,7 +57,7 @@ def autocast(enabled=None): def prepare_manifest(config: dict) -> str: """ - Perform VAD on long audio snippet might cause CUDA out of memory issue. + Perform VAD on long audio snippet might cause CUDA out of memory issue. Automatically split manifest entry by split_duration to avoid the potential memory issue. """ if 'prepared_manifest_vad_input' in config and config['prepared_manifest_vad_input']: @@ -132,7 +132,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: args_func: label (str): label for audio snippet.y split_duration (float): max duration of each audio clip (each line in json) - window_length_in_sec (float) : length of window for generating the frame. Used for taking care of joint. + window_length_in_sec (float) : length of window for generating the frame. Used for taking care of joint. Returns: res (list) : list of generated metadata line of json for file """ @@ -205,7 +205,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: def get_vad_stream_status(data: list) -> list: """ - Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. + Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. Used for concatenating to full audio file. Args: data (list): list of filepath of audio snippet @@ -256,9 +256,9 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. + Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). - This function uses multiprocessing to speed up. + This function uses multiprocessing to speed up. Args: frame_pred_dir (str): Directory of frame prediction file to be processed. smoothing_method (str): median or mean smoothing filter. @@ -322,7 +322,7 @@ def generate_overlap_vad_seq_per_tensor( """ Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments See description in generate_overlap_vad_seq. - Use this for single instance pipeline. + Use this for single instance pipeline. """ # This function will be refactor for vectorization but this is okay for now @@ -441,7 +441,7 @@ def filter_short_segments(segments: torch.Tensor, threshold: float) -> torch.Ten Remove segments which duration is smaller than a threshold. For example, torch.Tensor([[0, 1.5], [1, 3.5], [4, 7]]) and threshold = 2.0 - -> + -> torch.Tensor([[1, 3.5], [4, 7]]) """ return segments[segments[:, 1] - segments[:, 0] >= threshold] @@ -482,20 +482,20 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. + Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: sequence (torch.Tensor) : A tensor of frame level predictions. per_args: - onset (float): onset threshold for detecting the beginning and end of a speech - offset (float): offset threshold for detecting the end of a speech. + onset (float): onset threshold for detecting the beginning and end of a speech + offset (float): offset threshold for detecting the end of a speech. pad_onset (float): adding durations before each speech segment pad_offset (float): adding durations after each speech segment; frame_length_in_sec (float): length of frame. - + Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -545,9 +545,9 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - For example, + For example, remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), - -> + -> torch.Tensor([[start1, end1],[start3, end3]]) """ for y in to_be_removed_segments: @@ -558,7 +558,7 @@ def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: tor @torch.jit.script def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: """ - Get the gap segments. + Get the gap segments. For example, torch.Tensor([[start1, end1], [start2, end2], [start3, end3]]) -> torch.Tensor([[end1, start2], [end2, start3]]) """ @@ -568,22 +568,21 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: @torch.jit.script def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: - """ Filter out short non_speech and speech segments. Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. + Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. per_args: min_duration_on (float): threshold for small non_speech deletion min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. + filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ if speech_segments.shape == torch.Size([0]): return speech_segments @@ -630,7 +629,7 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc def prepare_gen_segment_table(sequence: torch.Tensor, per_args: dict) -> Tuple[str, dict]: """ - Preparing for generating segment table. + Preparing for generating segment table. """ out_dir = per_args.get('out_dir', None) @@ -658,7 +657,7 @@ def prepare_gen_segment_table(sequence: torch.Tensor, per_args: dict) -> Tuple[s def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: """ See description in generate_overlap_vad_seq. - Use this for single instance pipeline. + Use this for single instance pipeline. """ UNIT_FRAME_LEN = 0.01 @@ -721,7 +720,7 @@ def generate_vad_segment_table( Args: vad_pred_dir (str): directory of prediction files to be processed. postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. - frame_length_in_sec (float): frame length. + frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. num_workers(float): number of process for multiprocessing Returns: @@ -1070,7 +1069,7 @@ def gen_pred_from_speech_segments( speech_segments: torch.Tensor, prob: float, shift_length_in_sec: float = 0.01 ) -> np.array: """ - Generate prediction arrays like 000111000... from speech segments {[0,1][2,4]} + Generate prediction arrays like 000111000... from speech segments {[0,1][2,4]} """ pred = np.zeros(prob.shape) speech_segments = [list(i) for i in speech_segments] @@ -1086,7 +1085,7 @@ def gen_pred_from_speech_segments( def extract_labels(path2ground_truth_label: str, time: list) -> list: """ Extract ground-truth label for given time period. - path2ground_truth_label (str): path of groundtruth RTTM file + path2ground_truth_label (str): path of groundtruth RTTM file time (list) : a list of array representing time period. """ @@ -1273,7 +1272,6 @@ def stitch_segmented_asr_output( def construct_manifest_eval( input_manifest: str, stitched_output_manifest: str, aligned_vad_asr_output_manifest: str = "vad_asr_out.json" ) -> str: - """ Generate aligned manifest for evaluation. Because some pure noise samples might not appear in stitched_output_manifest. @@ -1393,7 +1391,7 @@ def get_nonspeech_segments( Args: speech_segments (List[List[float]]): speech segment intervals loaded by load_speech_segments() max_duration (Optional[float]): maximum duration of the audio, used to calculate the last silence segment - + Returns: nonspeech_segments (List[List[float]]): intervals of non-speech segments """ @@ -1483,8 +1481,8 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). + The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. The value 0.2 here is just for easier unit testing. Args: @@ -1624,9 +1622,9 @@ def frame_vad_infer_load_manifest(cfg: DictConfig): """ Load manifest file and prepare label/rttm mapping Args: - cfg: config file + cfg: DictConfig object Returns: - manifest_orig (List[Dict]): original manifest data + manifest_orig (List[Dict]): original manifest data key_labels_map (Dict): mapping from unique_audio_name to its labels key_rttm_map (Dict): mapping from unique_audio_name to its rttm file """ @@ -1634,7 +1632,7 @@ def frame_vad_infer_load_manifest(cfg: DictConfig): key_labels_map = {} key_rttm_map = {} manifest_orig = [] - manifest_file = Path(cfg.dataset).absolute().as_posix() + manifest_file = Path(cfg.input_manifest).absolute().as_posix() with open(manifest_file, 'r') as fin: for line in fin.readlines(): entry = json.loads(line.strip()) @@ -1649,22 +1647,25 @@ def frame_vad_infer_load_manifest(cfg: DictConfig): manifest_orig.append(entry) - # always prefer RTTM labels if exist - if "label" not in entry and ("rttm_filepath" in entry or "rttm_file" in entry): + if cfg.evaluate: + # always prefer RTTM labels if exist rttm_key = "rttm_filepath" if "rttm_filepath" in entry else "rttm_file" - segments = load_speech_segments_from_rttm(entry[rttm_key]) - label_str = get_frame_labels( - segments=segments, - frame_length=cfg.vad.parameters.shift_length_in_sec, - duration=entry['duration'], - offset=entry['offset'], - ) - key_rttm_map[uniq_audio_name] = entry[rttm_key] - key_labels_map[uniq_audio_name] = [float(x) for x in label_str.split()] - elif entry.get("label", None) is not None: - key_labels_map[uniq_audio_name] = [float(x) for x in entry["label"].split()] - elif cfg.evaluate: - raise ValueError("Must have either `label` or `rttm_filepath` in manifest when evaluate=True") + rttm_file = entry.get(rttm_key, None) + if rttm_file: + rttm_file = get_full_path(audio_file=rttm_file, manifest_file=manifest_file) + segments = load_speech_segments_from_rttm(rttm_file) + label_str = get_frame_labels( + segments=segments, + frame_length=cfg.vad.parameters.shift_length_in_sec, + duration=entry['duration'], + offset=entry['offset'], + ) + key_rttm_map[uniq_audio_name] = entry[rttm_key] + key_labels_map[uniq_audio_name] = [float(x) for x in label_str.split()] + elif entry.get("label", None) is not None: + key_labels_map[uniq_audio_name] = [float(x) for x in entry["label"].split()] + else: + raise ValueError("Must have either `label` or `rttm_filepath` in manifest when evaluate=True") return manifest_orig, key_labels_map, key_rttm_map @@ -1709,7 +1710,9 @@ def frame_vad_eval_detection_error( groundtruth = key_labels_map[key] reference, hypothesis = frame_vad_construct_pyannote_object_per_file( - prediction=key_pred_rttm_map[key], groundtruth=groundtruth, frame_length_in_sec=frame_length_in_sec, + prediction=key_pred_rttm_map[key], + groundtruth=groundtruth, + frame_length_in_sec=frame_length_in_sec, ) metric(reference, hypothesis) diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 6393bb5581d6..5a0bbc2bea37 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1639,6 +1639,16 @@ def cfg(self, cfg): if hasattr(self, '_hparams_initial') and 'cfg' in self._hparams_initial: self._hparams_initial['cfg'] = OmegaConf.to_object(self._cfg) + @property + def hparams(self): + """ + Overwrite default hparams property to return the lastest model config. + Without this change, the hparams property would return the old config if there was a direct change to + self._cfg (e.g., in self.setup_optimization()) that was not done via `self.cfg = new_cfg`. + """ + self._set_hparams(OmegaConf.create({'cfg': self._cfg})) + return super().hparams + @property def validation_step_outputs(self): """