Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Tensor parallel on Falcon models #582

Merged
merged 10 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/pytorch_poc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ModelConfig:
bos_token_id: int
eos_token_id: int
dtype: str
multi_query_attention: bool = False

def get_head_size(self):
return self.hidden_size // self.num_heads
10 changes: 6 additions & 4 deletions lmdeploy/pytorch_poc/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def gpu_cache(self):
def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]:
"""get shape of key block."""
num_heads = self.num_heads
if local:
assert self.num_heads % self.world_size == 0
if local and not self.model_config.multi_query_attention:
assert self.num_heads % self.world_size == 0, \
f'num_heads: {self.num_heads}, world_size: {self.world_size}'
num_heads = self.num_heads // self.world_size
return (
self.block_size,
Expand All @@ -79,8 +80,9 @@ def get_value_block_shape(self,
local: bool = False) -> Tuple[int, int, int]:
"""get shape of value block."""
num_heads = self.num_heads
if local:
assert self.num_heads % self.world_size == 0
if local and not self.model_config.multi_query_attention:
assert self.num_heads % self.world_size == 0, \
f'num_heads: {self.num_heads}, world_size: {self.world_size}'
num_heads = self.num_heads // self.world_size
return (
self.block_size,
Expand Down
9 changes: 8 additions & 1 deletion lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,17 @@ def __init__(
num_cpu_blocks=0,
num_gpu_blocks=0)
if 'falcon' in model_path:
if hf_config.new_decoder_architecture:
# 40b-instruct, GQA
kv_dim = hf_config.hidden_size // hf_config.num_attention_heads
kv_dim *= hf_config.num_kv_heads
kv_head = hf_config.num_kv_heads
if hf_config.multi_query:
# 7b-instruct, MQA
kv_dim = hf_config.hidden_size // hf_config.num_attention_heads
kv_head = 1
else:
# rw-1b, MHA
kv_dim = hf_config.hidden_size
kv_head = hf_config.num_attention_heads
model_config = ModelConfig(
Expand All @@ -493,7 +500,7 @@ def __init__(
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
multi_query_attention=hf_config.multi_query)
elif 'chatglm' in model_path:
model_config = ModelConfig(
hf_config.hidden_size // hf_config.num_attention_heads *
Expand Down
167 changes: 165 additions & 2 deletions lmdeploy/pytorch_poc/models/falcon.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapter from:
# Adapted from:
# https://huggingface.co/tiiuae/falcon-7b-instruct
# https://github.com/huggingface/transformers/blob/v4.33-release/src/transformers/models/falcon/modeling_falcon.py # noqa

import logging
from typing import Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed._tensor import DeviceMesh
from transformers.modeling_outputs import \
BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon.modeling_falcon import build_alibi_tensor

from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from lmdeploy.pytorch_poc.kernels import (alibi_paged_attention_fwd,
fill_kv_cache, paged_attention_fwd)

logger = logging.getLogger()


# rotary pos emb helpers
# (torch.jit.script does not seem to support staticmethod...)
Expand Down Expand Up @@ -82,6 +89,141 @@ def forward(self, query, key, position_ids_or_past_key_values_length=0):

class PatchedFalconAttention(nn.Module):

# @classmethod
def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""

world_size = dist.get_world_size()

if mod_name in ['query_key_value']:
if self.new_decoder_architecture:
# e.g. 40b-instruct, GQA
# split qkv across groups
# no finer-grained partitioning
weight = mod.weight.reshape(
-1, # num groups
(self.num_heads + self.num_kv_heads * 2) * self.head_dim,
self.hidden_size,
)
elif self.multi_query:
# e.g. 7b-instruct, MQA
# split to q, copy kv
weight = mod.weight.reshape(
-1,
self.head_dim,
self.hidden_size,
)
q_weight = weight[:self.num_heads]
k_weight = weight[self.num_heads:self.num_heads + 1]
v_weight = weight[self.num_heads + 1:self.num_heads + 2]
q_weight_shards = torch.tensor_split(q_weight,
world_size,
dim=0)
weight_shards = []
for q in q_weight_shards:
# only shard q heads but
# copy single k/v head to all ranks
weight_shards.append(q)
weight_shards.append(k_weight)
weight_shards.append(v_weight)
mod.weight.data = torch.cat(weight_shards, dim=0)
# here we keep the weight to be 3D,
# so that column parallel will split it
# into integer-numbered heads

# no bias for 7b-instruct and 40b-instruct

colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

if self.new_decoder_architecture or self.multi_query:
# return to 2D for later matmul
mod.weight.data = mod.weight.data.reshape(-1, self.hidden_size)

elif mod_name in ['dense']:
if self.new_decoder_architecture:
# e.g. 40b-instruct, GQA
weight = mod.weight.reshape(
self.hidden_size,
-1, # num groups
self.num_heads * self.head_dim,
)
elif self.multi_query:
# e.g. 7b-instruct, MQA
mod.weight.data = mod.weight.reshape(self.hidden_size, -1,
self.head_dim)

rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

if self.new_decoder_architecture or self.multi_query:
mod.weight.data = mod.weight.reshape(self.hidden_size, -1)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs

def _split_heads(
self, fused_qkv: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Split the last dimension into (num_heads, head_dim), results share
same memory storage as `fused_qkv`

Args:
fused_qkv (`torch.tensor`, *required*):
[batch_size, seq_length, num_heads * 3 * head_dim]

Returns:
query: [batch_size, seq_length, num_heads, head_dim]
key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
if self.new_decoder_architecture:
# e.g. 40b-instruct model
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1,
self.num_heads // self.num_kv_heads + 2,
self.head_dim)
query = qkv[:, :, :, :-2]
key = qkv[:, :, :, [-2]]
value = qkv[:, :, :, [-1]]
# because cache_engine & kernel
# already handled grouped attention
# removing broadcast make it faster and more memory-saving
# key = torch.broadcast_to(key, query.shape)
# value = torch.broadcast_to(value, query.shape)

query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
elif not self.multi_query:
# e.g. rw-1b model
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length,
self.num_heads // dist.get_world_size(),
3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[...,
2, :]
else:
# e.g. 7b-instruct model
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
if not dist.is_initialized():
num_head = self.num_heads
else:
# this trick will, for example, split 11 into [4, 4, 3]
# following the way column parallel linear splitting
# non-dividable dims
num_head = self.num_heads - dist.get_rank() - 1
num_head = 1 + num_head // dist.get_world_size()
fused_qkv = fused_qkv.view(batch_size, seq_length, num_head + 2,
self.head_dim)
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[
..., [-1], :]

def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -106,7 +248,6 @@ def _contiguous_batching_forward(
hidden_states) # [batch_size, seq_length, 3 x hidden_size]

# 3 x [batch_size, seq_length, num_heads, head_dim]
# TODO: need further check when using TP
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, query_length, _, _ = query_layer.shape
Expand Down Expand Up @@ -198,6 +339,28 @@ def forward(
head_mask, use_cache, output_attentions)


class PatchedFalconMLP(nn.Module):

@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['dense_h_to_4h']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['dense_4h_to_h']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs


class PatchedFalconModel(nn.Module):

def _contiguous_batching_forward(
Expand Down
17 changes: 10 additions & 7 deletions lmdeploy/pytorch_poc/models/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,25 @@

# Falcon Models in transformer / on hub
MODULE_MAP.update({
'transformers.models.falcon.modeling_falcon.FalconAttention':
'modeling_falcon.FalconAttention':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconAttention',
'transformers.models.falcon.modeling_falcon.FalconModel':
'modeling_falcon.FalconModel':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconModel',
'transformers.models.falcon.modeling_falcon.FalconRotaryEmbedding':
'modeling_falcon.FalconRotaryEmbedding':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconRotaryEmbedding',
'modeling_falcon.FalconMLP':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconMLP',
'modeling_falcon.FalconForCausalLM':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconForCausalLM',
# for old implementations on hub
'modelling_RW.Attention':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconAttention',
'modelling_RW.MLP':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconMLP',
'modelling_RW.RWModel':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconModel',
'modelling_RW.RotaryEmbedding':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconRotaryEmbedding',
'transformers.models.falcon.modeling_falcon.FalconForCausalLM':
'lmdeploy.pytorch_poc.models.falcon.PatchedFalconForCausalLM',
# 'transformers.models.falcon.modeling_falcon.FalconDecoderLayer':
# 'lmdeploy.pytorch_poc.models.falcon.PatchedFalconDecoderLayer',
})

# baichuan
Expand Down