Skip to content

Commit

Permalink
Allow overriding spawner config based on user group membership
Browse files Browse the repository at this point in the history
Similar to 'kubespawner_override' in KubeSpawner, this allows
admins to selectivel override spawner configuration based on
groups a user belongs to. This allows for low maintenance but
extremely powerful customization based on group membership.
This is particularly powerful when combined with
jupyterhub/oauthenticator#735

\#\# Dictionary vs List

Ordering is important here, but still I choose to implement this
configuration as a dictionary of dictionaries vs a list. This is
primarily to allow for easy overriding in z2jh (and similar places),
where Lists are just really hard to override. Ordering is provided
by lexicographically sorting the keys, similar to how we do it in z2jh.

\#\# Merging config

The merging code is literally copied from KubeSpawner, and provides
the exact same behavior. Documentation of how it acts is also copied.
  • Loading branch information
yuvipanda committed May 24, 2024
1 parent 282cc02 commit 5f3833b
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 3 deletions.
76 changes: 75 additions & 1 deletion jupyterhub/spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
exponential_backoff,
maybe_future,
random_port,
recursive_update,
url_escape_path,
url_path_join,
)
Expand Down Expand Up @@ -306,6 +307,31 @@ def _default_access_scopes(self):
f"access:servers!user={self.user.name}",
]

group_overrides = Union(
[Callable(), Dict()],
help="""
Override specific traitlets based on group membership of the user.
This can be a dict, or a callable that returns a dict. The key of the dict
is *only* used for lexicographical sorting, to guarantee a consistent
ordering of the overrides. If it is a callable, it may be async, and will
be passed one parameter - the spawner instance. It should return a dictionary.
The value of the dict is a dict, with the following keys:
- *groups* - If the user belongs to *any* of these groups, these overrides are
applied to their server before spawning.
- *spawner_override* - a dictionary with overrides to apply to the Spawner
settings. Each value can be either the final value to change or a callable that
take the `KubeSpawner` instance as parameter and return the final value. This can
be further overridden by 'profile_options'
If the traitlet being overriden is a *dictionary*, the dictionary
will be *recursively updated*, rather than overriden. If you want to
remove a key, set its value to `None`
""",
config=True,
)

handler = Any()

