Skip to content

Commit

Permalink
feat(server): support interrupting stream, support slash commands (#248)
Browse files Browse the repository at this point in the history
* fix(server): started work on interrupting stream, added support for slash commands

* fix: more progress towards interrupting generation
  • Loading branch information
ErikBjare authored Nov 5, 2024
1 parent b3f3118 commit c1d136f
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions gptme/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import atexit
import io
import logging
from collections.abc import Generator
from contextlib import redirect_stdout
from datetime import datetime
from importlib import resources
Expand Down Expand Up @@ -104,20 +105,6 @@ def api_conversation_generate(logfile: str):
# load conversation
manager = LogManager.load(logfile, branch=req_json.get("branch", "main"))

# if prompt is a user-command, execute it
if manager.log[-1].role == "user":
f = io.StringIO()
print("Begin capturing stdout, to pass along command output.")
with redirect_stdout(f):
resp = execute_cmd(manager.log[-1], manager, confirm_func)
print("Done capturing stdout.")
if resp:
manager.write()
output = f.getvalue()
return flask.jsonify(
[{"role": "system", "content": output, "stored": False}]
)

# performs reduction/context trimming, if necessary
msgs = prepare_messages(manager.log.messages)

Expand Down Expand Up @@ -154,7 +141,7 @@ def api_conversation_generate(logfile: str):
return flask.jsonify({"error": str(e)})

# Streaming response
def generate():
def generate() -> Generator[str, None, None]:
# Start with an empty message
output = ""
try:
Expand All @@ -166,6 +153,21 @@ def generate():
yield f"data: {flask.json.dumps({'error': 'No messages to process'})}\n\n"
return

# if prompt is a user-command, execute it
last_msg = manager.log[-1]
if last_msg.role == "user" and last_msg.content.startswith("/"):
f = io.StringIO()
print("Begin capturing stdout, to pass along command output.")
with redirect_stdout(f):
resp = execute_cmd(manager.log[-1], manager, confirm_func)
print("Done capturing stdout.")
output = f.getvalue().strip()
if resp and output:
print(f"Replying with command output: {output}")
manager.write()
yield f"data: {flask.json.dumps({'role': 'system', 'content': output, 'stored': False})}\n\n"
return

# Stream tokens from the model
logger.debug(f"Starting token stream with model {model}")
for char in (char for chunk in _stream(msgs, model) for char in chunk):
Expand All @@ -184,6 +186,7 @@ def generate():
msg = Message("assistant", output)
msg = msg.replace(quiet=True)
manager.append(msg)
yield f"data: {flask.json.dumps({'role': 'assistant', 'content': output, 'stored': True})}\n\n"

# Execute any tools and stream their output
for reply_msg in execute_msg(msg, confirm_func):
Expand All @@ -193,11 +196,19 @@ def generate():
manager.append(reply_msg)
yield f"data: {flask.json.dumps({'role': reply_msg.role, 'content': reply_msg.content, 'stored': True})}\n\n"

except GeneratorExit:
logger.info("Client disconnected during generation, interrupting")
if output:
output += "\n\n[interrupted]"
msg = Message("assistant", output)
msg = msg.replace(quiet=True)
manager.append(msg)
raise
except Exception as e:
logger.exception("Error during generation")
yield f"data: {flask.json.dumps({'error': str(e)})}\n\n"
finally:
logger.info("Generation complete")
logger.info("Generation completed")

return flask.Response(
generate(),
Expand Down

0 comments on commit c1d136f

Please sign in to comment.