Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Contextualist committed Jan 29, 2024
0 parents commit 59b4162
Show file tree
Hide file tree
Showing 19 changed files with 725 additions and 0 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
check-n-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Formatter
run: |
ruff format --check
- name: Unit tests
run: |
pytest lone_arena/
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/data/

__pycache__/
.pytest_cache/
.ruff_cache/
.vscode/
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
hooks:
- id: ruff-format
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Contextualist

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<img align="right" src="media/lone_arena-sketch-small.png" width="160" />

# Lone Arena

When comparing two LLM checkpoints, human evaluation could be tedious.
Let's strip down the evaluation process to just a single question:

![lone_arena-ui-en](media/lone_arena-ui-en.png)

Press <kbd>f</kbd> or <kbd>j</kbd> to choose the winner of each match.

Inspired by [Chatbot Arena](https://chat.lmsys.org).

## Get Started

1. In a Python (>= 3.12) environment, `pip install -r requirements.txt`
2. Fill out `config.toml` with your model endpoint infomation and prompts. See [`config-example.toml`](config-example.toml).
3. Run `python generate.py config.toml` to gather responses from models.
4. Run `python evaluate.py config.toml` to host your competition!


## Approach

Two models/checkpoints are compared by anonymous evaluation of their responses to the same prompt. For each prompt:

1. For each model, generate 8 sample responses. Run a single-elimination tournament to get top 3 responses. (8 matches x 2 models)
2. Let the best responses of two models compete, then 2nd best of two models, then 3rd best. Winner of each gets 4.8, 3.2, 2.0 points, respectively. (3 matches)

Matches are shuffled.
Number of samples and points are configurable.
In the future, I might implement [Elo](https://en.wikipedia.org/wiki/Elo_rating_system) for comparing multiple models.

## Develop

```bash
pip install -r requirements-dev.txt
pre-commit install
```
33 changes: 33 additions & 0 deletions config-example.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
data_dir = "./data"
sample = 8

[[model]]
# Entries other than `name` are passed to `openai.OpenAI` or `openai.OpenAI.chat.completion.create`
name = "chatglm3-6b-sft-checkpoint100"
base_url = "http://localhost:8000/v1"
api_key = "dummy-key"
max_retries = 1
model = ""

[[model]]
name = "gpt-3.5"
api_key = "sk-..."
max_retries = 1
model = "gpt-3.5-turbo"


[[prompt]]
name = "demo-sleep-zh"
chat = """
user: 睡不着怎么办?
assistant: 你可以试试给我讲个故事呀。
user: 诶嘿,那你想听什么样的故事呢?
"""

[[prompt]]
name = "demo-sleep-en"
chat = """
user: What should I do if I can't fall asleep?
assistant: Hmmm, maybe you can try telling me a story.
user: lol, what kind of story do you like?
"""
151 changes: 151 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from lone_arena.chatcup import Top3_1v1
from lone_arena.config import load_config, Config
from lone_arena.files import DocumentDir, ResultDir
from lone_arena.format import format_conversation

import gradio as gr

from queue import Queue
import argparse
from functools import partial

req_queue = Queue()
rsp_queue = Queue()
prog_notify = Queue()


def main(conf: Config):
def compete(a, b):
ca, cb = (
format_conversation(docs.get(a, [])),
format_conversation(docs.get(b, [])),
)
req_queue.put((ca, cb))
result = rsp_queue.get()
if result == 0:
return a, b
return b, a

mnames = [x.name for x in conf.model]
pnames = [x.name for x in conf.prompt]
docsd = DocumentDir(conf.data_dir)
docs = docsd.load(pnames, mnames)

resultd = ResultDir(conf.data_dir)
podiums, pnames_todo = resultd.load(pnames)
if podiums:
if pnames_todo:
print("Partially completed, resuming...")
else:
msg = f"Loaded completed result. To re-evaluate, remove results from {resultd.result_dir}"
print(msg)
req_queue.put((msg, ""))

cup = Top3_1v1(pnames_todo, mnames, conf.sample, conf.top3_scores)
todo_match, total_match = cup.nmatch(), cup.nmatch(len(pnames))
prog_notify.put((total_match - todo_match, total_match))
itournament = cup.run(compete)
for pname, podium in zip(pnames_todo, itournament):
podiums.append(podium)
resultd.dump(podium, docs, pname)
msg = f"End of evaluation. Winning responses can be found in {resultd.result_dir}"
req_queue.put((msg, ""))

score_weights = [p.score_weight for p in conf.prompt]
result = cup.tabulate_result(podiums, score_weights)
return gr.DataFrame(visible=True, value=result)


def init():
da, db = req_queue.get()
cm, tm = prog_notify.get()
return da, db, cm, tm, gr.Slider(value=round(cm / tm, 2))


def on_decision(completed_match: int, total_match: int, ev_data: gr.EventData):
if completed_match == total_match:
return gr.Markdown(), gr.Markdown() # no more updates
rsp_queue.put(int(ev_data.target.elem_id[-1]) - 1) # type: ignore
doc_a, doc_b = req_queue.get()
return doc_a, doc_b


def progbar_update(completed_match: int, total_match: int):
if completed_match < total_match:
completed_match += 1
progress = round(completed_match / total_match, 2)
return completed_match, gr.Slider(value=progress)


shortcut_js = """
<script>
function shortcuts(e) {
switch (e.target.tagName.toLowerCase()) {
case "input":
case "textarea":
return;
}
switch (e.key.toLowerCase()) {
case "f": return document.getElementById("choose1").click();
case "j": return document.getElementById("choose2").click();
}
}
document.addEventListener('keypress', shortcuts, false);
</script>
"""


def ui(conf: Config):
with gr.Blocks(
title="Lone Arena",
head=shortcut_js,
theme=gr.themes.Default(spacing_size=gr.themes.sizes.spacing_lg),
) as demo:
completed_match, total_match = gr.State(0), gr.State(1)
progbar = gr.Slider(0, 1, 0, label="Progress", container=False)
gr.Markdown(
"""
## Which of the two responses is better?
"""
)
with gr.Row(equal_height=True):
with gr.Column(variant="panel"):
choose1 = gr.Button("👇This one's better [f]", elem_id="choose1")
candidate1 = gr.Markdown(line_breaks=True)
with gr.Column(variant="panel"):
choose2 = gr.Button("👇This one's better [j]", elem_id="choose2")
candidate2 = gr.Markdown(line_breaks=True)
result_table = gr.DataFrame(visible=True, row_count=len(conf.prompt) + 1)

gr.on(
triggers=[choose1.click, choose2.click],
fn=on_decision,
inputs=[completed_match, total_match],
outputs=[candidate1, candidate2],
).then(
progbar_update,
inputs=[completed_match, total_match],
outputs=[completed_match, progbar],
)
demo.load(
init,
outputs=[candidate1, candidate2, completed_match, total_match, progbar],
)
# workaround for https://github.com/gradio-app/gradio/issues/7101
demo.load(
lambda: gr.DataFrame(visible=False),
outputs=[result_table],
)
demo.load(partial(main, conf), outputs=[result_table])
return demo


if __name__ == "__main__":
argp = argparse.ArgumentParser(description="Host the evaluation Web UI")
argp.add_argument("--port", type=int, default=7860)
argp.add_argument("config", type=str)
args = argp.parse_args()
conf = load_config(args.config)

demo = ui(conf)
demo.launch(server_port=args.port, show_api=False, quiet=True)
65 changes: 65 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from lone_arena.config import load_config, Model, Config
from lone_arena.files import DocumentDir

import openai
from tqdm import tqdm

import argparse
import re
from typing import Any

ROLE_TAG = re.compile(r"^(user|assistant|system): ?(.*)$")


def parse_chat(chat: str) -> list[dict]:
msg = []
for li in chat.splitlines():
if (m := ROLE_TAG.match(li)) is not None:
msg.append({"role": m.group(1), "content": m.group(2)})
continue
assert len(msg) > 0, f"missing role tag for {chat!r}"
msg[-1]["content"] += "\n" + li
return msg


def split_params(params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
params = params.copy()
client_params = {}
for k in ["api_key", "organization", "base_url", "timeout", "max_retries"]:
if k in params:
client_params[k] = params.pop(k)
return client_params, params


def batch_request(model: Model, prompts: dict[str, list], conf: Config):
client_params, completion_params = split_params(model.openai_params)
client = openai.OpenAI(**client_params)
docsd = DocumentDir(conf.data_dir)
pbar = tqdm(total=conf.sample * len(prompts), desc=f"model {model.name}")
for pname, messages in prompts.items():
msg_list = []
for _ in range(conf.sample):
content = ""
while not content:
chat_completion = client.chat.completions.create(
messages=messages,
**completion_params,
)
content = chat_completion.choices[0].message.content
msg_list.append([*messages, {"role": "assistant", "content": content}])
pbar.update(1)
docsd.dump(msg_list, pname, model.name)
pbar.close()


if __name__ == "__main__":
argp = argparse.ArgumentParser(
description="Gather sample responses from model endpoints"
)
argp.add_argument("config", type=str)
args = argp.parse_args()

conf = load_config(args.config)
prompts = {p.name: parse_chat(p.chat) for p in conf.prompt}
for model in conf.model:
batch_request(model, prompts, conf)
Empty file added lone_arena/__init__.py
Empty file.
Loading

0 comments on commit 59b4162

Please sign in to comment.