oauth_roles = Union(
Expand Down Expand Up @@ -504,7 +530,7 @@ async def _get_oauth_client_allowed_scopes(self):
max=1,
help="""
Jitter fraction for poll_interval.
Avoids alignment of poll calls for many Spawners,
e.g. when restarting JupyterHub, which restarts all polls for running Spawners.
Expand Down Expand Up @@ -1479,6 +1505,54 @@ async def _wait_for_death():
except AnyTimeoutError:
return False

def _apply_overrides(self, spawner_override: dict):
"""
Apply set of overrides onto the current spawner instance
spawner_override is a dict with key being the name of the traitlet
to override, and value is either a callable or the value for the
traitlet. If the value is a dictionary, it is *merged* with the
existing value (rather than replaced). Callables are called with
one parameter - the current spawner instance.
"""
for k, v in spawner_override.items():
if callable(v):
v = v(self)
self.log.info(
f".. overriding {self.__class__.__name__} value %s=%s (callable result)",
k,
v,
)
else:
self.log.info(
f".. overriding {self.__class__.__name__} value %s=%s", k, v
)

# If v is a dict, *merge* it with existing values, rather than completely
# resetting it. This allows *adding* things like environment variables rather
# than completely replacing them. If value is set to None, the key
# will be removed
if isinstance(v, dict) and isinstance(getattr(self, k), dict):
recursive_update(getattr(self, k), v)
else:
setattr(self, k, v)

async def apply_group_overrides(self):
"""
Apply group overrides before starting a server
"""
user_group_names = {g.name for g in self.user.groups}
if callable(self.group_overrides):
group_overrides = await maybe_future(self.group_overrides(self))
else:
group_overrides = self.group_overrides
for key in sorted(group_overrides):
go = group_overrides[key]
if user_group_names & set(go['groups']):
# If there is *any* overlap between the groups user is in
# and the groups for this override, apply overrides
self._apply_overrides(go['spawner_override'])


def _try_setcwd(path):
"""Try to set CWD to path, walking up until a valid directory is found.
Expand Down
90 changes: 88 additions & 2 deletions jupyterhub/tests/test_spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from ..user import User
from ..utils import AnyTimeoutError, maybe_future, new_token, url_path_join
from .mocking import public_url
from .test_api import add_user
from .utils import async_requests
from .utils import add_user, async_requests, find_user

_echo_sleep = """
import sys, time
Expand Down Expand Up @@ -598,3 +597,90 @@ def test_spawner_server(db):
spawner.server = Server.from_url("http://1.2.3.4")
assert spawner.server is not None
assert spawner.server.ip == "1.2.3.4"


async def test_group_override(app):
app.load_groups = {
"admin": {"users": ["admin"]},
"user": {"users": ["admin", "user"]},
}
await app.init_groups()

group_overrides = {
"01-admin-mem-limit": {
"groups": ["admin"],
"spawner_override": {"start_timeout": 120},
}
}

admin_user = find_user(app.db, "admin")
s = Spawner(user=admin_user)
s.start_timeout = 60
s.group_overrides = group_overrides
await s.apply_group_overrides()
assert s.start_timeout == 120

non_admin_user = find_user(app.db, "user")
s = Spawner(user=non_admin_user)
s.start_timeout = 60
s.group_overrides = group_overrides
await s.apply_group_overrides()
assert s.start_timeout == 60


async def test_group_override_lexical_ordering(app):
app.load_groups = {
"admin": {"users": ["admin"]},
"user": {"users": ["admin", "user"]},
}
await app.init_groups()

group_overrides = {
# this should be applied last, even though it is specified first,
# due to lexical ordering based on key names
"02-admin-mem-limit": {
"groups": ["admin"],
"spawner_override": {"start_timeout": 300},
},
"01-admin-mem-limit": {
"groups": ["admin"],
"spawner_override": {"start_timeout": 120},
},
}

admin_user = find_user(app.db, "admin")
s = Spawner(user=admin_user)
s.start_timeout = 60
s.group_overrides = group_overrides
await s.apply_group_overrides()
assert s.start_timeout == 300


async def test_group_override_callable(app):
app.load_groups = {
"admin": {"users": ["admin"]},
"user": {"users": ["admin", "user"]},
}
await app.init_groups()

def group_overrides(spawner):
return {
"01-admin-mem-limit": {
"groups": ["admin"],
"spawner_override": {"start_timeout": 120},
}
}

admin_user = find_user(app.db, "admin")
s = Spawner(user=admin_user)
s.start_timeout = 60
s.group_overrides = group_overrides
await s.apply_group_overrides()
assert s.start_timeout == 120

non_admin_user = find_user(app.db, "user")
s = Spawner(user=non_admin_user)
s.start_timeout = 60
s.group_overrides = group_overrides
await s.apply_group_overrides()
assert s.start_timeout == 60
1 change: 1 addition & 0 deletions jupyterhub/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ async def spawn(self, server_name='', options=None, handler=None):
# wait for spawner.start to return
# run optional preparation work to bootstrap the notebook
await maybe_future(spawner.run_pre_spawn_hook())
await spawner.apply_group_overrides()
if self.settings.get('internal_ssl'):
self.log.debug("Creating internal SSL certs for %s", spawner._log_name)
hub_paths = await maybe_future(spawner.create_certs())
Expand Down
20 changes: 20 additions & 0 deletions jupyterhub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,3 +942,23 @@ def subdomain_hook_idna(name, domain, kind):
else:
suffix = f"--{kind}"
return f"{safe_name}{suffix}.{domain}"


# From https://github.com/jupyter-server/jupyter_server/blob/fc0ac3236fdd92778ea765db6e8982212c8389ee/jupyter_server/config_manager.py#L14
def recursive_update(target, new):
"""
Recursively update one dictionary in-place using another.
None values will delete their keys.
"""
for k, v in new.items():
if isinstance(v, dict):
if k not in target:
target[k] = {}
recursive_update(target[k], v)

elif v is None:
target.pop(k, None)

else:
target[k] = v

0 comments on commit 5f3833b

Please sign in to comment.