Skip to content

Commit

Permalink
tar creation with multilingual
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <[email protected]>
  • Loading branch information
aklife97 committed May 9, 2021
1 parent 65f7001 commit bbdd886
Showing 1 changed file with 57 additions and 29 deletions.
86 changes: 57 additions & 29 deletions nemo/collections/nlp/data/machine_translation/preproc_mt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,45 +179,73 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None:
raise ValueError(
'src_file_name and tgt_file_name needed to create tarred dataset but could not be found.'
)
if cfg.get('multilingual'):
raise ValueError(
'Tarred dataset cannot be created with multilingual tag set to True. Pre-process each dataset one-by-one.'
)
# Preprocess data and cache for use during training
if self.global_rank == 0:
logging.info(
f"Using tarred dataset for src: {cfg.train_ds.get('src_file_name')} and tgt: {cfg.train_ds.get('tgt_file_name')}"
)

if not cfg.get('multilingual'):
src_file_list = [cfg.train_ds.get('src_file_name')]
tgt_file_list = [cfg.train_ds.get('tgt_file_name')]
outdir_list = [cfg.get('preproc_out_dir')]
else:
src_file_list = cfg.train_ds.get('src_file_name')
tgt_file_list = cfg.train_ds.get('tgt_file_name')
if isinstance(cfg.get('src_language'), ListConfig):
langs = cfg.get('src_language')
elif isinstance(cfg.get('tgt_language'), ListConfig):
langs = cfg.get('tgt_language')
else:
raise ValueError(
"Expect either cfg.src_language or cfg.tgt_language to be a list when multilingual=True."
)
outdir_list = []
for lang in langs:
outdir_list.append(os.path.join(cfg.get('preproc_out_dir'), lang))

if len(src_file_list) != len(tgt_file_list) or len(src_file_list) != len(outdir_list):
raise ValueError(
"Number of source files, target files, and multilingual language pairs must be the same."
)

# TODO: have to get tokenizers instide .preprocess_parallel because they can't be pickled
self.train_tar_files, self.train_metadata_file = MTDataPreproc.preprocess_parallel_dataset(
clean=cfg.train_ds.clean,
src_fname=cfg.train_ds.get('src_file_name'),
tgt_fname=cfg.train_ds.get('tgt_file_name'),
out_dir=cfg.get('preproc_out_dir'),
encoder_tokenizer_name=cfg.encoder_tokenizer.get('library'),
encoder_model_name=cfg.encoder.get('model_name'),
encoder_tokenizer_model=self.encoder_tokenizer_model,
encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
decoder_tokenizer_name=cfg.decoder_tokenizer.get('library'),
decoder_model_name=cfg.decoder.get('model_name'),
decoder_tokenizer_model=self.decoder_tokenizer_model,
decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
max_seq_length=cfg.train_ds.get('max_seq_length', 512),
tokens_in_batch=cfg.train_ds.get('tokens_in_batch', 8192),
lines_per_dataset_fragment=cfg.train_ds.get('lines_per_dataset_fragment', 1000000),
num_batches_per_tarfile=cfg.train_ds.get('num_batches_per_tarfile', 1000),
min_seq_length=1,
global_rank=self.global_rank,
world_size=self.world_size,
n_jobs=cfg.train_ds.get('n_preproc_jobs', -2),
tar_file_prefix=cfg.train_ds.get('tar_file_prefix', 'parallel'),
)
metadata_file_list = []
for idx, src_file in enumerate(src_file_list):
self.train_tar_files, self.train_metadata_file = MTDataPreproc.preprocess_parallel_dataset(
clean=cfg.train_ds.clean,
src_fname=src_file,
tgt_fname=tgt_file_list[idx],
out_dir=outdir_list[idx],
encoder_tokenizer_name=cfg.encoder_tokenizer.get('library'),
encoder_model_name=cfg.encoder.get('model_name'),
encoder_tokenizer_model=self.encoder_tokenizer_model,
encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
decoder_tokenizer_name=cfg.decoder_tokenizer.get('library'),
decoder_model_name=cfg.decoder.get('model_name'),
decoder_tokenizer_model=self.decoder_tokenizer_model,
decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
max_seq_length=cfg.train_ds.get('max_seq_length', 512),
tokens_in_batch=cfg.train_ds.get('tokens_in_batch', 8192),
lines_per_dataset_fragment=cfg.train_ds.get('lines_per_dataset_fragment', 1000000),
num_batches_per_tarfile=cfg.train_ds.get('num_batches_per_tarfile', 1000),
min_seq_length=1,
global_rank=self.global_rank,
world_size=self.world_size,
n_jobs=cfg.train_ds.get('n_preproc_jobs', -2),
tar_file_prefix=cfg.train_ds.get('tar_file_prefix', 'parallel'),
)
metadata_file_list.append(self.train_metadata_file)
# update config
# self._cfg.train_ds.tar_files = self.tar_files_to_string(self.train_tar_files)
# self._cfg.train_ds.tar_files = self.train_tar_files
self._cfg.train_ds.metadata_file = self.train_metadata_file
if not cfg.get('multilingual'):
self._cfg.train_ds.metadata_file = metadata_file_list[0]
else:
self._cfg.train_ds.metadata_file = metadata_file_list

logging.info(
f"Using tarred dataset created at {self.train_tar_files} and metadata created at {self._cfg.train_ds.metadata_file}"
f"Using tarred dataset created in folder(s) {outdir_list} and metadata created at {self._cfg.train_ds.metadata_file}"
)
elif cfg.train_ds.get('tar_files') is not None and cfg.train_ds.get('metadata_file') is None:
raise ValueError('A metadata file is required for tarred dataset but cfg.metadata_file is None.')
Expand Down

0 comments on commit bbdd886

Please sign in to comment.