diff --git a/potassium/potassium.py b/potassium/potassium.py index 2b57e0b..dfbd0e7 100644 --- a/potassium/potassium.py +++ b/potassium/potassium.py @@ -1,9 +1,11 @@ import time -from flask import Flask, request, make_response, abort +from typing import Optional +from flask import Flask, request, make_response, abort, Response as FlaskResponse from werkzeug.serving import make_server from threading import Thread, Lock, Condition import functools import traceback +import json as jsonlib from termcolor import colored @@ -19,10 +21,34 @@ def __init__(self, json: dict): class Response(): - def __init__(self, status: int = 200, json: dict = {}): - self.json = json + def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[bytes] = None): + assert json == None or body == None, "Potassium Response object cannot have both json and body set" + + self.headers = headers if headers != None else {} + + # convert json to body if not None + if json != None: + self.body = jsonlib.dumps(json).encode("utf-8") + self.headers["Content-Type"] = "application/json" + else: + self.body = body + self.status = status + @property + def json(self): + if self.body == None: + return None + try: + return jsonlib.loads(self.body.decode("utf-8")) + except: + return None + + @json.setter + def json(self, json): + self.body = jsonlib.dumps(json).encode("utf-8") + self.headers["Content-Type"] = "application/json" + class InvalidEndpointTypeException(Exception): def __init__(self): @@ -102,10 +128,9 @@ def wrapper(request): if type(out) != Response: raise Exception("Potassium Response object not returned") - # check if out.json is a dict - if type(out.json) != dict: + if type(out.body) != bytes: raise Exception( - "Potassium Response object json must be a dict") + "Potassium Response object body must be bytes") return out @@ -144,7 +169,6 @@ def _handle_generic(self, endpoint, flask_request): except: res = make_response() res.status_code = 423 - res.headers['X-Endpoint-Type'] = endpoint.type return res res = None @@ -157,28 +181,26 @@ def _handle_generic(self, endpoint, flask_request): except: res = make_response() res.status_code = 400 - res.headers['X-Endpoint-Type'] = endpoint.type self._gpu_lock.release() return res if endpoint.type == "handler": try: out = endpoint.func(req) - res = make_response(out.json) - res.status_code = out.status - res.headers['X-Endpoint-Type'] = endpoint.type + + # create flask response + res = make_response() + res = FlaskResponse( + out.body, status=out.status, headers=out.headers) except: tb_str = traceback.format_exc() print(colored(tb_str, "red")) res = make_response(tb_str) res.status_code = 500 - res.headers['X-Endpoint-Type'] = endpoint.type self._idle_start_time = time.time() self._last_inference_start_time = None self._gpu_lock.release() elif endpoint.type == "background": - - # run as threaded task def task(endpoint, lock, req): try: @@ -199,7 +221,6 @@ def task(endpoint, lock, req): # send task start success message res = make_response({'started': True}) - res.headers['X-Endpoint-Type'] = endpoint.type else: raise InvalidEndpointTypeException() @@ -241,7 +262,6 @@ def warm(): "warm": True, }) res.status_code = 200 - res.headers['X-Endpoint-Type'] = "warmup" return res @flask_app.route('/_k/status', methods=["GET"]) @@ -265,7 +285,6 @@ def status(): }) res.status_code = 200 - res.headers['X-Endpoint-Type'] = "status" return res return flask_app diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index f54b198..dfb7147 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -26,6 +26,14 @@ def handler2(context: dict, request: potassium.Request) -> potassium.Response: status=200 ) + @app.handler("/some_binary_response") + def handler3(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + body=b"hello", + status=200, + headers={"Content-Type": "application/octet-stream"} + ) + @app.handler("/some_path/child_path") def handler2_id(context: dict, request: potassium.Request) -> potassium.Response: return potassium.Response( @@ -43,6 +51,11 @@ def handler2_id(context: dict, request: potassium.Request) -> potassium.Response assert res.status_code == 200 assert res.json == {"hello": "some_path"} + res = client.post("/some_binary_response", json={}) + assert res.status_code == 200 + assert res.data == b"hello" + assert res.headers["Content-Type"] == "application/octet-stream" + res = client.post("/some_path/child_path", json={}) assert res.status_code == 200 assert res.json == {"hello": "some_path/child_path"} diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 0000000..15d20af --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,34 @@ +import pytest +import potassium + +def test_json_response(): + response = potassium.Response( + status=200, + json={"key": "value"} + ) + + assert response.status == 200 + assert response.json == {"key": "value"} + assert response.headers["Content-Type"] == "application/json" + + response.json = {"key": "value2"} + assert response.json == {"key": "value2"} + +def test_body_response(): + response = potassium.Response( + status=200, + body=b"Hello, world!" + ) + + assert response.status == 200 + assert response.body == b"Hello, world!" + assert 'Content-Type' not in response.headers + + response.json = {"key": "value2"} + + assert response.json == {"key": "value2"} + assert response.headers["Content-Type"] == "application/json" + + + +