Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AzureAD] Support manage_groups #710

Merged
merged 8 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/tutorials/provider-specific-setup/providers/azuread.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,20 @@ AzureAdOAuthenticator expands OAuthenticator with the following config that may
be relevant to read more about in the configuration reference:

- {attr}`.AzureAdOAuthenticator.tenant_id`

## Loading user groups

The `AzureAdOAuthenticator` can load the group-membership of users from the access token.
This is done by setting the `AzureAdOAuthenticator.groups_claim` to the name of the claim that contains the
group-membership.

```python
c.JupyterHub.authenticator_class = "azuread"

# {...} other settings (see above)

c.AzureAdOAuthenticator.manage_groups = True
c.AzureAdOAuthenticator.user_groups_claim = 'groups' # this is the default
```

This requires Azure AD to be configured to include the group-membership in the access token.
19 changes: 19 additions & 0 deletions oauthenticator/azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def _login_service_default(self):
def _username_claim_default(self):
return "name"

user_groups_claim = Unicode(
"groups",
config=True,
help="""
Name of claim containing user group memberships.

Will populate JupyterHub groups if Authenticator.manage_groups is True.
""",
)

tenant_id = Unicode(
config=True,
help="""
Expand All @@ -44,6 +54,15 @@ def _authorize_url_default(self):
def _token_url_default(self):
return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/token"

async def update_auth_model(self, auth_model, **kwargs):
auth_model = await super().update_auth_model(auth_model, **kwargs)

if getattr(self, "manage_groups", False):
user_info = auth_model["auth_state"][self.user_auth_state_key]
auth_model["groups"] = user_info[self.user_groups_claim]

return auth_model

async def token_to_user(self, token_info):
id_token = token_info['id_token']
decoded = jwt.decode(
Expand Down
43 changes: 42 additions & 1 deletion oauthenticator/tests/test_azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest import mock

import jwt
import pytest
from pytest import fixture, mark
from traitlets.config import Config

Expand Down Expand Up @@ -44,6 +45,17 @@ def user_model(tenant_id, client_id, name):
"tid": tenant_id,
"nonce": "123523",
"aio": "Df2UVXL1ix!lMCWMSOJBcFatzcGfvFGhjKv8q5g0x732dR5MB5BisvGQO7YWByjd8iQDLq!eGbIDakyp5mnOrcdqHeYSnltepQmRp6AIZ8jY",
"groups": [
"96000b2c-7333-4f6e-a2c3-e7608fa2d131",
"a992b3d5-1966-4af4-abed-6ef021417be4",
"ceb90a42-030f-44f1-a0c7-825b572a3b07",
],
# different from 'groups' for tests
"grp": [
"96000b2c-7333-4f6e-a2c3",
"a992b3d5-1966-4af4-abed",
"ceb90a42-030f-44f1-a0c7",
],
},
os.urandom(5),
)
Expand Down Expand Up @@ -103,6 +115,23 @@ def user_model(tenant_id, client_id, name):
True,
None,
),
# test user_groups_claim
(
"30",
{"allow_all": True, "manage_groups": True},
True,
None,
),
(
"31",
{
"allow_all": True,
"manage_groups": True,
"user_groups_claim": "grp",
},
True,
None,
),
],
)
async def test_azuread(
Expand All @@ -119,6 +148,12 @@ async def test_azuread(
c.AzureAdOAuthenticator.client_id = str(uuid.uuid1())
c.AzureAdOAuthenticator.client_secret = str(uuid.uuid1())
authenticator = AzureAdOAuthenticator(config=c)
manage_groups = False
if "manage_groups" in class_config:
if hasattr(authenticator, "manage_groups"):
manage_groups = authenticator.manage_groups
else:
pytest.skip("manage_groups requires jupyterhub 2.2")

handled_user_model = user_model(
tenant_id=authenticator.tenant_id,
Expand All @@ -130,14 +165,20 @@ async def test_azuread(

if expect_allowed:
assert auth_model
assert set(auth_model) == {"name", "admin", "auth_state"}
expected_keys = {"name", "admin", "auth_state"}
if manage_groups:
expected_keys.add("groups")
assert set(auth_model) == expected_keys
assert auth_model["admin"] == expect_admin
auth_state = auth_model["auth_state"]
assert json.dumps(auth_state)
assert "access_token" in auth_state
user_info = auth_state[authenticator.user_auth_state_key]
assert user_info["aud"] == authenticator.client_id
assert auth_model["name"] == user_info[authenticator.username_claim]
if manage_groups:
groups = auth_model['groups']
assert groups == user_info[authenticator.user_groups_claim]
else:
assert auth_model == None

Expand Down