Skip to content

Commit

Permalink
Add request ID into request, and include headers (#46)
Browse files Browse the repository at this point in the history
* version bump

* implemented

* using case insensitive headers
  • Loading branch information
erik-dunteman authored Nov 28, 2023
1 parent 71a2810 commit e614c14
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
12 changes: 8 additions & 4 deletions potassium/potassium.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Generator, Optional, Union
from flask import Flask, request, make_response, abort, Response as FlaskResponse
from werkzeug.serving import make_server
from werkzeug.datastructures.headers import EnvironHeaders
from threading import Thread, Lock, Condition
import functools
import traceback
Expand All @@ -15,18 +16,19 @@ def __init__(self, type, func):
self.type = type
self.func = func


class Request():
def __init__(self, json: dict):
def __init__(self, id: str, headers: EnvironHeaders, json: dict):
self.id = id
self.headers = headers
self.json = json


ResponseBody = Union[bytes, Generator[bytes, None, None]]

class Response():
def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[ResponseBody] = None):
assert json == None or body == None, "Potassium Response object cannot have both json and body set"


self.headers = headers if headers != None else {}

# convert json to body if not None
Expand Down Expand Up @@ -181,7 +183,9 @@ def _handle_generic(self, endpoint, flask_request):

try:
req = Request(
json=flask_request.get_json()
headers=flask_request.headers,
json=flask_request.get_json(),
id=flask_request.headers.get("X-Banana-Request-Id", "")
)
except:
res = make_response()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
setup(
name='potassium',
packages=['potassium'],
version='0.4.0',
version='0.4.1',
license='Apache License 2.0',
# Give a short description about your library
description='The potassium package is a flask-like HTTP server for serving large AI models',
Expand Down
21 changes: 21 additions & 0 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def handler2_id(context: dict, request: potassium.Request) -> potassium.Response
status=200
)

@app.handler("/some_headers_request")
def handler5(context: dict, request: potassium.Request) -> potassium.Response:
assert request.headers["A"] == "a"
assert request.headers["B"] == "b"
assert request.headers["X-Banana-Request-Id"] == request.id
return potassium.Response(
headers={"A": "a", "B": "b", "X-Banana-Request-Id": request.id},
json={"hello": "some_headers_request", "id": request.id},
status=200
)

client = app.test_client()

res = client.post("/", json={})
Expand All @@ -77,8 +88,18 @@ def handler2_id(context: dict, request: potassium.Request) -> potassium.Response
assert res.status_code == 200
assert res.json == {"hello": "some_path/child_path"}

# note the capitalization of ID, we're testing that it's case insensitive
headers = {"A": "a", "B": "b", "X-Banana-Request-ID": "123"}
res = client.post("/some_headers_request", json={}, headers=headers)
assert res.status_code == 200
assert res.json == {"hello": "some_headers_request", "id": headers["X-Banana-Request-ID"]}
assert res.headers["X-Banana-Request-Id"] == "123"
assert res.headers["A"] == "a"
assert res.headers["B"] == "b"

res = client.post("/", data='{"key": unquoted_value}', content_type='application/json')
assert res.status_code == 400

# check status
res = client.get("/__status__")
assert res.status_code == 200
Expand Down

0 comments on commit e614c14

Please sign in to comment.