diff --git a/jupyterhub/spawner.py b/jupyterhub/spawner.py index 8e8f320bc..726fa5ff9 100644 --- a/jupyterhub/spawner.py +++ b/jupyterhub/spawner.py @@ -50,6 +50,7 @@ exponential_backoff, maybe_future, random_port, + recursive_update, url_escape_path, url_path_join, ) @@ -306,6 +307,31 @@ def _default_access_scopes(self): f"access:servers!user={self.user.name}", ] + group_overrides = Union( + [Callable(), Dict()], + help=""" + Override specific traitlets based on group membership of the user. + + This can be a dict, or a callable that returns a dict. The key of the dict + is *only* used for lexicographical sorting, to guarantee a consistent + ordering of the overrides. If it is a callable, it may be async, and will + be passed one parameter - the spawner instance. It should return a dictionary. + + The value of the dict is a dict, with the following keys: + + - *groups* - If the user belongs to *any* of these groups, these overrides are + applied to their server before spawning. + - *spawner_override* - a dictionary with overrides to apply to the Spawner + settings. Each value can be either the final value to change or a callable that + take the `KubeSpawner` instance as parameter and return the final value. This can + be further overridden by 'profile_options' + If the traitlet being overriden is a *dictionary*, the dictionary + will be *recursively updated*, rather than overriden. If you want to + remove a key, set its value to `None` + """, + config=True, + ) + handler = Any() oauth_roles = Union( @@ -504,7 +530,7 @@ async def _get_oauth_client_allowed_scopes(self): max=1, help=""" Jitter fraction for poll_interval. - + Avoids alignment of poll calls for many Spawners, e.g. when restarting JupyterHub, which restarts all polls for running Spawners. @@ -1479,6 +1505,54 @@ async def _wait_for_death(): except AnyTimeoutError: return False + def _apply_overrides(self, spawner_override: dict): + """ + Apply set of overrides onto the current spawner instance + + spawner_override is a dict with key being the name of the traitlet + to override, and value is either a callable or the value for the + traitlet. If the value is a dictionary, it is *merged* with the + existing value (rather than replaced). Callables are called with + one parameter - the current spawner instance. + """ + for k, v in spawner_override.items(): + if callable(v): + v = v(self) + self.log.info( + f".. overriding {self.__class__.__name__} value %s=%s (callable result)", + k, + v, + ) + else: + self.log.info( + f".. overriding {self.__class__.__name__} value %s=%s", k, v + ) + + # If v is a dict, *merge* it with existing values, rather than completely + # resetting it. This allows *adding* things like environment variables rather + # than completely replacing them. If value is set to None, the key + # will be removed + if isinstance(v, dict) and isinstance(getattr(self, k), dict): + recursive_update(getattr(self, k), v) + else: + setattr(self, k, v) + + async def apply_group_overrides(self): + """ + Apply group overrides before starting a server + """ + user_group_names = {g.name for g in self.user.groups} + if callable(self.group_overrides): + group_overrides = await maybe_future(self.group_overrides(self)) + else: + group_overrides = self.group_overrides + for key in sorted(group_overrides): + go = group_overrides[key] + if user_group_names & set(go['groups']): + # If there is *any* overlap between the groups user is in + # and the groups for this override, apply overrides + self._apply_overrides(go['spawner_override']) + def _try_setcwd(path): """Try to set CWD to path, walking up until a valid directory is found. diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index 6155864f0..8fdeb4d1f 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -23,8 +23,7 @@ from ..user import User from ..utils import AnyTimeoutError, maybe_future, new_token, url_path_join from .mocking import public_url -from .test_api import add_user -from .utils import async_requests +from .utils import add_user, async_requests, find_user _echo_sleep = """ import sys, time @@ -598,3 +597,90 @@ def test_spawner_server(db): spawner.server = Server.from_url("http://1.2.3.4") assert spawner.server is not None assert spawner.server.ip == "1.2.3.4" + + +async def test_group_override(app): + app.load_groups = { + "admin": {"users": ["admin"]}, + "user": {"users": ["admin", "user"]}, + } + await app.init_groups() + + group_overrides = { + "01-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 120}, + } + } + + admin_user = find_user(app.db, "admin") + s = Spawner(user=admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 120 + + non_admin_user = find_user(app.db, "user") + s = Spawner(user=non_admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 60 + + +async def test_group_override_lexical_ordering(app): + app.load_groups = { + "admin": {"users": ["admin"]}, + "user": {"users": ["admin", "user"]}, + } + await app.init_groups() + + group_overrides = { + # this should be applied last, even though it is specified first, + # due to lexical ordering based on key names + "02-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 300}, + }, + "01-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 120}, + }, + } + + admin_user = find_user(app.db, "admin") + s = Spawner(user=admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 300 + + +async def test_group_override_callable(app): + app.load_groups = { + "admin": {"users": ["admin"]}, + "user": {"users": ["admin", "user"]}, + } + await app.init_groups() + + def group_overrides(spawner): + return { + "01-admin-mem-limit": { + "groups": ["admin"], + "spawner_override": {"start_timeout": 120}, + } + } + + admin_user = find_user(app.db, "admin") + s = Spawner(user=admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 120 + + non_admin_user = find_user(app.db, "user") + s = Spawner(user=non_admin_user) + s.start_timeout = 60 + s.group_overrides = group_overrides + await s.apply_group_overrides() + assert s.start_timeout == 60 diff --git a/jupyterhub/user.py b/jupyterhub/user.py index c86d1089a..f4c1da189 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -905,6 +905,7 @@ async def spawn(self, server_name='', options=None, handler=None): # wait for spawner.start to return # run optional preparation work to bootstrap the notebook await maybe_future(spawner.run_pre_spawn_hook()) + await spawner.apply_group_overrides() if self.settings.get('internal_ssl'): self.log.debug("Creating internal SSL certs for %s", spawner._log_name) hub_paths = await maybe_future(spawner.create_certs()) diff --git a/jupyterhub/utils.py b/jupyterhub/utils.py index 2eb38c0d5..3035e4875 100644 --- a/jupyterhub/utils.py +++ b/jupyterhub/utils.py @@ -942,3 +942,23 @@ def subdomain_hook_idna(name, domain, kind): else: suffix = f"--{kind}" return f"{safe_name}{suffix}.{domain}" + + +# From https://github.com/jupyter-server/jupyter_server/blob/fc0ac3236fdd92778ea765db6e8982212c8389ee/jupyter_server/config_manager.py#L14 +def recursive_update(target, new): + """ + Recursively update one dictionary in-place using another. + + None values will delete their keys. + """ + for k, v in new.items(): + if isinstance(v, dict): + if k not in target: + target[k] = {} + recursive_update(target[k], v) + + elif v is None: + target.pop(k, None) + + else: + target[k] = v