From ab6ce425f2894953ba393c80c5614af9d6244d31 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 13 Jan 2020 09:06:57 -0800 Subject: [PATCH] Improve robustness of fast-stat-sync Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/986 Differential Revision: D19372552 Pulled By: myleott fbshipit-source-id: 58a41ab71a09924b20832810d4e07dea932861cb --- fairseq/data/iterators.py | 2 -- fairseq/trainer.py | 42 +++++++++++++++++++-------------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index c7b7c895eb..b5fd966304 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -242,8 +242,6 @@ def load_state_dict(self, state_dict): def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0): def shuffle_batches(batches, seed): - # set seed based on the seed and epoch number so that we get - # reproducible results when resuming from checkpoints with data_utils.numpy_seed(seed): np.random.shuffle(batches) return batches diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 76c21360c5..438552bb7a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -637,7 +637,8 @@ def _all_gather_list_sync( def _fast_stat_sync_sum( self, logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum + *extra_stats_to_sum, + min_buffer_size: int = 50, ): """ Sync logging outputs across workers. fast_stat_sync_sum is @@ -647,35 +648,32 @@ def _fast_stat_sync_sum( num_extra = len(extra_stats_to_sum) if len(logging_outputs) > 0: sorted_keys = sorted(logging_outputs[0].keys()) - stats = list(extra_stats_to_sum) + [ + stats = [0.] + list(extra_stats_to_sum) + [ sum(log.get(k, 0) for log in logging_outputs) for k in sorted_keys ] + stats = stats + [0.]*(min_buffer_size - len(stats)) buf = torch.cuda.DoubleTensor(stats) - - # When the number of batches is not evenly divisible by the - # number of GPUs, logging_outputs will be empty for some - # workers in the last iteration. But we still need to know - # the keys and buffer size, so we cache the state in case it - # needs to be reused by this worker later. - self._fss_buf = buf - self._fss_sorted_keys = sorted_keys - elif self._fss_buf is not None: - buf = self._fss_buf - buf.zero_() - buf[:num_extra] = torch.cuda.DoubleTensor(extra_stats_to_sum) - sorted_keys = self._fss_sorted_keys else: - raise RuntimeError( - 'fast_stat_sync failed, perhaps (# GPUs) > (# batches)?' - ) - + buf = torch.zeros(min_buffer_size, dtype=torch.double, device='cuda') + buf[0] = 1. # flag to indicate we should fallback to _all_gather_list_sync + + # stats buffer is organized like: + # 0: flag to indicate whether fast-stat-sync should be disabled + # 1-i: extra_stats_to_sum + # i-j: values from logging_outputs (sorted by key) + # j-min_buffer_size: padded with 0s distributed_utils.all_reduce(buf) buf = buf.tolist() - extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:] - stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}] - return [stats] + extra_stats_to_sum + fallback = buf[0] + if fallback > 0.: + # fallback to _all_gather_list_sync + return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum) + else: + extra_stats_to_sum, stats = buf[1:num_extra + 1], buf[num_extra + 1:] + stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}] + return [stats] + extra_stats_to_sum def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers."""