diff --git a/generate.py b/generate.py index c5c6787..6ab35b8 100644 --- a/generate.py +++ b/generate.py @@ -4,6 +4,7 @@ import openai from tqdm import tqdm +import asyncio import argparse import re from typing import Any @@ -31,35 +32,70 @@ def split_params(params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] return client_params, params -def batch_request(model: Model, prompts: dict[str, list], conf: Config): +async def batch_request( + model: Model, prompts: dict[str, list], conf: Config, batch_size: int = 1 +): client_params, completion_params = split_params(model.openai_params) - client = openai.OpenAI(**client_params) + client = openai.AsyncOpenAI(**client_params) docsd = DocumentDir(conf.data_dir) pbar = tqdm(total=conf.sample * len(prompts), desc=f"model {model.name}") - for pname, messages in prompts.items(): - msg_list = [] - for _ in range(conf.sample): - content = "" - while not content: - chat_completion = client.chat.completions.create( - messages=messages, - **completion_params, - ) - content = chat_completion.choices[0].message.content - msg_list.append([*messages, {"role": "assistant", "content": content}]) + + results = {} + + async def make_request(pname: str, messages: list): + nonlocal client, completion_params, results + content = "" + while not content: + chat_completion = await client.chat.completions.create( + messages=messages, + **completion_params, + ) + content = chat_completion.choices[0].message.content + results.setdefault(pname, []) + results[pname].append([*messages, {"role": "assistant", "content": content}]) + + todo = list(prompts.items()) * conf.sample + tasks = set() + + def queue_request(): + nonlocal todo, tasks + pname, message = todo.pop() + tasks.add(asyncio.create_task(make_request(pname, message))) + + for i in range(batch_size): + if todo: + queue_request() + + while tasks: + done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in done: pbar.update(1) + tasks.remove(task) + await task + + while todo and len(tasks) < batch_size: + queue_request() + + for pname, msg_list in results.items(): docsd.dump(msg_list, pname, model.name) + pbar.close() -if __name__ == "__main__": +async def main(): argp = argparse.ArgumentParser( description="Gather sample responses from model endpoints" ) + argp.add_argument("--batch-size", type=int, default=4) argp.add_argument("config", type=str) args = argp.parse_args() conf = load_config(args.config) prompts = {p.name: parse_chat(p.chat) for p in conf.prompt} for model in conf.model: - batch_request(model, prompts, conf) + await batch_request(model, prompts, conf, batch_size=args.batch_size) + + +if __name__ == "__main__": + asyncio.run(main())