diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5cd7f41 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,29 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + check-n-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + - name: Install dependencies + run: | + pip install -r requirements.txt + pip install -r requirements-dev.txt + + - name: Formatter + run: | + ruff format --check + + - name: Unit tests + run: | + pytest lone_arena/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fdfc8c7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/data/ + +__pycache__/ +.pytest_cache/ +.ruff_cache/ +.vscode/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a768a1f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,5 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.14 + hooks: + - id: ruff-format diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b7e6d9e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Contextualist + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..6fd05c9 --- /dev/null +++ b/README.md @@ -0,0 +1,38 @@ + + +# Lone Arena + +When comparing two LLM checkpoints, human evaluation could be tedious. +Let's strip down the evaluation process to just a single question: + +![lone_arena-ui-en](media/lone_arena-ui-en.png) + +Press f or j to choose the winner of each match. + +Inspired by [Chatbot Arena](https://chat.lmsys.org). + +## Get Started + +1. In a Python (>= 3.12) environment, `pip install -r requirements.txt` +2. Fill out `config.toml` with your model endpoint infomation and prompts. See [`config-example.toml`](config-example.toml). +3. Run `python generate.py config.toml` to gather responses from models. +4. Run `python evaluate.py config.toml` to host your competition! + + +## Approach + +Two models/checkpoints are compared by anonymous evaluation of their responses to the same prompt. For each prompt: + +1. For each model, generate 8 sample responses. Run a single-elimination tournament to get top 3 responses. (8 matches x 2 models) +2. Let the best responses of two models compete, then 2nd best of two models, then 3rd best. Winner of each gets 4.8, 3.2, 2.0 points, respectively. (3 matches) + +Matches are shuffled. +Number of samples and points are configurable. +In the future, I might implement [Elo](https://en.wikipedia.org/wiki/Elo_rating_system) for comparing multiple models. + +## Develop + +```bash +pip install -r requirements-dev.txt +pre-commit install +``` diff --git a/config-example.toml b/config-example.toml new file mode 100644 index 0000000..e8054a4 --- /dev/null +++ b/config-example.toml @@ -0,0 +1,33 @@ +data_dir = "./data" +sample = 8 + +[[model]] +# Entries other than `name` are passed to `openai.OpenAI` or `openai.OpenAI.chat.completion.create` +name = "chatglm3-6b-sft-checkpoint100" +base_url = "http://localhost:8000/v1" +api_key = "dummy-key" +max_retries = 1 +model = "" + +[[model]] +name = "gpt-3.5" +api_key = "sk-..." +max_retries = 1 +model = "gpt-3.5-turbo" + + +[[prompt]] +name = "demo-sleep-zh" +chat = """ +user: 睡不着怎么办? +assistant: 你可以试试给我讲个故事呀。 +user: 诶嘿,那你想听什么样的故事呢? +""" + +[[prompt]] +name = "demo-sleep-en" +chat = """ +user: What should I do if I can't fall asleep? +assistant: Hmmm, maybe you can try telling me a story. +user: lol, what kind of story do you like? +""" diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..eba7c40 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,151 @@ +from lone_arena.chatcup import Top3_1v1 +from lone_arena.config import load_config, Config +from lone_arena.files import DocumentDir, ResultDir +from lone_arena.format import format_conversation + +import gradio as gr + +from queue import Queue +import argparse +from functools import partial + +req_queue = Queue() +rsp_queue = Queue() +prog_notify = Queue() + + +def main(conf: Config): + def compete(a, b): + ca, cb = ( + format_conversation(docs.get(a, [])), + format_conversation(docs.get(b, [])), + ) + req_queue.put((ca, cb)) + result = rsp_queue.get() + if result == 0: + return a, b + return b, a + + mnames = [x.name for x in conf.model] + pnames = [x.name for x in conf.prompt] + docsd = DocumentDir(conf.data_dir) + docs = docsd.load(pnames, mnames) + + resultd = ResultDir(conf.data_dir) + podiums, pnames_todo = resultd.load(pnames) + if podiums: + if pnames_todo: + print("Partially completed, resuming...") + else: + msg = f"Loaded completed result. To re-evaluate, remove results from {resultd.result_dir}" + print(msg) + req_queue.put((msg, "")) + + cup = Top3_1v1(pnames_todo, mnames, conf.sample, conf.top3_scores) + todo_match, total_match = cup.nmatch(), cup.nmatch(len(pnames)) + prog_notify.put((total_match - todo_match, total_match)) + itournament = cup.run(compete) + for pname, podium in zip(pnames_todo, itournament): + podiums.append(podium) + resultd.dump(podium, docs, pname) + msg = f"End of evaluation. Winning responses can be found in {resultd.result_dir}" + req_queue.put((msg, "")) + + score_weights = [p.score_weight for p in conf.prompt] + result = cup.tabulate_result(podiums, score_weights) + return gr.DataFrame(visible=True, value=result) + + +def init(): + da, db = req_queue.get() + cm, tm = prog_notify.get() + return da, db, cm, tm, gr.Slider(value=round(cm / tm, 2)) + + +def on_decision(completed_match: int, total_match: int, ev_data: gr.EventData): + if completed_match == total_match: + return gr.Markdown(), gr.Markdown() # no more updates + rsp_queue.put(int(ev_data.target.elem_id[-1]) - 1) # type: ignore + doc_a, doc_b = req_queue.get() + return doc_a, doc_b + + +def progbar_update(completed_match: int, total_match: int): + if completed_match < total_match: + completed_match += 1 + progress = round(completed_match / total_match, 2) + return completed_match, gr.Slider(value=progress) + + +shortcut_js = """ + +""" + + +def ui(conf: Config): + with gr.Blocks( + title="Lone Arena", + head=shortcut_js, + theme=gr.themes.Default(spacing_size=gr.themes.sizes.spacing_lg), + ) as demo: + completed_match, total_match = gr.State(0), gr.State(1) + progbar = gr.Slider(0, 1, 0, label="Progress", container=False) + gr.Markdown( + """ + ## Which of the two responses is better? + """ + ) + with gr.Row(equal_height=True): + with gr.Column(variant="panel"): + choose1 = gr.Button("👇This one's better [f]", elem_id="choose1") + candidate1 = gr.Markdown(line_breaks=True) + with gr.Column(variant="panel"): + choose2 = gr.Button("👇This one's better [j]", elem_id="choose2") + candidate2 = gr.Markdown(line_breaks=True) + result_table = gr.DataFrame(visible=True, row_count=len(conf.prompt) + 1) + + gr.on( + triggers=[choose1.click, choose2.click], + fn=on_decision, + inputs=[completed_match, total_match], + outputs=[candidate1, candidate2], + ).then( + progbar_update, + inputs=[completed_match, total_match], + outputs=[completed_match, progbar], + ) + demo.load( + init, + outputs=[candidate1, candidate2, completed_match, total_match, progbar], + ) + # workaround for https://github.com/gradio-app/gradio/issues/7101 + demo.load( + lambda: gr.DataFrame(visible=False), + outputs=[result_table], + ) + demo.load(partial(main, conf), outputs=[result_table]) + return demo + + +if __name__ == "__main__": + argp = argparse.ArgumentParser(description="Host the evaluation Web UI") + argp.add_argument("--port", type=int, default=7860) + argp.add_argument("config", type=str) + args = argp.parse_args() + conf = load_config(args.config) + + demo = ui(conf) + demo.launch(server_port=args.port, show_api=False, quiet=True) diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..c5c6787 --- /dev/null +++ b/generate.py @@ -0,0 +1,65 @@ +from lone_arena.config import load_config, Model, Config +from lone_arena.files import DocumentDir + +import openai +from tqdm import tqdm + +import argparse +import re +from typing import Any + +ROLE_TAG = re.compile(r"^(user|assistant|system): ?(.*)$") + + +def parse_chat(chat: str) -> list[dict]: + msg = [] + for li in chat.splitlines(): + if (m := ROLE_TAG.match(li)) is not None: + msg.append({"role": m.group(1), "content": m.group(2)}) + continue + assert len(msg) > 0, f"missing role tag for {chat!r}" + msg[-1]["content"] += "\n" + li + return msg + + +def split_params(params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + params = params.copy() + client_params = {} + for k in ["api_key", "organization", "base_url", "timeout", "max_retries"]: + if k in params: + client_params[k] = params.pop(k) + return client_params, params + + +def batch_request(model: Model, prompts: dict[str, list], conf: Config): + client_params, completion_params = split_params(model.openai_params) + client = openai.OpenAI(**client_params) + docsd = DocumentDir(conf.data_dir) + pbar = tqdm(total=conf.sample * len(prompts), desc=f"model {model.name}") + for pname, messages in prompts.items(): + msg_list = [] + for _ in range(conf.sample): + content = "" + while not content: + chat_completion = client.chat.completions.create( + messages=messages, + **completion_params, + ) + content = chat_completion.choices[0].message.content + msg_list.append([*messages, {"role": "assistant", "content": content}]) + pbar.update(1) + docsd.dump(msg_list, pname, model.name) + pbar.close() + + +if __name__ == "__main__": + argp = argparse.ArgumentParser( + description="Gather sample responses from model endpoints" + ) + argp.add_argument("config", type=str) + args = argp.parse_args() + + conf = load_config(args.config) + prompts = {p.name: parse_chat(p.chat) for p in conf.prompt} + for model in conf.model: + batch_request(model, prompts, conf) diff --git a/lone_arena/__init__.py b/lone_arena/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lone_arena/chatcup.py b/lone_arena/chatcup.py new file mode 100644 index 0000000..19ba9f7 --- /dev/null +++ b/lone_arena/chatcup.py @@ -0,0 +1,87 @@ +from .tournament import single_elimination, pair_matches, run_tournament, Player, Podium + +from attrs import define +import pandas as pd + +from typing import Callable, Iterable, Protocol, TYPE_CHECKING + + +class Cup(Protocol): + def nmatch(self, nprompt: int | None = None) -> int: + ... + + def run( + self, compete: Callable[[Player, Player], tuple[Player, Player]] + ) -> Iterable[Podium]: + ... + + def tabulate_result( + self, podiums: list[Podium], score_weights: list[float] + ) -> pd.DataFrame: + ... + + +@define +class Top3_1v1: + prompt_names: list[str] + model_names: list[str] + nplayer: int + top3_scores: tuple[float, float, float] + + def nmatch(self, nprompt: int | None = None) -> int: + nprompt = nprompt or len(self.prompt_names) + return nprompt * (self.nplayer * 2 + 3) + + def run( + self, compete: Callable[[Player, Player], tuple[Player, Player]] + ) -> Iterable[Podium]: + assert len(self.model_names) == 2, "expect 2 models" + for pname in self.prompt_names: + ta = single_elimination( + [(pname, self.model_names[0], i) for i in range(self.nplayer)] + ) + tb = single_elimination( + [(pname, self.model_names[1], i) for i in range(self.nplayer)] + ) + run_tournament(ta, tb, compete=compete) + + top3 = list(zip(ta.podium.players, tb.podium.players)) + tp = pair_matches(top3) + run_tournament(tp, compete=compete) + + yield tp.podium + + def tabulate_result( + self, podiums: list[Podium], score_weights: list[float] + ) -> pd.DataFrame: + ptags: list[list] = [p.players for p in podiums] + tb = [] + scores = [0.0] * len(self.model_names) + total_scores = [0.0] * len(self.model_names) + for ptag, weight in zip(ptags, score_weights): + for i in range(len(self.model_names)): + scores[i] = ( + self.top3_scores[0] * (ptag[0][1] == self.model_names[i]) + + self.top3_scores[1] * (ptag[1][1] == self.model_names[i]) + + self.top3_scores[2] * (ptag[2][1] == self.model_names[i]) + ) + total_scores[i] += scores[i] * weight + tb.append( + { + "Prompt": ptag[0][0], + **{m: round(s, 1) for m, s in zip(self.model_names, scores)}, + "weight": f"{weight:.1f}×", + } + ) + tb.append( + { + "Prompt": "TOTAL", + **{m: round(s, 1) for m, s in zip(self.model_names, total_scores)}, + "weight": "", + } + ) + return pd.DataFrame(tb) + + +if TYPE_CHECKING: + _: type[Cup] = Top3_1v1 diff --git a/lone_arena/config.py b/lone_arena/config.py new file mode 100644 index 0000000..05c4006 --- /dev/null +++ b/lone_arena/config.py @@ -0,0 +1,52 @@ +from attrs import define, Factory +import cattrs + +import tomllib +from pathlib import Path +from math import log2 +from typing import Any + +cattrs.register_structure_hook(Path, lambda d, t: t(d)) + + +@define +class Model: + name: str + # params for openai.OpenAI and openai.OpenAI.chat.completion.create + openai_params: dict[str, Any] = Factory(dict) + + @classmethod + def from_dict(cls, d): + return cls( + name=d.pop("name"), + openai_params=d, + ) + + +cattrs.register_structure_hook(Model, lambda d, t: t.from_dict(d)) + + +@define +class Prompt: + name: str + chat: str + score_weight: float = 1.0 + + +@define +class Config: + data_dir: Path = Path("./data") + sample: int = 8 + top3_scores: tuple[float, float, float] = (4.8, 3.2, 2.0) + model: list[Model] = Factory(list) + prompt: list[Prompt] = Factory(list) + + def __attrs_post_init__(self): + assert log2(self.sample).is_integer(), "config: sample must be power of 2" + assert len(self.prompt) > 0, "config: expect at least 1 prompt" + + +def load_config(fname: str) -> Config: + with open(fname, "rb") as f: + d = tomllib.load(f) + return cattrs.structure(d, Config) diff --git a/lone_arena/files.py b/lone_arena/files.py new file mode 100644 index 0000000..6f60961 --- /dev/null +++ b/lone_arena/files.py @@ -0,0 +1,62 @@ +from .tournament import Podium, Player + +from attrs import define + +from pathlib import Path +from functools import cached_property +import json + +type Messages = list[dict] +type Documents = dict[Player, Messages] + + +@define(slots=False) +class DocumentDir: + data_dir: Path + + @cached_property + def doc_dir(self) -> Path: + d = self.data_dir / "response" + d.mkdir(exist_ok=True, parents=True) + return d + + def load(self, prompt_names: list[str], model_names: list[str]) -> Documents: + docs = {} + for pname in prompt_names: + for mname in model_names: + with (self.doc_dir / pname / f"{mname}.jsonl").open("r") as fi: + for i, line in enumerate(fi): + docs[(pname, mname, i)] = json.loads(line) + return docs + + def dump(self, msg_list: list[Messages], prompt_name: str, model_name: str): + pdir = self.doc_dir / prompt_name + pdir.mkdir(exist_ok=True) + with (pdir / f"{model_name}.jsonl").open("w") as fo: + for msg in msg_list: + json.dump(msg, fo, ensure_ascii=False) + print(file=fo) + + +@define(slots=False) +class ResultDir: + data_dir: Path + + @cached_property + def result_dir(self) -> Path: + d = self.data_dir / "result" + d.mkdir(exist_ok=True, parents=True) + return d + + def load(self, prompt_names: list[str]) -> tuple[list[Podium], list[str]]: + podiums = [] + pnames_todo = [] + for pname in prompt_names: + if (self.result_dir / f"{pname}.json").exists(): + podiums.append(Podium.load_from(self.result_dir / f"{pname}.json")) + else: + pnames_todo.append(pname) + return podiums, pnames_todo + + def dump(self, podium: Podium, docs: Documents, prompt_name: str): + podium.dump(self.result_dir / f"{prompt_name}.json", docs) diff --git a/lone_arena/format.py b/lone_arena/format.py new file mode 100644 index 0000000..124d4b7 --- /dev/null +++ b/lone_arena/format.py @@ -0,0 +1,20 @@ +def format_conversation(conv): + if not conv: + return "" + conv_f = [] + for e in conv[:-1]: + conv_f.append( + f"**{' '*(9-len(e['role']))}{e['role']}** {e['content']}" + ) + conv_f.append( + f"**{' '*(9-len(conv[-1]['role']))}{conv[-1]['role']}** {conv[-1]['content']}" + ) + s = "\n\n".join(conv_f) + return min_lines(s, 6) + + +def min_lines(s, n): + ln = len(s.splitlines()) + if ln < n: + s += "
" * (n - ln) + return s diff --git a/lone_arena/test_tournament.py b/lone_arena/test_tournament.py new file mode 100644 index 0000000..64c51b0 --- /dev/null +++ b/lone_arena/test_tournament.py @@ -0,0 +1,38 @@ +import pytest +from .tournament import * + + +@pytest.mark.parametrize("case", ["4", "16"]) +def test_single_elimination(case): + players, n_init, n_to_top = { + "4": (list(range(4)), 2, 1), + "16": (list(range(16)), 8, 3), + }[case] + + t = single_elimination(players) + assert len(t.init_matches) == n_init + m = t.init_matches[0] + for _ in range(n_to_top): + wt = m.winner_to + assert wt is not None and isinstance(wt[0], Match) + m = wt[0] + assert m.winner_to is not None and m.winner_to[0] is t.podium + assert len(t.podium.players) == 3 + + +def test_pair_matches(): + players = [(0, 1), (2, 3), (4, 5), (6, 7)] + t = pair_matches(players) + assert len(t.init_matches) == 4 + assert len(t.podium.players) == 4 + + +def test_run_tournament(): + def compete(p1, p2): + if p1 < p2: + return p2, p1 + return p1, p2 + + t = single_elimination([0, 7, 3, 4, 1, 6, 2, 5]) + run_tournament(t, compete=compete) + assert t.podium.players == [7, 6, 5] diff --git a/lone_arena/tournament.py b/lone_arena/tournament.py new file mode 100644 index 0000000..a9e1e26 --- /dev/null +++ b/lone_arena/tournament.py @@ -0,0 +1,109 @@ +from attrs import define, Factory, NOTHING as TBD + +from math import log2 +from itertools import chain, batched +import random +from pathlib import Path +import json +from typing import Self, Callable, Protocol +from collections.abc import Hashable + +type Player = Hashable + + +class PlayerBox(Protocol): + players: list[Player] + + +@define +class Match: + winner_to: tuple[PlayerBox, int] | None = None + loser_to: tuple[PlayerBox, int] | None = None + players: list[Player] = Factory(lambda: [TBD, TBD]) + + def __repr__(self) -> str: + return f"Match({self.players})" + + def moveon(self, winner: Player, loser: Player) -> list["Match"]: + next_matches = [self._fill(winner, True), self._fill(loser, False)] + return [m for m in next_matches if isinstance(m, Match)] + + def _fill(self, v: Player, is_winning: bool) -> PlayerBox | None: + next_pos = self.winner_to if is_winning else self.loser_to + if next_pos is None: + return None + m, i = next_pos + m.players[i] = v + if any((x is TBD for x in m.players)): + return None + return m + + +@define +class Podium: + players: list[Player] + + @classmethod + def for_(cls, n: int) -> Self: + return cls(players=[TBD] * n) + + def dump(self, path: Path, docs: dict[Player, list] | None = None): + if docs is None: + d = [{"id": p} for p in self.players] + else: + d = [{"id": p, "chat": docs[p]} for p in self.players] + json.dump(d, path.open("w"), ensure_ascii=False, indent=2) + + @classmethod + def load_from(cls, path: Path) -> Self: + d = json.load(path.open("r")) + return cls(players=[tuple(x["id"]) for x in d]) + + +@define +class Tournament: + init_matches: list[Match] + podium: Podium + + +def run_tournament( + *tournaments: Tournament, + compete: Callable[[Player, Player], tuple[Player, Player]], +): + pool = list(chain.from_iterable(t.init_matches for t in tournaments)) + while pool: + m = pool.pop(random.randrange(len(pool))) + random.shuffle(m.players) + winner, loser = compete(*m.players) + pool.extend(m.moveon(winner, loser)) + + +def single_elimination(players: list[Player]) -> Tournament: + nplayer = len(players) + assert log2(nplayer).is_integer(), "expect 2^n players" + + top3 = Podium.for_(3) + final = Match((top3, 0), (top3, 1)) + loser_final = Match((top3, 2), None) + semifinal0 = Match((final, 0), (loser_final, 0)) + semifinal1 = Match((final, 1), (loser_final, 1)) + leaves = [semifinal0, semifinal1] + while len(leaves) * 2 < nplayer: + leaves = list( + chain.from_iterable( + [Match((m, 0), None), Match((m, 1), None)] for m in leaves + ) + ) + + for i, p in enumerate(batched(players, 2)): + leaves[i].players = list(p) + + return Tournament(leaves, top3) + + +def pair_matches(players: list[tuple[Player, Player]]) -> Tournament: + assert all(len(p) == 2 for p in players), "expect n pairs" + n = len(players) + winners = Podium.for_(n) + matches = [Match((winners, i), None, list(p)) for i, p in enumerate(players)] + return Tournament(matches, winners) diff --git a/media/lone_arena-sketch-small.png b/media/lone_arena-sketch-small.png new file mode 100644 index 0000000..7c3b7b6 Binary files /dev/null and b/media/lone_arena-sketch-small.png differ diff --git a/media/lone_arena-ui-en.png b/media/lone_arena-ui-en.png new file mode 100644 index 0000000..381907e Binary files /dev/null and b/media/lone_arena-ui-en.png differ diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..c8a1f70 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +pytest +ruff +pre-commit diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4bb1856 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +gradio +openai +attrs +cattrs +pandas +tqdm