Skip to content

Commit

Permalink
Merge pull request #3 from Contextualist/feat-generate-batch
Browse files Browse the repository at this point in the history
Parallel batched requests
  • Loading branch information
Contextualist authored Feb 5, 2024
2 parents 2e211e3 + a4cea86 commit d10a59a
Showing 1 changed file with 51 additions and 15 deletions.
66 changes: 51 additions & 15 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import openai
from tqdm import tqdm

import asyncio
import argparse
import re
from typing import Any
Expand Down Expand Up @@ -31,35 +32,70 @@ def split_params(params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]
return client_params, params


def batch_request(model: Model, prompts: dict[str, list], conf: Config):
async def batch_request(
model: Model, prompts: dict[str, list], conf: Config, batch_size: int = 1
):
client_params, completion_params = split_params(model.openai_params)
client = openai.OpenAI(**client_params)
client = openai.AsyncOpenAI(**client_params)
docsd = DocumentDir(conf.data_dir)
pbar = tqdm(total=conf.sample * len(prompts), desc=f"model {model.name}")
for pname, messages in prompts.items():
msg_list = []
for _ in range(conf.sample):
content = ""
while not content:
chat_completion = client.chat.completions.create(
messages=messages,
**completion_params,
)
content = chat_completion.choices[0].message.content
msg_list.append([*messages, {"role": "assistant", "content": content}])

results = {}

async def make_request(pname: str, messages: list):
nonlocal client, completion_params, results
content = ""
while not content:
chat_completion = await client.chat.completions.create(
messages=messages,
**completion_params,
)
content = chat_completion.choices[0].message.content
results.setdefault(pname, [])
results[pname].append([*messages, {"role": "assistant", "content": content}])

todo = list(prompts.items()) * conf.sample
tasks = set()

def queue_request():
nonlocal todo, tasks
pname, message = todo.pop()
tasks.add(asyncio.create_task(make_request(pname, message)))

for i in range(batch_size):
if todo:
queue_request()

while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

for task in done:
pbar.update(1)
tasks.remove(task)
await task

while todo and len(tasks) < batch_size:
queue_request()

for pname, msg_list in results.items():
docsd.dump(msg_list, pname, model.name)

pbar.close()


if __name__ == "__main__":
async def main():
argp = argparse.ArgumentParser(
description="Gather sample responses from model endpoints"
)
argp.add_argument("--batch-size", type=int, default=4)
argp.add_argument("config", type=str)
args = argp.parse_args()

conf = load_config(args.config)
prompts = {p.name: parse_chat(p.chat) for p in conf.prompt}
for model in conf.model:
batch_request(model, prompts, conf)
await batch_request(model, prompts, conf, batch_size=args.batch_size)


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit d10a59a

Please sign in to comment.