Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support deploy qwen-14b-chat #482

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## News 🎉

- \[2023/09\] TurboMind supports Qwen-14B
- \[2023/09\] TurboMind supports InternLM-20B
- \[2023/09\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/supported_models/codellama.md) for deployment guide
- \[2023/09\] TurboMind supports Baichuan2-7B
Expand Down Expand Up @@ -65,6 +66,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
| InternLM-7B | Yes | Yes | Yes | Yes | No |
| InternLM-20B | Yes | Yes | Yes | Yes | No |
| QWen-7B | Yes | Yes | Yes | No | No |
| QWen-14B | Yes | Yes | Yes | No | No |
| Baichuan-7B | Yes | Yes | Yes | Yes | No |
| Baichuan2-7B | Yes | Yes | No | No | No |
| Code Llama | Yes | Yes | No | No | No |
Expand Down
2 changes: 2 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/09\] TurboMind 支持 Qwen-14B
- \[2023/09\] TurboMind 支持 InternLM-20B 模型
- \[2023/09\] TurboMind 支持 Code Llama 所有功能:代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/supported_models/codellama.md)阅读部署方法
- \[2023/09\] TurboMind 支持 Baichuan2-7B
Expand Down Expand Up @@ -66,6 +67,7 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
| InternLM-7B | Yes | Yes | Yes | Yes | No |
| InternLM-20B | Yes | Yes | Yes | Yes | No |
| QWen-7B | Yes | Yes | Yes | No | No |
| QWen-14B | Yes | Yes | Yes | No | No |
| Baichuan-7B | Yes | Yes | Yes | Yes | No |
| Baichuan2-7B | Yes | Yes | No | No | No |
| Code Llama | Yes | Yes | No | No | No |
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def messages2prompt(self, messages, sequence_start=True):
return ret


@MODELS.register_module(name='qwen-14b')
@MODELS.register_module(name='qwen-7b')
class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat."""
Expand Down
48 changes: 32 additions & 16 deletions lmdeploy/serve/turbomind/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import fire
import safetensors
import torch
from safetensors.torch import load_file
from sentencepiece import SentencePieceProcessor

import lmdeploy
Expand Down Expand Up @@ -108,6 +109,35 @@ def tokenizer_info_qwen(model_dir: str):
return n_words, bos_id, eos_id


def load_checkpoint(model_path):
"""Load checkpoint files into torch format.

Args:
model_path (str): the checkpoint folder
Returns:
Dict[str, torch.Tensor]: weight in torch format
"""
suffixes = ['.safetensors', '.bin']
for suffix in suffixes:
files = [
file for file in os.listdir(model_path) if file.endswith(suffix)
]
if len(files) > 0:
break

assert len(files) > 0, f'could not find checkpoints in {model_path}'
files = sorted(files)
print(files)
params = {}
for file in files:
if file.endswith('.bin'):
tmp = torch.load(osp.join(model_path, file), map_location='cpu')
else:
tmp = load_file(osp.join(model_path, file))
params.update(tmp)
return params


def export(model_name: str,
num_layer: int,
norm_eps: float,
Expand Down Expand Up @@ -437,14 +467,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
_qweight = 'weight'
_suffixes = [_qweight, 'bias']

_files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
_files = sorted(_files)
print(_files)

_params = {}
for _file in _files:
_tmp = torch.load(osp.join(model_path, _file), map_location='cpu')
_params.update(_tmp)
_params = load_checkpoint(model_path)

def get_tensor(name):
"""return tensor according its name."""
Expand Down Expand Up @@ -837,14 +860,7 @@ def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
# convert weights from hf to turbomind
model_params = {}

_files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
_files = sorted(_files)
print(_files)

_params = {}
for _file in _files:
_tmp = torch.load(osp.join(model_path, _file), map_location='cpu')
_params.update(_tmp)
_params = load_checkpoint(model_path)

def get_tensor(name, trans=True):
"""return a transposed tensor according its name."""
Expand Down