From 5f3833bc9529f4a35831266eed4ac0d05d81b580 Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Thu, 23 May 2024 19:38:56 -0700 Subject: [PATCH] Allow overriding spawner config based on user group membership Similar to 'kubespawner_override' in KubeSpawner, this allows admins to selectivel override spawner configuration based on groups a user belongs to. This allows for low maintenance but extremely powerful customization based on group membership. This is particularly powerful when combined with https://github.com/jupyterhub/oauthenticator/pull/735 \#\# Dictionary vs List Ordering is important here, but still I choose to implement this configuration as a dictionary of dictionaries vs a list. This is primarily to allow for easy overriding in z2jh (and similar places), where Lists are just really hard to override. Ordering is provided by lexicographically sorting the keys, similar to how we do it in z2jh. \#\# Merging config The merging code is literally copied from KubeSpawner, and provides the exact same behavior. Documentation of how it acts is also copied. --- jupyterhub/spawner.py | 76 ++++++++++++++++++++++++++- jupyterhub/tests/test_spawner.py | 90 +++++++++++++++++++++++++++++++- jupyterhub/user.py | 1 + jupyterhub/utils.py | 20 +++++++ 4 files changed, 184 insertions(+), 3 deletions(-) diff --git a/jupyterhub/spawner.py b/jupyterhub/spawner.py index 8e8f320bc..726fa5ff9 100644 --- a/jupyterhub/spawner.py +++ b/jupyterhub/spawner.py @@ -50,6 +50,7 @@ exponential_backoff, maybe_future, random_port, + recursive_update, url_escape_path, url_path_join, ) @@ -306,6 +307,31 @@ def _default_access_scopes(self): f"access:servers!user={self.user.name}", ] + group_overrides = Union( + [Callable(), Dict()], + help=""" + Override specific traitlets based on group membership of the user. + + This can be a dict, or a callable that returns a dict. The key of the dict + is *only* used for lexicographical sorting, to guarantee a consistent + ordering of the overrides. If it is a callable, it may be async, and will + be passed one parameter - the spawner instance. It should return a dictionary. + + The value of the dict is a dict, with the following keys: + + - *groups* - If the user belongs to *any* of these groups, these overrides are + applied to their server before spawning. + - *spawner_override* - a dictionary with overrides to apply to the Spawner + settings. Each value can be either the final value to change or a callable that + take the `KubeSpawner` instance as parameter and return the final value. This can + be further overridden by 'profile_options' + If the traitlet being overriden is a *dictionary*, the dictionary + will be *recursively updated*, rather than overriden. If you want to + remove a key, set its value to `None` + """, + config=True, + ) + handler = Any() oauth_roles = Union( @@ -504,7 +530,7 @@ async def _get_oauth_client_allowed_scopes(self): max=1, help=""" Jitter fraction for poll_interval. - + Avoids alignment of poll calls for many Spawners, e.g. when restarting JupyterHub, which restarts all polls for running Spawners. @@ -1479,6 +1505,54 @@ async def _wait_for_death(): except AnyTimeoutError: return False + def _apply_overrides(self, spawner_override: dict): + """ + Apply set of overrides onto the current spawner instance + + spawner_override is a dict with key being the name of the traitlet + to override, and value is either a callable or the value for the + traitlet. If the value is a dictionary, it is *merged* with the + existing value (rather than replaced). Callables are called with + one parameter - the current spawner instance. + """ + for k, v in spawner_override.items(): + if callable(v): + v = v(self) + self.log.info( + f".. overriding {self.__class__.__name__} value %s=%s (callable result)", + k, + v, + ) + else: + self.log.info( + f".. overriding {self.__class__.__name__} value %s=%s", k, v + ) + + # If v is a dict, *merge* it with existing values, rather than completely + # resetting it. This allows *adding* things like environment variables rather + # than completely replacing them. If value is set to None, the key + # will be removed + if isinstance(v, dict) and isinstance(getattr(self, k), dict): + recursive_update(getattr(self, k), v) + else: + setattr(self, k, v) + + async def apply_group_overrides(self): + """ + Apply group overrides before starting a server + """ + user_group_names = {g.name for g in self.user.groups} + if callable(self.group_overrides): + group_overrides = await maybe_future(self.group_overrides(self)) + else: + group_overrides = self.group_overrides + for key in sorted(group_overrides): + go = group_overrides[key] + if user_group_names & set(go['groups']): + # If there is *any* overlap between the groups user is in + # and the groups for this override, apply overrides + self._apply_overrides(go['spawner_override']) + def _try_setcwd(path): """Try to set CWD to path, walking up until a valid directory is found. diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index 6155864f0..8fdeb4d1f 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -23,8 +23,7 @@ from ..user import User from ..utils import AnyTimeoutError, maybe_future, new_token, url_path_join from .mocking import public_url -from .test_api import add_user -from .utils import async_requests +from .utils import add_user, async_requests, find_user _echo_sleep = """ import sys, time @@ -598,3 +597,90 @@ def test_spawner_server(db): spawner.server = Server.from_url("http://1.2.3.4") assert spawner.server is not None assert spawner.server.ip == "1.2.3.4" + + +async def test_group_override(app): + app.load_groups = { + "admin": {"users": ["admin"]}, + "user": {"users": ["admin", "user"]}, + } + await app.init_groups() + + group_overrides = { + "01-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 120}, + } + } + + admin_user = find_user(app.db, "admin") + s = Spawner(user=admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 120 + + non_admin_user = find_user(app.db, "user") + s = Spawner(user=non_admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 60 + + +async def test_group_override_lexical_ordering(app): + app.load_groups = { + "admin": {"users": ["admin"]}, + "user": {"users": ["admin", "user"]}, + } + await app.init_groups() + + group_overrides = { + # this should be applied last, even though it is specified first, + # due to lexical ordering based on key names + "02-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 300}, + }, + "01-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 120}, + }, + } + + admin_user = find_user(app.db, "admin") + s = Spawner(user=admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 300 + + +async def test_group_override_callable(app): + app.load_groups = { + "admin": {"users": ["admin"]}, + "user": {"users": ["admin", "user"]}, + } + await app.init_groups() + + def group_overrides(spawner): + return { + "01-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 120}, + } + } + + admin_user = find_user(app.db, "admin") + s = Spawner(user=admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 120 + + non_admin_user = find_user(app.db, "user") + s = Spawner(user=non_admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 60 diff --git a/jupyterhub/user.py b/jupyterhub/user.py index c86d1089a..f4c1da189 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -905,6 +905,7 @@ async def spawn(self, server_name='', options=None, handler=None): # wait for spawner.start to return # run optional preparation work to bootstrap the notebook await maybe_future(spawner.run_pre_spawn_hook()) + await spawner.apply_group_overrides() if self.settings.get('internal_ssl'): self.log.debug("Creating internal SSL certs for %s", spawner._log_name) hub_paths = await maybe_future(spawner.create_certs()) diff --git a/jupyterhub/utils.py b/jupyterhub/utils.py index 2eb38c0d5..3035e4875 100644 --- a/jupyterhub/utils.py +++ b/jupyterhub/utils.py @@ -942,3 +942,23 @@ def subdomain_hook_idna(name, domain, kind): else: suffix = f"--{kind}" return f"{safe_name}{suffix}.{domain}" + + +# From https://github.com/jupyter-server/jupyter_server/blob/fc0ac3236fdd92778ea765db6e8982212c8389ee/jupyter_server/config_manager.py#L14 +def recursive_update(target, new): + """ + Recursively update one dictionary in-place using another. + + None values will delete their keys. + """ + for k, v in new.items(): + if isinstance(v, dict): + if k not in target: + target[k] = {} + recursive_update(target[k], v) + + elif v is None: + target.pop(k, None) + + else: + target[k] = v