Skip to content

Commit

Permalink
Fix patch for parameter partitioning in zero.Init() (#6388)
Browse files Browse the repository at this point in the history
This PR fixes an issue addressed in #5921.
With this change, we only apply the patch for parameter partitioning to
classes that have `__init__` so that we can avoid applying the patch
multiple times.
The class that does not have `__init__` now uses its superclass's one.
So this PR also applies the patch to the root class,
`torch.nn.modules.module.Module`.

Thanks @VeryLazyBoy for the report and initial solution.

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
tohtana and loadams authored Sep 4, 2024
1 parent 9d17116 commit ddeb0c1
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def new_tensor(cls, *args, **kwargs) -> Tensor:


# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls):
def get_all_subclasses(cls, include_root=True):
subclass_list = []

def recurse(cl):
Expand All @@ -272,7 +272,10 @@ def recurse(cl):

recurse(cls)

return set(subclass_list)
ret = set(subclass_list)
if include_root:
ret.add(cls)
return ret


@instrument_w_nvtx
Expand Down Expand Up @@ -465,11 +468,13 @@ def wrapper(*args, **kwargs):
return wrapper

def _enable_class_apply(cls):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)
if '_apply' in cls.__dict__:
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook
if hasattr(cls, '_old_apply_of_skip_init_hook'):
cls._apply = cls._old_apply_of_skip_init_hook

# add hooks for to_empty: apply_(empty_like)
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
Expand Down Expand Up @@ -522,12 +527,14 @@ def wrapper(module, *args, **kwargs):
return wrapper

def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
if '__init__' in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
if '__init__' in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
Expand Down Expand Up @@ -567,7 +574,8 @@ def unpatch_init_and_builtins(self):
if self.patched:

def _disable_class(cls):
cls.__init__ = cls._old_init
if hasattr(cls, '_old_init'):
cls.__init__ = cls._old_init

for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class(subclass)
Expand Down

0 comments on commit ddeb0c1

Please sign in to comment.