Skip to content

Commit

Permalink
add support for arbitrary body and any headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Peddle committed Nov 7, 2023
1 parent 101f4ac commit e9d344d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 17 deletions.
53 changes: 36 additions & 17 deletions potassium/potassium.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import time
from flask import Flask, request, make_response, abort
from typing import Optional
from flask import Flask, request, make_response, abort, Response as FlaskResponse
from werkzeug.serving import make_server
from threading import Thread, Lock, Condition
import functools
import traceback
import json as jsonlib
from termcolor import colored


Expand All @@ -19,10 +21,34 @@ def __init__(self, json: dict):


class Response():
def __init__(self, status: int = 200, json: dict = {}):
self.json = json
def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[bytes] = None):
assert json == None or body == None, "Potassium Response object cannot have both json and body set"

self.headers = headers if headers != None else {}

# convert json to body if not None
if json != None:
self.body = jsonlib.dumps(json).encode("utf-8")
self.headers["Content-Type"] = "application/json"
else:
self.body = body

self.status = status

@property
def json(self):
if self.body == None:
return None
try:
return jsonlib.loads(self.body.decode("utf-8"))
except:
return None

@json.setter
def json(self, json):
self.body = jsonlib.dumps(json).encode("utf-8")
self.headers["Content-Type"] = "application/json"


class InvalidEndpointTypeException(Exception):
def __init__(self):
Expand Down Expand Up @@ -102,10 +128,9 @@ def wrapper(request):
if type(out) != Response:
raise Exception("Potassium Response object not returned")

# check if out.json is a dict
if type(out.json) != dict:
if type(out.body) != bytes:
raise Exception(
"Potassium Response object json must be a dict")
"Potassium Response object body must be bytes")

return out

Expand Down Expand Up @@ -144,7 +169,6 @@ def _handle_generic(self, endpoint, flask_request):
except:
res = make_response()
res.status_code = 423
res.headers['X-Endpoint-Type'] = endpoint.type
return res

res = None
Expand All @@ -157,28 +181,26 @@ def _handle_generic(self, endpoint, flask_request):
except:
res = make_response()
res.status_code = 400
res.headers['X-Endpoint-Type'] = endpoint.type
self._gpu_lock.release()
return res

if endpoint.type == "handler":
try:
out = endpoint.func(req)
res = make_response(out.json)
res.status_code = out.status
res.headers['X-Endpoint-Type'] = endpoint.type

# create flask response
res = make_response()
res = FlaskResponse(
out.body, status=out.status, headers=out.headers)
except:
tb_str = traceback.format_exc()
print(colored(tb_str, "red"))
res = make_response(tb_str)
res.status_code = 500
res.headers['X-Endpoint-Type'] = endpoint.type
self._idle_start_time = time.time()
self._last_inference_start_time = None
self._gpu_lock.release()
elif endpoint.type == "background":


# run as threaded task
def task(endpoint, lock, req):
try:
Expand All @@ -199,7 +221,6 @@ def task(endpoint, lock, req):

# send task start success message
res = make_response({'started': True})
res.headers['X-Endpoint-Type'] = endpoint.type
else:
raise InvalidEndpointTypeException()

Expand Down Expand Up @@ -241,7 +262,6 @@ def warm():
"warm": True,
})
res.status_code = 200
res.headers['X-Endpoint-Type'] = "warmup"
return res

@flask_app.route('/_k/status', methods=["GET"])
Expand All @@ -265,7 +285,6 @@ def status():
})

res.status_code = 200
res.headers['X-Endpoint-Type'] = "status"
return res

return flask_app
Expand Down
13 changes: 13 additions & 0 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def handler2(context: dict, request: potassium.Request) -> potassium.Response:
status=200
)

@app.handler("/some_binary_response")
def handler3(context: dict, request: potassium.Request) -> potassium.Response:
return potassium.Response(
body=b"hello",
status=200,
headers={"Content-Type": "application/octet-stream"}
)

@app.handler("/some_path/child_path")
def handler2_id(context: dict, request: potassium.Request) -> potassium.Response:
return potassium.Response(
Expand All @@ -43,6 +51,11 @@ def handler2_id(context: dict, request: potassium.Request) -> potassium.Response
assert res.status_code == 200
assert res.json == {"hello": "some_path"}

res = client.post("/some_binary_response", json={})
assert res.status_code == 200
assert res.data == b"hello"
assert res.headers["Content-Type"] == "application/octet-stream"

res = client.post("/some_path/child_path", json={})
assert res.status_code == 200
assert res.json == {"hello": "some_path/child_path"}
Expand Down
34 changes: 34 additions & 0 deletions tests/test_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import potassium

def test_json_response():
response = potassium.Response(
status=200,
json={"key": "value"}
)

assert response.status == 200
assert response.json == {"key": "value"}
assert response.headers["Content-Type"] == "application/json"

response.json = {"key": "value2"}
assert response.json == {"key": "value2"}

def test_body_response():
response = potassium.Response(
status=200,
body=b"Hello, world!"
)

assert response.status == 200
assert response.body == b"Hello, world!"
assert 'Content-Type' not in response.headers

response.json = {"key": "value2"}

assert response.json == {"key": "value2"}
assert response.headers["Content-Type"] == "application/json"




0 comments on commit e9d344d

Please sign in to comment.