Skip to content

Commit

Permalink
LoRA support for HF::AutoModelForCausalLM (NVIDIA#10982)
Browse files Browse the repository at this point in the history
* add LinearAdapter

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add hf lora example

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove unused imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* subclass mixin

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove stale imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* undo

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix scale

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* regex selector for peft

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* move lora

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fmt

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* hf_auto_model_for_causal_lm finetune recipe

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
2 people authored and XuesongYang committed Jan 18, 2025
1 parent 786dc38 commit 9d85a50
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 4 deletions.
105 changes: 105 additions & 0 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fiddle as fdl
from pytorch_lightning.loggers import WandbLogger
from nemo import lightning as nl
from nemo.collections import llm


def mk_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
tokens = ans['input_ids']
return {
'tokens': tokens,
'labels': tokens[1:] + [tokens[-1]],
}

from datasets import load_dataset

dataset = load_dataset("rajpurkar/squad", split="train")
dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2)
return dataset


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
args = parser.parse_args()

wandb = None
if args.wandb_project is not None:
model = '_'.join(args.model.split('/')[-2:])
wandb = WandbLogger(
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_{args.strategy}',
)
grad_clip = 0.5
if args.strategy == 'fsdp':
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False
tokenizer = llm.HfAutoModelForCausalLM.configure_tokenizer(args.model)

llm.api.finetune(
model=llm.HfAutoModelForCausalLM(args.model),
data=llm.HfDatasetDataModule(
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator=args.accelerator,
strategy=args.strategy,
log_every_n_steps=1,
limit_val_batches=0.0,
num_sanity_val_steps=0,
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
log=None,
peft=llm.peft.LoRA(
target_modules=['*_proj'],
dim=32,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers import AutoModelForCausalLM

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm import fn
from nemo.lightning import io


Expand All @@ -33,7 +34,7 @@ def masked_cross_entropy(logits, targets, mask=None):
return F.cross_entropy(logits, targets)


class HfAutoModelForCausalLM(pl.LightningModule, io.IOMixin):
class HfAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin):
def __init__(self, model_name='gpt2', load_pretrained_weights=True, tokenizer=None, loss_fn=masked_cross_entropy):
super().__init__()
self.save_hyperparameters()
Expand Down
52 changes: 51 additions & 1 deletion nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import re
from dataclasses import dataclass, field
from typing import List, Literal

import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from torch import nn
Expand Down Expand Up @@ -69,6 +71,49 @@ def forward(self, x):
return linear_output + adapter_output, bias


class LinearAdapter(nn.Module):
def __init__(
self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post', lora_A_init_method='xavier'
):
super(LinearAdapter, self).__init__()
assert isinstance(orig_linear, nn.Linear)

self.orig_linear = orig_linear
self.dim = dim
self.scale = alpha / dim

# Freezer
device = self.orig_linear.weight.device
self.orig_linear.weight.requires_grad = False
if self.orig_linear.bias is not None:
self.orig_linear.bias.requires_grad = False

in_features = self.orig_linear.in_features
out_features = self.orig_linear.out_features
dtype = self.orig_linear.weight.dtype
self.lora_a = nn.Parameter(torch.zeros((in_features, dim), dtype=dtype, device=device))
self.lora_b = nn.Parameter(torch.zeros((dim, out_features), dtype=dtype, device=device))
if lora_A_init_method == 'xavier':
torch.nn.init.uniform_(self.lora_a)
else:
nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))

self.dropout = nn.Dropout(p=dropout)
assert dropout_position in ['pre', 'post'], dropout_position
self.dropout_position = dropout_position

def forward(self, x):
res = self.orig_linear(x)
if self.dropout_position == 'pre':
x = self.dropout(x)
lora_res = x @ self.lora_a
lora_res = lora_res @ self.lora_b
lora_res = lora_res * self.scale
if self.dropout_position == 'post':
lora_res = self.dropout(lora_res)
return res + lora_res


@dataclass
class LoRA(PEFT):
"""
Expand Down Expand Up @@ -142,13 +187,13 @@ def wildcard_match(pattern, key):
match = regex_pattern.match(key)
return match is not None

tp_size = parallel_state.get_tensor_model_parallel_world_size()
full_name = f"{prefix}.{name}" if prefix else name
if name in self.target_modules or any(wildcard_match(pattern, full_name) for pattern in self.target_modules):
if HAVE_TE and isinstance(m, TEColumnParallelLinear) or isinstance(m, TELayerNormColumnParallelLinear):
input_is_parallel = False
# m.in_features and m.out_features are divided by tp_size already,
# but in_features and out_features passed to ParallelLinearAdapter are not.
tp_size = parallel_state.get_tensor_model_parallel_world_size()
in_features = m.in_features
out_features = m.out_features * tp_size
# LoRA is applied after layernorm, so layernorm output must be returned
Expand All @@ -158,6 +203,7 @@ def wildcard_match(pattern, key):
m.return_layernorm_output_gathered = True
elif HAVE_TE and isinstance(m, TERowParallelLinear):
input_is_parallel = True
tp_size = parallel_state.get_tensor_model_parallel_world_size()
in_features = m.in_features * tp_size
out_features = m.out_features
elif isinstance(m, ColumnParallelLinear):
Expand All @@ -168,6 +214,10 @@ def wildcard_match(pattern, key):
input_is_parallel = True
in_features = m.input_size
out_features = m.output_size
elif isinstance(m, nn.Linear):
return LinearAdapter(
m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method
)
else:
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")

Expand Down
66 changes: 64 additions & 2 deletions nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def pretrain_recipe(
model_name: str = '',
) -> run.Partial:
"""
Create a pre-training recipe for Mistral 7B model.
Create a pre-training recipe for a HfAutoModelForCausalLM model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Expand All @@ -155,7 +155,7 @@ def pretrain_recipe(
"""
return run.Partial(
fn,
model=model(model_name),
model=model(model_name, load_pretrained_weights=False),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
Expand All @@ -166,3 +166,65 @@ def pretrain_recipe(
optim=pytorch_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)


@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
model_name: str = '',
) -> run.Partial:
"""
Create a fine-tuning recipe for a HfAutoModelForCausalLM model.
This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
Returns:
run.Partial: Partial configuration for fine-tuning.
Examples:
CLI usage:
$ nemo llm finetune --factory hf_auto_model_for_causal_lm
Python API usage:
>>> recipe = finetune_recipe(name="llama3_8b_finetune", num_nodes=2)
>>> print(recipe)
Note:
This recipe uses the SQuAD dataset for fine-tuning. For more information
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
recipe = run.Partial(
finetune,
model=model(model_name, load_pretrained_weights=True),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=pytorch_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)
if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
return recipe

0 comments on commit 9d85a50

Please sign in to comment.