diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index c11b5eb7..c2e4e36c 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -3,12 +3,11 @@ """ import os -from functools import reduce from jupyterhub.auth import LocalAuthenticator from jupyterhub.traitlets import Callable from tornado.httpclient import AsyncHTTPClient -from traitlets import Bool, Dict, Set, Unicode, Union, default, observe +from traitlets import Bool, Dict, Unicode, Union, default, observe from .oauth2 import OAuthenticator @@ -37,9 +36,13 @@ def _login_service_default(self): def _claim_groups_key_changed(self, change): if callable(change.new): # Automatically wrap the claim_gorups_key call so it gets what it thinks it should get - self.auth_model_groups_key = lambda auth_model: self.claim_groups_key(auth_model["auth_state"][self.user_auth_state_key]) + self.auth_model_groups_key = lambda auth_model: self.claim_groups_key( + auth_model["auth_state"][self.user_auth_state_key] + ) else: - self.auth_model_groups_key = f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}" + self.auth_model_groups_key = ( + f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}" + ) @default("http_client") def _default_http_client(self): diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index c6846ef9..cde8e95e 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -7,8 +7,8 @@ import base64 import json import os -from functools import reduce import uuid +from functools import reduce from urllib.parse import quote, urlencode, urlparse, urlunparse import jwt @@ -21,7 +21,18 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest from tornado.httputil import url_concat from tornado.log import app_log -from traitlets import Any, Bool, Callable, Dict, List, Unicode, Union, default, Set, validate +from traitlets import ( + Any, + Bool, + Callable, + Dict, + List, + Set, + Unicode, + Union, + default, + validate, +) def guess_callback_uri(protocol, host, hub_server_url): @@ -1047,7 +1058,9 @@ def get_user_groups(self, auth_model: dict): if callable(self.claim_groups_key): return set(self.auth_model_groups_key(auth_model)) try: - return set(reduce(dict.get, self.auth_model_groups_key.split("."), auth_model)) + return set( + reduce(dict.get, self.auth_model_groups_key.split("."), auth_model) + ) except TypeError: self.log.error( f"The auth_model_groups_key {self.auth_model_groups_key} does not exist in the auth_model. Available keys are: {auth_model.keys()}"