Skip to content

Commit

Permalink
Improve robustness of fast-stat-sync
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#986

Differential Revision: D19372552

Pulled By: myleott

fbshipit-source-id: 58a41ab71a09924b20832810d4e07dea932861cb
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Jan 13, 2020
1 parent 7201ebc commit ab6ce42
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 24 deletions.
2 changes: 0 additions & 2 deletions fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ def load_state_dict(self, state_dict):
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):

def shuffle_batches(batches, seed):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
return batches
Expand Down
42 changes: 20 additions & 22 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ def _all_gather_list_sync(
def _fast_stat_sync_sum(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum
*extra_stats_to_sum,
min_buffer_size: int = 50,
):
"""
Sync logging outputs across workers. fast_stat_sync_sum is
Expand All @@ -647,35 +648,32 @@ def _fast_stat_sync_sum(
num_extra = len(extra_stats_to_sum)
if len(logging_outputs) > 0:
sorted_keys = sorted(logging_outputs[0].keys())
stats = list(extra_stats_to_sum) + [
stats = [0.] + list(extra_stats_to_sum) + [
sum(log.get(k, 0) for log in logging_outputs)
for k in sorted_keys
]
stats = stats + [0.]*(min_buffer_size - len(stats))
buf = torch.cuda.DoubleTensor(stats)

# When the number of batches is not evenly divisible by the
# number of GPUs, logging_outputs will be empty for some
# workers in the last iteration. But we still need to know
# the keys and buffer size, so we cache the state in case it
# needs to be reused by this worker later.
self._fss_buf = buf
self._fss_sorted_keys = sorted_keys
elif self._fss_buf is not None:
buf = self._fss_buf
buf.zero_()
buf[:num_extra] = torch.cuda.DoubleTensor(extra_stats_to_sum)
sorted_keys = self._fss_sorted_keys
else:
raise RuntimeError(
'fast_stat_sync failed, perhaps (# GPUs) > (# batches)?'
)

buf = torch.zeros(min_buffer_size, dtype=torch.double, device='cuda')
buf[0] = 1. # flag to indicate we should fallback to _all_gather_list_sync

# stats buffer is organized like:
# 0: flag to indicate whether fast-stat-sync should be disabled
# 1-i: extra_stats_to_sum
# i-j: values from logging_outputs (sorted by key)
# j-min_buffer_size: padded with 0s
distributed_utils.all_reduce(buf)

buf = buf.tolist()
extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:]
stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}]
return [stats] + extra_stats_to_sum
fallback = buf[0]
if fallback > 0.:
# fallback to _all_gather_list_sync
return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum)
else:
extra_stats_to_sum, stats = buf[1:num_extra + 1], buf[num_extra + 1:]
stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}]
return [stats] + extra_stats_to_sum

def _check_grad_norms(self, grad_norm):
"""Check that grad norms are consistent across workers."""
Expand Down

0 comments on commit ab6ce42

Please sign in to comment.