diff --git a/example.py b/example.py index e2b9cd4..e37497a 100644 --- a/example.py +++ b/example.py @@ -1,6 +1,8 @@ +import time from potassium import Potassium, Request, Response from transformers import pipeline import torch +import os app = Potassium("my_app") @@ -28,5 +30,26 @@ def handler(context: dict, request: Request) -> Response: status=200 ) +@app.background("/background") +def background(context: dict, request: Request): + time.sleep(5) + print('hi') + + +@app.handler("/stream") +def stream(context: dict, request: Request): + def stream(): + for i in range(100): + yield f"{i}\n" + time.sleep(1) + + + return Response( + body=stream(), + status=200, + headers={"Content-Type": "text/plain"} + ) + + if __name__ == "__main__": - app.serve() \ No newline at end of file + app.serve() diff --git a/potassium/__init__.py b/potassium/__init__.py index 8006620..a50db17 100644 --- a/potassium/__init__.py +++ b/potassium/__init__.py @@ -1,3 +1,4 @@ from .potassium import * from .hooks import * -from .store import Store, RedisConfig \ No newline at end of file +from .store import Store, RedisConfig +from .types import Request, Response diff --git a/potassium/exceptions.py b/potassium/exceptions.py new file mode 100644 index 0000000..93c42f6 --- /dev/null +++ b/potassium/exceptions.py @@ -0,0 +1,10 @@ +class InvalidEndpointTypeException(Exception): + def __init__(self): + super().__init__("Invalid endpoint type. Must be 'handler' or 'background'") + + +class RouteAlreadyInUseException(Exception): + def __init__(self): + super().__init__("Route already in use") + + diff --git a/potassium/potassium.py b/potassium/potassium.py index 9ad355b..fc16e0f 100644 --- a/potassium/potassium.py +++ b/potassium/potassium.py @@ -1,92 +1,134 @@ +from enum import Enum import time +import os from types import GeneratorType -from typing import Generator, Optional, Union +from typing import Callable +from dataclasses import dataclass from flask import Flask, request, make_response, abort, Response as FlaskResponse +import uuid from werkzeug.serving import make_server -from werkzeug.datastructures.headers import EnvironHeaders -from threading import Thread, Lock, Condition +from threading import Thread, Lock +from queue import Queue as ThreadQueue import functools -import traceback -import json as jsonlib from termcolor import colored - - +from multiprocessing import Pool as ProcessPool, Queue as ProcessQueue +from multiprocessing.pool import ThreadPool +from .status import PotassiumStatus, StatusEvent +from .worker import run_worker, init_worker +from .exceptions import RouteAlreadyInUseException, InvalidEndpointTypeException +from .types import Request, RequestHeaders, Response +import logging + +class HandlerType(Enum): + HANDLER = "HANDLER" + BACKGROUND = "BACKGROUND" + +@dataclass class Endpoint(): - def __init__(self, type, func): - self.type = type - self.func = func - -class Request(): - def __init__(self, id: str, headers: EnvironHeaders, json: dict): - self.id = id - self.headers = headers - self.json = json - -ResponseBody = Union[bytes, Generator[bytes, None, None]] - -class Response(): - def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[ResponseBody] = None): - assert json == None or body == None, "Potassium Response object cannot have both json and body set" - - - self.headers = headers if headers != None else {} - - # convert json to body if not None - if json != None: - self.body = jsonlib.dumps(json).encode("utf-8") - self.headers["Content-Type"] = "application/json" - else: - self.body = body - - self.status = status + type: HandlerType + func: Callable - @property - def json(self): - if self.body == None: - return None - if type(self.body) == bytes: - try: - return jsonlib.loads(self.body.decode("utf-8")) - except: - return None - return None - - @json.setter - def json(self, json): - self.body = jsonlib.dumps(json).encode("utf-8") - self.headers["Content-Type"] = "application/json" +class ResponseMailbox(): + def __init__(self, response_queue): + self._response_queue = response_queue + self._mailbox = {} + self._lock = Lock() + t = Thread(target=self._response_handler, daemon=True) + t.start() -class InvalidEndpointTypeException(Exception): - def __init__(self): - super().__init__("Invalid endpoint type. Must be 'handler' or 'background'") - + def _response_handler(self): + try: + while True: + request_id, payload = self._response_queue.get() + with self._lock: + if request_id not in self._mailbox: + self._mailbox[request_id] = ThreadQueue() + self._mailbox[request_id].put(payload) + except EOFError: + # queue closed, this happens when the server is shutting down + pass + + def get_response(self, request_id): + with self._lock: + if request_id not in self._mailbox: + self._mailbox[request_id] = ThreadQueue() + result, stream_id = self._mailbox[request_id].get() + + if stream_id is not None: + result.body = self._stream_body(stream_id) + + with self._lock: + del self._mailbox[request_id] + + return result + + def _stream_body(self, stream_id): + with self._lock: + if stream_id not in self._mailbox: + self._mailbox[stream_id] = ThreadQueue() + queue = self._mailbox[stream_id] -class RouteAlreadyInUseException(Exception): - def __init__(self): - super().__init__("Route already in use") + try: + while True: + result = queue.get() + if isinstance(result, Exception): + with self._lock: + del self._mailbox[stream_id] + raise result + elif result == None: + break + else: + yield result + except GeneratorExit: + while True: + # flush the queue + result = queue.get() + if result == None: + break + elif isinstance(result, Exception): + with self._lock: + del self._mailbox[stream_id] + raise result + with self._lock: + del self._mailbox[stream_id] class Potassium(): "Potassium is a simple, stateful, GPU-enabled, and autoscaleable web framework for deploying machine learning models." - def __init__(self, name): + def __init__(self, name, experimental_num_workers=1): self.name = name # default init function, if the user doesn't specify one - self._init_func = lambda: {} + self._init_func = lambda _: {} # dictionary to store unlimited Endpoints, by unique route self._endpoints = {} self._context = {} - self._gpu_lock = Lock() - self._background_task_cv = Condition() - self._sequence_number = 0 - self._sequence_number_lock = Lock() - self._idle_start_time = 0 - self._last_inference_start_time = None self._flask_app = self._create_flask_app() + self._event_queue = ProcessQueue() + self._response_queue = ProcessQueue() + self._response_mailbox = ResponseMailbox(self._response_queue) + + self._num_workers = experimental_num_workers + + self._worker_pool = None + + self.event_handler_thread = Thread(target=self._event_handler, daemon=True) + self.event_handler_thread.start() + + self._status = PotassiumStatus.initial(self._num_workers) + + def _event_handler(self): + try: + while True: + event = self._event_queue.get() + self._status = self._status.update(event) + except EOFError: + # this happens when the process is shutting down + pass + - # def init(self, func): """init runs once on server start, and is used to initialize the app's context. You can use this to load models onto the GPU, set up connections, etc. @@ -97,14 +139,8 @@ def init(self, func): - the context is not shared between multiple replicas of the app """ - def wrapper(): - print(colored("Running init()", 'yellow')) - self._context = func() - if not isinstance(self._context, dict): - raise Exception("Potassium init() must return a dictionary") - - self._init_func = wrapper - return wrapper + self._init_func = func + return func @staticmethod def _standardize_route(route): @@ -118,134 +154,46 @@ def _standardize_route(route): return route - # handler is a blocking http POST handler - def handler(self, route: str = "/"): - "handler is a blocking http POST handler" - + def _base_decorator(self, route: str, handler_type: HandlerType): route = self._standardize_route(route) if route in self._endpoints: raise RouteAlreadyInUseException() def actual_decorator(func): @functools.wraps(func) - def wrapper(request): + def wrapper(context, request): # send in app's stateful context if GPU, and the request - out = func(self._context, request) + out = func(context, request) - if type(out) != Response: - raise Exception("Potassium Response object not returned") - - if type(out.body) != bytes and type(out.body) != GeneratorType: - raise Exception( - "Potassium Response object body must be bytes", type(out.body)) + if handler_type == HandlerType.HANDLER: + if type(out) != Response: + raise Exception("Potassium Response object not returned") + if type(out.body) != bytes and type(out.body) != GeneratorType: + raise Exception( + "Potassium Response object body must be bytes", type(out.body)) return out - self._endpoints[route] = Endpoint(type="handler", func=wrapper) + self._endpoints[route] = Endpoint(type=handler_type, func=wrapper) return wrapper return actual_decorator + # handler is a blocking http POST handler + def handler(self, route: str = "/"): + "handler is a blocking http POST handler" + return self._base_decorator(route, HandlerType.HANDLER) + # background is a non-blocking http POST handler def background(self, route: str = "/"): "background is a non-blocking http POST handler" - route = self._standardize_route(route) - if route in self._endpoints: - raise RouteAlreadyInUseException() - - def actual_decorator(func): - @functools.wraps(func) - def wrapper(request): - # send in app's stateful context if GPU, and the request - return func(self._context, request) - - self._endpoints[route] = Endpoint( - type="background", func=wrapper) - return wrapper - return actual_decorator + return self._base_decorator(route, HandlerType.BACKGROUND) def test_client(self): "test_client returns a Flask test client for the app" + self._init_server() return self._flask_app.test_client() - # _handle_generic takes in a request and the endpoint it was routed to and handles it as expected by that endpoint - def _handle_generic(self, endpoint, flask_request): - # potassium rejects if lock already in use - try: - self._gpu_lock.acquire(blocking=False) - except: - res = make_response() - res.status_code = 423 - return res - - res = None - self._last_inference_start_time = time.time() - - try: - req = Request( - headers=flask_request.headers, - json=flask_request.get_json(), - id=flask_request.headers.get("X-Banana-Request-Id", "") - ) - except: - res = make_response() - res.status_code = 400 - self._gpu_lock.release() - return res - - if endpoint.type == "handler": - try: - out = endpoint.func(req) - - # create flask response - res = make_response() - res = FlaskResponse( - out.body, status=out.status, headers=out.headers) - except: - tb_str = traceback.format_exc() - print(colored(tb_str, "red")) - res = make_response(tb_str) - res.status_code = 500 - self._idle_start_time = time.time() - self._last_inference_start_time = None - self._gpu_lock.release() - elif endpoint.type == "background": - # run as threaded task - def task(endpoint, lock, req): - try: - endpoint.func(req) - except Exception as e: - # do any cleanup before re-raising user error - raise e - finally: - with self._background_task_cv: - self._background_task_cv.notify_all() - - self._idle_start_time = time.time() - self._last_inference_start_time = None - lock.release() - - thread = Thread(target=task, args=(endpoint, self._gpu_lock, req)) - thread.start() - - # send task start success message - res = make_response({'started': True}) - else: - raise InvalidEndpointTypeException() - - return res - - # WARNING: cover depends on this being called so it should not be changed - def _read_event_chan(self) -> bool: - """ - _read_event_chan essentially waits for a background task to finish, - and then returns True - """ - with self._background_task_cv: - # wait until the background task is done - self._background_task_cv.wait() - return True - def _create_flask_app(self): flask_app = Flask(__name__) @@ -253,20 +201,59 @@ def _create_flask_app(self): @flask_app.route('/', defaults={'path': ''}, methods=["POST"]) @flask_app.route('/', methods=["POST"]) def handle(path): - with self._sequence_number_lock: - self._sequence_number += 1 - route = "/" + path if route not in self._endpoints: + self._event_queue.put((StatusEvent.BAD_REQUEST_RECEIVED,)) abort(404) endpoint = self._endpoints[route] - return self._handle_generic(endpoint, request) - + request_id = request.headers.get("X-Banana-Request-Id", None) + if request_id is None: + request_id = str(uuid.uuid4()) + try: + req = Request( + headers=RequestHeaders(dict(request.headers.items())), + json=request.get_json(), + id=request_id + ) + except: + res = make_response() + res.status_code = 400 + self._event_queue.put((StatusEvent.BAD_REQUEST_RECEIVED,)) + return res + + self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + + assert self._worker_pool is not None, "Worker pool not initialized" + # use an internal id for critical path to prevent user from accidentally + # breaking things by sending multiple requests with the same id + internal_id = str(uuid.uuid4()) + if endpoint.type == HandlerType.HANDLER: + self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id, True)) + resp = self._response_mailbox.get_response(internal_id) + + flask_response = FlaskResponse( + resp.body, + status=resp.status, + headers=resp.headers + ) + elif endpoint.type == HandlerType.BACKGROUND: + self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id)) + + flask_response = make_response({'started': True}) + else: + raise InvalidEndpointTypeException() + + return flask_response + @flask_app.route('/_k/warmup', methods=["POST"]) def warm(): - with self._sequence_number_lock: - self._sequence_number += 1 + request_id = str(uuid.uuid4()) + + # a bit of a hack but we need to send a start and end event to the event queue + # in order to update the status the way the load balancer expects + self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + self._event_queue.put((StatusEvent.INFERENCE_END, request_id)) res = make_response({ "warm": True, }) @@ -276,33 +263,61 @@ def warm(): @flask_app.route('/_k/status', methods=["GET"]) @flask_app.route('/__status__', methods=["GET"]) def status(): - idle_time = 0 - inference_time = 0 - gpu_available = not self._gpu_lock.locked() - - if self._last_inference_start_time != None: - inference_time = int((time.time() - self._last_inference_start_time)*1000) - - if gpu_available: - idle_time = int((time.time() - self._idle_start_time)*1000) + cur_status = self._status res = make_response({ - "gpu_available": gpu_available, - "sequence_number": self._sequence_number, - "idle_time": idle_time, - "inference_time": inference_time, + "gpu_available": cur_status.gpu_available, + "sequence_number": cur_status.sequence_number, + "idle_time": int(cur_status.idle_time*1000), + "inference_time": int(cur_status.longest_inference_time*1000), }) res.status_code = 200 return res return flask_app + + def _init_server(self): + # unless the user has already set up logging, set up logging to stdout using + # a separate fd so that we don't get in the way of request logs + log = logging.getLogger('werkzeug') + if len(log.handlers) == 0: + # duplicate stdout + stdout_copy = os.dup(1) + # redirect flask logs to stdout_copy + log.addHandler(logging.StreamHandler(os.fdopen(stdout_copy, 'w'))) + + self._idle_start_time = time.time() + index_queue = ProcessQueue() + for i in range(self._num_workers): + index_queue.put(i) + if self._num_workers == 1: + Pool = ThreadPool + else: + Pool = ProcessPool + self._worker_pool = Pool( + self._num_workers, + init_worker, + ( + index_queue, + self._event_queue, + self._response_queue, + self._init_func, + self._num_workers + ) + ) + + while True: + if self._status.num_workers_started == self._num_workers: + break + print(colored(f"Started {self._num_workers} workers", 'green')) # serve runs the http server def serve(self, host="0.0.0.0", port=8000): print(colored("------\nStarting Potassium Server 🍌", 'yellow')) - self._init_func() + self._init_server() server = make_server(host, port, self._flask_app, threaded=True) print(colored(f"Serving at http://{host}:{port}\n------", 'green')) - self._idle_start_time = time.time() + server.serve_forever() + diff --git a/potassium/status.py b/potassium/status.py new file mode 100644 index 0000000..edbff22 --- /dev/null +++ b/potassium/status.py @@ -0,0 +1,124 @@ +from enum import Enum +import time +from typing import List, Tuple +from dataclasses import dataclass + +from .types import RequestID + +class InvalidStatusEvent(Exception): + pass + +class StatusEvent(Enum): + INFERENCE_REQUEST_RECEIVED = "INFERENCE_REQUEST_RECEIVED" + INFERENCE_START = "INFERENCE_START" + INFERENCE_END = "INFERENCE_END" + WORKER_STARTED = "WORKER_STARTED" + BAD_REQUEST_RECEIVED = "BAD_REQUEST_RECEIVED" + +@dataclass +class PotassiumStatus(): + """PotassiumStatus is a simple class that represents the status of a Potassium app.""" + num_started_inference_requests: int + num_completed_inference_requests: int + num_bad_requests: int + num_workers: int + num_workers_started: int + idle_start_timestamp: float + in_flight_request_start_times: List[Tuple[RequestID, float]] + + @staticmethod + def initial(num_workers: int) -> "PotassiumStatus": + return PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=num_workers, + num_workers_started=0, + idle_start_timestamp=time.time(), + in_flight_request_start_times=[] + ) + + @property + def requests_in_progress(self): + return self.num_started_inference_requests - self.num_completed_inference_requests + + @property + def gpu_available(self): + if self.num_workers_started < self.num_workers: + return False + return self.num_workers - self.requests_in_progress > 0 + + @property + def sequence_number(self): + return self.num_started_inference_requests + self.num_bad_requests + + @property + def idle_time(self): + num_received_requests_not_completed = self.num_started_inference_requests - self.num_completed_inference_requests + has_incomplete_requests = num_received_requests_not_completed > 0 + if not self.gpu_available or has_incomplete_requests: + return 0 + return time.time() - self.idle_start_timestamp + + @property + def longest_inference_time(self): + if self.in_flight_request_start_times == []: + return 0 + + oldest_start_time = min([start_time for _, start_time in self.in_flight_request_start_times]) + + return time.time() - oldest_start_time + + def update(self, event) -> "PotassiumStatus": + event_type = event[0] + event_data = event[1:] + if event_type not in event_handlers: + raise InvalidStatusEvent(f"Invalid status event: {event_type}") + return event_handlers[event_type](self.clone(), *event_data) + + + def clone(self): + return PotassiumStatus( + self.num_started_inference_requests, + self.num_completed_inference_requests, + self.num_bad_requests, + self.num_workers, + self.num_workers_started, + self.idle_start_timestamp, + self.in_flight_request_start_times + ) + +def handle_start_inference(status: PotassiumStatus, request_id: RequestID): + status.in_flight_request_start_times.append((request_id, time.time())) + return status + +def handle_end_inference(status: PotassiumStatus, request_id: RequestID): + status.num_completed_inference_requests += 1 + status.in_flight_request_start_times = [t for t in status.in_flight_request_start_times if t[0] != request_id] + + if status.gpu_available: + status.idle_start_timestamp = time.time() + + return status + +def handle_inference_request_received(status: PotassiumStatus): + status.num_started_inference_requests += 1 + return status + +def handle_worker_started(status: PotassiumStatus): + status.num_workers_started += 1 + return status + +def handle_bad_request_received(status: PotassiumStatus): + status.num_bad_requests += 1 + return status + +event_handlers = { + StatusEvent.INFERENCE_REQUEST_RECEIVED: handle_inference_request_received, + StatusEvent.INFERENCE_START: handle_start_inference, + StatusEvent.INFERENCE_END: handle_end_inference, + StatusEvent.WORKER_STARTED: handle_worker_started, + StatusEvent.BAD_REQUEST_RECEIVED: handle_bad_request_received +} + + diff --git a/potassium/types.py b/potassium/types.py new file mode 100644 index 0000000..11b8b47 --- /dev/null +++ b/potassium/types.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generator, Optional, Union, Generator, Optional, Union +import json as jsonlib + +class RequestHeaders(): + def __init__(self, headers: Dict[str, str]): + self._headers = {} + for key in headers: + self._headers[self._normalize_key(key)] = headers[key] + + def _normalize_key(self, key): + if not isinstance(key, str): + raise KeyError(key) + return key.upper().replace("-", "_") + + def __getitem__(self, key): + print(self._headers) + key = self._normalize_key(key) + return self._headers[key] + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + +@dataclass +class Request(): + id: str + headers: RequestHeaders + json: Dict[str, Any] + +ResponseBody = Union[bytes, Generator[bytes, None, None]] +RequestID = str + +class Response(): + def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[ResponseBody] = None): + assert json == None or body == None, "Potassium Response object cannot have both json and body set" + + + self.headers = headers if headers != None else {} + + # convert json to body if not None + if json != None: + self.body = jsonlib.dumps(json).encode("utf-8") + self.headers["Content-Type"] = "application/json" + else: + self.body = body + + self.status = status + + @property + def json(self): + if self.body == None: + return None + if type(self.body) == bytes: + try: + return jsonlib.loads(self.body.decode("utf-8")) + except: + return None + return None + + @json.setter + def json(self, json): + self.body = jsonlib.dumps(json).encode("utf-8") + self.headers["Content-Type"] = "application/json" + + diff --git a/potassium/worker.py b/potassium/worker.py new file mode 100644 index 0000000..25c71e1 --- /dev/null +++ b/potassium/worker.py @@ -0,0 +1,144 @@ +from multiprocessing import Queue +import os +import threading +from typing import Dict, Any, Generator +from dataclasses import dataclass +from flask import make_response, Response as FlaskResponse +from termcolor import colored +import traceback +import inspect + +from .status import StatusEvent +from .types import Response + +worker = None + +class FDRedirect(): + def __init__(self, fd: int): + self._fd = fd + self._fd_copy = os.dup(fd) + self._redirect_w = None + + def _run_redirect_loop(self, redirect_r, prefix): + redirect_r = os.fdopen(redirect_r, "r") + + for line in redirect_r: + os.write(self._fd_copy, (prefix + line).encode("utf-8")) + redirect_r.close() + + def set_prefix(self, prefix): + if self._redirect_w is not None: + os.dup2(self._fd_copy, self._fd) + os.close(self._redirect_w) + + fd = self._fd + redirect_r, redirect_w = os.pipe() + + self._fd_copy = os.dup(fd) + os.dup2(redirect_w, fd) + self._redirect_w = redirect_w + + t = threading.Thread(target=self._run_redirect_loop, args=(redirect_r, prefix)) + t.daemon = True + t.start() + + +@dataclass +class Worker(): + worker_num: int + total_workers: int + context: Dict[Any, Any] + event_queue: Queue + response_queue: Queue + stderr_redirect: FDRedirect + stdout_redirect: FDRedirect + +def init_worker(index_queue, event_queue, response_queue, init_func, total_workers): + global worker + worker_num = index_queue.get() + + stdout_redirect = FDRedirect(1) + stderr_redirect = FDRedirect(2) + + if total_workers > 1: + stderr_redirect.set_prefix(f"[worker {worker_num}] ") + stdout_redirect.set_prefix(f"[worker {worker_num}] ") + + # check if the init function takes in a worker number + print(colored("Running init()", 'yellow')) + try: + if len(inspect.signature(init_func).parameters) == 0: + context = init_func() + else: + context = init_func(worker_num) + except Exception as e: + tb_str = traceback.format_exc() + print(colored(tb_str, "red")) + raise e + + if not isinstance(context, dict): + raise Exception("Potassium init() must return a dictionary") + + event_queue.put((StatusEvent.WORKER_STARTED,)) + + worker = Worker( + worker_num, + total_workers, + context, + event_queue, + response_queue, + stdout_redirect, + stderr_redirect + ) + +def run_worker(func, request, internal_id, use_response=False): + assert worker is not None, "worker is not initialized" + + if worker.total_workers > 1: + prefix = f"[worker {worker.worker_num}, requestID {request.id}] " + else: + prefix = f"[requestID {request.id}] " + + + worker.stderr_redirect.set_prefix(prefix) + worker.stdout_redirect.set_prefix(prefix) + + resp = None + worker.event_queue.put((StatusEvent.INFERENCE_START, internal_id)) + + try: + resp = func(worker.context, request) + except: + tb_str = traceback.format_exc() + print(colored(tb_str, "red")) + resp = Response( + status=500, + body=tb_str.encode("utf-8"), + headers={ + "Content-Type": "text/plain" + } + ) + + if use_response: + generator = None + stream_id = None + if inspect.isgenerator(resp.body): + stream_id = 'stream-' + internal_id + generator = resp.body + resp.body = None + worker.response_queue.put((internal_id, (resp, stream_id))) + + # if the response is a generator, we need to iterate through it + if stream_id: + assert generator is not None + for chunk in generator: + worker.response_queue.put((stream_id, chunk)) + worker.response_queue.put((stream_id, None)) + + + if worker.total_workers == 1: + worker.stderr_redirect.set_prefix("") + worker.stdout_redirect.set_prefix("") + + worker.event_queue.put((StatusEvent.INFERENCE_END, internal_id)) + diff --git a/setup.py b/setup.py index 50b80e7..72a5579 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='potassium', packages=['potassium'], - version='0.4.1', + version='0.5.0', license='Apache License 2.0', # Give a short description about your library description='The potassium package is a flask-like HTTP server for serving large AI models', diff --git a/tests/app.py b/tests/app.py new file mode 100644 index 0000000..de16028 --- /dev/null +++ b/tests/app.py @@ -0,0 +1,60 @@ +import potassium + +potassium_test_app = potassium.Potassium("test_app") + +@potassium_test_app.init +def init(): + return {} + +@potassium_test_app.handler() +def handler(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + json={"hello": "root"}, + status=200 + ) + +@potassium_test_app.handler("/some_path") +def handler2(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + json={"hello": "some_path"}, + status=200 + ) + +@potassium_test_app.handler("/some_binary_response") +def handler3(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + body=b"hello", + status=200, + headers={"Content-Type": "application/octet-stream"} + ) + +@potassium_test_app.handler("/some_path_byte_stream_response") +def handler4(context: dict, request: potassium.Request) -> potassium.Response: + def stream(): + yield b"hello" + yield b"world" + + return potassium.Response( + body=stream(), + status=200, + headers={"Content-Type": "application/octet-stream"} + ) + +@potassium_test_app.handler("/some_path/child_path") +def handler2_id(context: dict, request: potassium.Request) -> potassium.Response: + return potassium.Response( + json={"hello": f"some_path/child_path"}, + status=200 + ) + +@potassium_test_app.handler("/some_headers_request") +def handler5(context: dict, request: potassium.Request) -> potassium.Response: + assert request.headers["A"] == "a" + assert request.headers["B"] == "b" + assert request.headers["X-Banana-Request-Id"] == request.id + return potassium.Response( + headers={"A": "a", "B": "b", "X-Banana-Request-Id": request.id}, + json={"hello": "some_headers_request", "id": request.id}, + status=200 + ) + diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 7622b02..89b451a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,7 +4,6 @@ import pytest import potassium - def test_handler(): app = potassium.Potassium("my_app") @@ -101,10 +100,19 @@ def handler5(context: dict, request: potassium.Request) -> potassium.Response: assert res.status_code == 400 # check status - res = client.get("/__status__") - assert res.status_code == 200 - assert res.json is not None - assert res.json["gpu_available"] == True + count = 0 + while True: + res = client.get("/__status__") + assert res.status_code == 200 + assert res.json is not None + + if res.json["gpu_available"] == True: + break + elif count > 10: + assert False, "GPU never became available" + else: + time.sleep(0.1) + count += 1 # parameterized test for path collisions @pytest.mark.parametrize("paths", [ @@ -196,6 +204,8 @@ def background(context: dict, request: potassium.Request): res = client.post("/this_path_does_not_exist", json={}) assert res.status_code == 404 + # takes a split second for the status to update + time.sleep(0.1) res = client.get("/__status__", json={}) assert res.status_code == 200 assert res.json is not None @@ -219,29 +229,36 @@ def background(context: dict, request: potassium.Request): resolve_background_condition.wait() - def wait_for_background_task(): - app._read_event_chan() - order_of_execution_queue.put("background_task_completed") - - thread = threading.Thread(target=wait_for_background_task) - thread.start() - client = app.test_client() # send background post in separate thread - order_of_execution_queue.put("send_background_task") res = client.post("/background", json={}) assert res.status_code == 200 + + time.sleep(0.1) + + res = client.get("/__status__", json={}) + + assert res.status_code == 200 + assert res.json is not None + assert res.json["gpu_available"] == False + assert res.json["sequence_number"] == 1 + # notify background thread to continue with resolve_background_condition: resolve_background_condition.notify() - thread.join() + time.sleep(0.1) + + res = client.get("/__status__", json={}) + + assert res.status_code == 200 + assert res.json is not None + assert res.json["gpu_available"] == True + assert res.json["sequence_number"] == 1 + - # assert order of execution - assert order_of_execution_queue.get() == "send_background_task" - assert order_of_execution_queue.get() == "background_task_completed" def test_warmup(): app = potassium.Potassium("my_app") @@ -260,6 +277,7 @@ def handler(context: dict, request: potassium.Request) -> potassium.Response: assert res.status_code == 200 assert res.json == {"warm": True} + time.sleep(0.1) res = client.get("/__status__", json={}) assert res.status_code == 200 assert res.json is not None diff --git a/tests/test_status.py b/tests/test_status.py new file mode 100644 index 0000000..cba909f --- /dev/null +++ b/tests/test_status.py @@ -0,0 +1,207 @@ +import pytest +from potassium.status import StatusEvent, PotassiumStatus, InvalidStatusEvent +import time + +@pytest.mark.parametrize("worker_num", [ + 1, + 2, + 4 +]) +def test_workers_starting(worker_num): + status = PotassiumStatus.initial(worker_num) + assert status.num_workers == worker_num + assert status.num_workers_started == 0 + assert status.gpu_available == False + status = status.update((StatusEvent.WORKER_STARTED,)) + + if worker_num == 1: + assert status.gpu_available == True + else: + assert status.gpu_available == False + + for _ in range(worker_num-1): + status = status.update((StatusEvent.WORKER_STARTED,)) + assert status.num_workers_started == worker_num + assert status.gpu_available == True + +def test_bad_event(): + status = PotassiumStatus.initial(1) + with pytest.raises(InvalidStatusEvent): + status.update(("BAD_EVENT",)) + +def test_inference_requests_single_worker(): + status = PotassiumStatus.initial(1) + status = status.update((StatusEvent.WORKER_STARTED,)) + + status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert status.num_started_inference_requests == 1 + assert status.num_completed_inference_requests == 0 + assert status.gpu_available == False + + status = status.update((StatusEvent.INFERENCE_START, 0)) + status = status.update((StatusEvent.INFERENCE_END, 0)) + + assert status.num_started_inference_requests == 1 + assert status.num_completed_inference_requests == 1 + assert status.sequence_number == 1 + assert status.gpu_available == True + + status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert status.num_started_inference_requests == 2 + status = status.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert status.num_started_inference_requests == 3 + + status = status.update((StatusEvent.INFERENCE_START, 1)) + status = status.update((StatusEvent.INFERENCE_END, 1)) + + assert status.num_started_inference_requests == 3 + assert status.num_completed_inference_requests == 2 + assert status.sequence_number == 3 + assert status.gpu_available == False + + status = status.update((StatusEvent.INFERENCE_START, 2)) + status = status.update((StatusEvent.INFERENCE_END, 2)) + + assert status.num_started_inference_requests == 3 + assert status.num_completed_inference_requests == 3 + assert status.sequence_number == 3 + assert status.gpu_available == True + +def test_inference_requests_multiple_workers(): + state = PotassiumStatus.initial(2) + + state = state.update((StatusEvent.WORKER_STARTED,)) + state = state.update((StatusEvent.WORKER_STARTED,)) + + assert state.gpu_available == True + assert state.sequence_number == 0 + + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert state.num_started_inference_requests == 1 + assert state.num_completed_inference_requests == 0 + assert state.gpu_available == True + + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + assert state.num_started_inference_requests == 2 + assert state.num_completed_inference_requests == 0 + assert state.gpu_available == False + + state = state.update((StatusEvent.INFERENCE_START, 0)) + state = state.update((StatusEvent.INFERENCE_END, 0)) + + assert state.num_started_inference_requests == 2 + assert state.num_completed_inference_requests == 1 + assert state.sequence_number == 2 + assert state.gpu_available == True + + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + state = state.update((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 1 + assert state.sequence_number == 4 + assert state.gpu_available == False + + state = state.update((StatusEvent.INFERENCE_START, 1)) + state = state.update((StatusEvent.INFERENCE_END, 1)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 2 + assert state.sequence_number == 4 + assert state.gpu_available == False + + state = state.update((StatusEvent.INFERENCE_START, 2)) + state = state.update((StatusEvent.INFERENCE_END, 2)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 3 + assert state.sequence_number == 4 + assert state.gpu_available == True + + state = state.update((StatusEvent.INFERENCE_START, 3)) + state = state.update((StatusEvent.INFERENCE_END, 3)) + + assert state.num_started_inference_requests == 4 + assert state.num_completed_inference_requests == 4 + assert state.sequence_number == 4 + assert state.gpu_available == True + +@pytest.mark.parametrize("status_result_tuple", [ + (PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=1, + num_workers_started=0, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), 0), + (PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=1, + num_workers_started=1, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), time.time()), + (PotassiumStatus( + num_started_inference_requests=1, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=1, + num_workers_started=1, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), 0), + (PotassiumStatus( + num_started_inference_requests=2, + num_completed_inference_requests=0, + num_bad_requests=0, + num_workers=4, + num_workers_started=4, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), 0), + (PotassiumStatus( + num_started_inference_requests=2, + num_completed_inference_requests=2, + num_bad_requests=0, + num_workers=4, + num_workers_started=4, + idle_start_timestamp=0, + in_flight_request_start_times=[] + ), time.time()), +]) +def test_idle_time(status_result_tuple): + status, result = status_result_tuple + delta = abs(status.idle_time - result) + ALLOWED_DELTA = 1 + assert delta < ALLOWED_DELTA + +def test_longest_inference_time(): + status = PotassiumStatus( + num_started_inference_requests=6, + num_completed_inference_requests=2, + num_bad_requests=0, + num_workers=4, + num_workers_started=4, + idle_start_timestamp=0, + in_flight_request_start_times=[ + ("b", time.time() - 2), + ("a", time.time() - 1), + ("c", time.time() - 3), + ("d", time.time()), + ] + ) + + longest_inference_time = status.longest_inference_time + EXPECTED_LONGEST_INFERENCE_TIME = 3 + delta = abs(longest_inference_time - EXPECTED_LONGEST_INFERENCE_TIME) + + ALLOWED_DELTA = 0.1 + assert delta < ALLOWED_DELTA + + + +