Skip to content

Commit

Permalink
bf16_optimizer: fixes to different grad acc dtype (#6485)
Browse files Browse the repository at this point in the history
- fix step function to cast to FP32 before step in case of different
gradient accumulation data type
- remove redundatn function initialize_optimizer_states()
  • Loading branch information
nelyahu authored Sep 4, 2024
1 parent 9b7fc54 commit cfc6ed3
Showing 1 changed file with 10 additions and 23 deletions.
33 changes: 10 additions & 23 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ def _setup_for_real_optimizer(self):

see_memory_usage(f'after initializing group {i}', force=True)

see_memory_usage('before initialize_optimizer', force=True)
self.initialize_optimizer_states()
see_memory_usage('end initialize_optimizer', force=True)

self._grad_acc_hooks = []
if self.immediate_grad_update:
self.create_grad_acc_hooks()
Expand Down Expand Up @@ -252,25 +248,6 @@ def _lazy_init_hp_params_optimizer_state(self):
self.optimizer.state)
self._hp_optimizer_states_linked = True

def initialize_optimizer_states(self):
"""Take an optimizer step with zero-valued gradients to allocate internal
optimizer state.
This helps prevent memory fragmentation by allocating optimizer state at the
beginning of training instead of after activations have been allocated.
"""
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
# In case of grad acc dtype different than FP32, need to cast to high precision.
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition

if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None

self.clear_hp_grads()

def _split_flat_tensor(self, flat_tensor, num_elem_list):
assert sum(num_elem_list) <= flat_tensor.numel()
tensor_list = []
Expand Down Expand Up @@ -317,8 +294,18 @@ def step(self, closure=None):
mpu=self.mpu,
use_graph=self.graph_harvesting)

for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
# In case of grad acc dtype different than FP32, need to cast to high precision.
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition

self.optimizer.step()

if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None

# We need to link optimizer state after the first step() call
self._lazy_init_hp_params_optimizer_state()

Expand Down

0 comments on commit cfc6ed3

Please sign in to comment.