-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 59b4162
Showing
19 changed files
with
725 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
name: CI | ||
|
||
on: | ||
push: | ||
branches: [main] | ||
pull_request: | ||
branches: [main] | ||
|
||
jobs: | ||
check-n-test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.12' | ||
cache: 'pip' | ||
- name: Install dependencies | ||
run: | | ||
pip install -r requirements.txt | ||
pip install -r requirements-dev.txt | ||
- name: Formatter | ||
run: | | ||
ruff format --check | ||
- name: Unit tests | ||
run: | | ||
pytest lone_arena/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
/data/ | ||
|
||
__pycache__/ | ||
.pytest_cache/ | ||
.ruff_cache/ | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
repos: | ||
- repo: https://github.com/astral-sh/ruff-pre-commit | ||
rev: v0.1.14 | ||
hooks: | ||
- id: ruff-format |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Contextualist | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
<img align="right" src="media/lone_arena-sketch-small.png" width="160" /> | ||
|
||
# Lone Arena | ||
|
||
When comparing two LLM checkpoints, human evaluation could be tedious. | ||
Let's strip down the evaluation process to just a single question: | ||
|
||
 | ||
|
||
Press <kbd>f</kbd> or <kbd>j</kbd> to choose the winner of each match. | ||
|
||
Inspired by [Chatbot Arena](https://chat.lmsys.org). | ||
|
||
## Get Started | ||
|
||
1. In a Python (>= 3.12) environment, `pip install -r requirements.txt` | ||
2. Fill out `config.toml` with your model endpoint infomation and prompts. See [`config-example.toml`](config-example.toml). | ||
3. Run `python generate.py config.toml` to gather responses from models. | ||
4. Run `python evaluate.py config.toml` to host your competition! | ||
|
||
|
||
## Approach | ||
|
||
Two models/checkpoints are compared by anonymous evaluation of their responses to the same prompt. For each prompt: | ||
|
||
1. For each model, generate 8 sample responses. Run a single-elimination tournament to get top 3 responses. (8 matches x 2 models) | ||
2. Let the best responses of two models compete, then 2nd best of two models, then 3rd best. Winner of each gets 4.8, 3.2, 2.0 points, respectively. (3 matches) | ||
|
||
Matches are shuffled. | ||
Number of samples and points are configurable. | ||
In the future, I might implement [Elo](https://en.wikipedia.org/wiki/Elo_rating_system) for comparing multiple models. | ||
|
||
## Develop | ||
|
||
```bash | ||
pip install -r requirements-dev.txt | ||
pre-commit install | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
data_dir = "./data" | ||
sample = 8 | ||
|
||
[[model]] | ||
# Entries other than `name` are passed to `openai.OpenAI` or `openai.OpenAI.chat.completion.create` | ||
name = "chatglm3-6b-sft-checkpoint100" | ||
base_url = "http://localhost:8000/v1" | ||
api_key = "dummy-key" | ||
max_retries = 1 | ||
model = "" | ||
|
||
[[model]] | ||
name = "gpt-3.5" | ||
api_key = "sk-..." | ||
max_retries = 1 | ||
model = "gpt-3.5-turbo" | ||
|
||
|
||
[[prompt]] | ||
name = "demo-sleep-zh" | ||
chat = """ | ||
user: 睡不着怎么办? | ||
assistant: 你可以试试给我讲个故事呀。 | ||
user: 诶嘿,那你想听什么样的故事呢? | ||
""" | ||
|
||
[[prompt]] | ||
name = "demo-sleep-en" | ||
chat = """ | ||
user: What should I do if I can't fall asleep? | ||
assistant: Hmmm, maybe you can try telling me a story. | ||
user: lol, what kind of story do you like? | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from lone_arena.chatcup import Top3_1v1 | ||
from lone_arena.config import load_config, Config | ||
from lone_arena.files import DocumentDir, ResultDir | ||
from lone_arena.format import format_conversation | ||
|
||
import gradio as gr | ||
|
||
from queue import Queue | ||
import argparse | ||
from functools import partial | ||
|
||
req_queue = Queue() | ||
rsp_queue = Queue() | ||
prog_notify = Queue() | ||
|
||
|
||
def main(conf: Config): | ||
def compete(a, b): | ||
ca, cb = ( | ||
format_conversation(docs.get(a, [])), | ||
format_conversation(docs.get(b, [])), | ||
) | ||
req_queue.put((ca, cb)) | ||
result = rsp_queue.get() | ||
if result == 0: | ||
return a, b | ||
return b, a | ||
|
||
mnames = [x.name for x in conf.model] | ||
pnames = [x.name for x in conf.prompt] | ||
docsd = DocumentDir(conf.data_dir) | ||
docs = docsd.load(pnames, mnames) | ||
|
||
resultd = ResultDir(conf.data_dir) | ||
podiums, pnames_todo = resultd.load(pnames) | ||
if podiums: | ||
if pnames_todo: | ||
print("Partially completed, resuming...") | ||
else: | ||
msg = f"Loaded completed result. To re-evaluate, remove results from {resultd.result_dir}" | ||
print(msg) | ||
req_queue.put((msg, "")) | ||
|
||
cup = Top3_1v1(pnames_todo, mnames, conf.sample, conf.top3_scores) | ||
todo_match, total_match = cup.nmatch(), cup.nmatch(len(pnames)) | ||
prog_notify.put((total_match - todo_match, total_match)) | ||
itournament = cup.run(compete) | ||
for pname, podium in zip(pnames_todo, itournament): | ||
podiums.append(podium) | ||
resultd.dump(podium, docs, pname) | ||
msg = f"End of evaluation. Winning responses can be found in {resultd.result_dir}" | ||
req_queue.put((msg, "")) | ||
|
||
score_weights = [p.score_weight for p in conf.prompt] | ||
result = cup.tabulate_result(podiums, score_weights) | ||
return gr.DataFrame(visible=True, value=result) | ||
|
||
|
||
def init(): | ||
da, db = req_queue.get() | ||
cm, tm = prog_notify.get() | ||
return da, db, cm, tm, gr.Slider(value=round(cm / tm, 2)) | ||
|
||
|
||
def on_decision(completed_match: int, total_match: int, ev_data: gr.EventData): | ||
if completed_match == total_match: | ||
return gr.Markdown(), gr.Markdown() # no more updates | ||
rsp_queue.put(int(ev_data.target.elem_id[-1]) - 1) # type: ignore | ||
doc_a, doc_b = req_queue.get() | ||
return doc_a, doc_b | ||
|
||
|
||
def progbar_update(completed_match: int, total_match: int): | ||
if completed_match < total_match: | ||
completed_match += 1 | ||
progress = round(completed_match / total_match, 2) | ||
return completed_match, gr.Slider(value=progress) | ||
|
||
|
||
shortcut_js = """ | ||
<script> | ||
function shortcuts(e) { | ||
switch (e.target.tagName.toLowerCase()) { | ||
case "input": | ||
case "textarea": | ||
return; | ||
} | ||
switch (e.key.toLowerCase()) { | ||
case "f": return document.getElementById("choose1").click(); | ||
case "j": return document.getElementById("choose2").click(); | ||
} | ||
} | ||
document.addEventListener('keypress', shortcuts, false); | ||
</script> | ||
""" | ||
|
||
|
||
def ui(conf: Config): | ||
with gr.Blocks( | ||
title="Lone Arena", | ||
head=shortcut_js, | ||
theme=gr.themes.Default(spacing_size=gr.themes.sizes.spacing_lg), | ||
) as demo: | ||
completed_match, total_match = gr.State(0), gr.State(1) | ||
progbar = gr.Slider(0, 1, 0, label="Progress", container=False) | ||
gr.Markdown( | ||
""" | ||
## Which of the two responses is better? | ||
""" | ||
) | ||
with gr.Row(equal_height=True): | ||
with gr.Column(variant="panel"): | ||
choose1 = gr.Button("👇This one's better [f]", elem_id="choose1") | ||
candidate1 = gr.Markdown(line_breaks=True) | ||
with gr.Column(variant="panel"): | ||
choose2 = gr.Button("👇This one's better [j]", elem_id="choose2") | ||
candidate2 = gr.Markdown(line_breaks=True) | ||
result_table = gr.DataFrame(visible=True, row_count=len(conf.prompt) + 1) | ||
|
||
gr.on( | ||
triggers=[choose1.click, choose2.click], | ||
fn=on_decision, | ||
inputs=[completed_match, total_match], | ||
outputs=[candidate1, candidate2], | ||
).then( | ||
progbar_update, | ||
inputs=[completed_match, total_match], | ||
outputs=[completed_match, progbar], | ||
) | ||
demo.load( | ||
init, | ||
outputs=[candidate1, candidate2, completed_match, total_match, progbar], | ||
) | ||
# workaround for https://github.com/gradio-app/gradio/issues/7101 | ||
demo.load( | ||
lambda: gr.DataFrame(visible=False), | ||
outputs=[result_table], | ||
) | ||
demo.load(partial(main, conf), outputs=[result_table]) | ||
return demo | ||
|
||
|
||
if __name__ == "__main__": | ||
argp = argparse.ArgumentParser(description="Host the evaluation Web UI") | ||
argp.add_argument("--port", type=int, default=7860) | ||
argp.add_argument("config", type=str) | ||
args = argp.parse_args() | ||
conf = load_config(args.config) | ||
|
||
demo = ui(conf) | ||
demo.launch(server_port=args.port, show_api=False, quiet=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from lone_arena.config import load_config, Model, Config | ||
from lone_arena.files import DocumentDir | ||
|
||
import openai | ||
from tqdm import tqdm | ||
|
||
import argparse | ||
import re | ||
from typing import Any | ||
|
||
ROLE_TAG = re.compile(r"^(user|assistant|system): ?(.*)$") | ||
|
||
|
||
def parse_chat(chat: str) -> list[dict]: | ||
msg = [] | ||
for li in chat.splitlines(): | ||
if (m := ROLE_TAG.match(li)) is not None: | ||
msg.append({"role": m.group(1), "content": m.group(2)}) | ||
continue | ||
assert len(msg) > 0, f"missing role tag for {chat!r}" | ||
msg[-1]["content"] += "\n" + li | ||
return msg | ||
|
||
|
||
def split_params(params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: | ||
params = params.copy() | ||
client_params = {} | ||
for k in ["api_key", "organization", "base_url", "timeout", "max_retries"]: | ||
if k in params: | ||
client_params[k] = params.pop(k) | ||
return client_params, params | ||
|
||
|
||
def batch_request(model: Model, prompts: dict[str, list], conf: Config): | ||
client_params, completion_params = split_params(model.openai_params) | ||
client = openai.OpenAI(**client_params) | ||
docsd = DocumentDir(conf.data_dir) | ||
pbar = tqdm(total=conf.sample * len(prompts), desc=f"model {model.name}") | ||
for pname, messages in prompts.items(): | ||
msg_list = [] | ||
for _ in range(conf.sample): | ||
content = "" | ||
while not content: | ||
chat_completion = client.chat.completions.create( | ||
messages=messages, | ||
**completion_params, | ||
) | ||
content = chat_completion.choices[0].message.content | ||
msg_list.append([*messages, {"role": "assistant", "content": content}]) | ||
pbar.update(1) | ||
docsd.dump(msg_list, pname, model.name) | ||
pbar.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
argp = argparse.ArgumentParser( | ||
description="Gather sample responses from model endpoints" | ||
) | ||
argp.add_argument("config", type=str) | ||
args = argp.parse_args() | ||
|
||
conf = load_config(args.config) | ||
prompts = {p.name: parse_chat(p.chat) for p in conf.prompt} | ||
for model in conf.model: | ||
batch_request(model, prompts, conf) |
Empty file.
Oops, something went wrong.