Skip to content

Commit

Permalink
some cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
Peddle committed Dec 8, 2023
1 parent d8559b9 commit 86e8571
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
26 changes: 13 additions & 13 deletions potassium/potassium.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,14 @@ def _stream_body(self, stream_id):
with self._lock:
del self._mailbox[stream_id]
raise result


print('generator exit')

with self._lock:
del self._mailbox[stream_id]


class Potassium():
"Potassium is a simple, stateful, GPU-enabled, and autoscaleable web framework for deploying machine learning models."

def __init__(self, name):
def __init__(self, name, experimental_num_workers=1):
self.name = name

# default init function, if the user doesn't specify one
Expand All @@ -113,7 +109,7 @@ def __init__(self, name):
self._response_queue = ProcessQueue()
self._response_mailbox = ResponseMailbox(self._response_queue)

self._num_workers = int(os.environ.get("POTASSIUM_NUM_WORKERS", 1))
self._num_workers = experimental_num_workers

self._worker_pool = None

Expand Down Expand Up @@ -150,12 +146,6 @@ def init(self, func):
- the context is not shared between multiple replicas of the app
"""

# def wrapper(worker_num):
# print(colored("Running init()", 'yellow'))
# self._context = func(worker_num)
# if not isinstance(self._context, dict):
# raise Exception("Potassium init() must return a dictionary")

self._init_func = func
return func

Expand Down Expand Up @@ -303,7 +293,17 @@ def _init_server(self):
Pool = ThreadPool
else:
Pool = ProcessPool
self._worker_pool = Pool(self._num_workers, init_worker, (index_queue, self._event_queue, self._response_queue, self._init_func, self._num_workers))
self._worker_pool = Pool(
self._num_workers,
init_worker,
(
index_queue,
self._event_queue,
self._response_queue,
self._init_func,
self._num_workers
)
)

while True:
if self._status.num_workers_started == self._num_workers:
Expand Down
5 changes: 4 additions & 1 deletion potassium/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from .types import RequestID

class InvalidStatusEvent(Exception):
pass

class StatusEvent(Enum):
INFERENCE_REQUEST_RECEIVED = "INFERENCE_REQUEST_RECEIVED"
INFERENCE_START = "INFERENCE_START"
Expand Down Expand Up @@ -54,7 +57,7 @@ def update(self, event) -> "PotassiumStatus":
event_type = event[0]
event_data = event[1:]
if event_type not in event_handlers:
raise Exception(f"Invalid event {event}")
raise InvalidStatusEvent(f"Invalid status event: {event_type}")
return event_handlers[event_type](self.clone(), *event_data)


Expand Down
3 changes: 3 additions & 0 deletions potassium/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def init_worker(index_queue, event_queue, response_queue, init_func, total_worke
stdout_redirect.set_prefix(f"[worker {worker_num}] ")

# check if the init function takes in a worker number
print(colored("Running init()", 'yellow'))
try:
if len(inspect.signature(init_func).parameters) == 0:
context = init_func()
Expand All @@ -75,6 +76,8 @@ def init_worker(index_queue, event_queue, response_queue, init_func, total_worke
print(colored(tb_str, "red"))
raise e

if not isinstance(context, dict):
raise Exception("Potassium init() must return a dictionary")

event_queue.put((StatusEvent.WORKER_STARTED,))

Expand Down

0 comments on commit 86e8571

Please sign in to comment.