From 6e81b6aece36728b748cd46a23441c2e88ea614a Mon Sep 17 00:00:00 2001 From: Fantix King Date: Thu, 28 Aug 2025 16:08:03 -0400 Subject: [PATCH] OIDC providers --- gel/_internal/_auth/_oauth.py | 133 ++++++++++++ gel/_internal/_auth/_pkce.py | 26 ++- .../_integration/_fastapi/_auth/__init__.py | 78 +++++++- .../_fastapi/_auth/_email_password.py | 3 +- .../_integration/_fastapi/_auth/_oidc.py | 189 ++++++++++++++++++ gel/auth/__init__.py | 2 + gel/auth/oauth.py | 21 ++ 7 files changed, 445 insertions(+), 7 deletions(-) create mode 100644 gel/_internal/_auth/_oauth.py create mode 100644 gel/_internal/_integration/_fastapi/_auth/_oidc.py create mode 100644 gel/auth/oauth.py diff --git a/gel/_internal/_auth/_oauth.py b/gel/_internal/_auth/_oauth.py new file mode 100644 index 000000000..cb6aaca02 --- /dev/null +++ b/gel/_internal/_auth/_oauth.py @@ -0,0 +1,133 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from __future__ import annotations +from typing import Any, Optional, TypeVar + +import dataclasses + +import httpx +import jwt +import logging + +import gel +from gel import blocking_client + +from . import _base as base +from . import _pkce as pkce_mod +from . import _token_data as td_mod + + +logger = logging.getLogger("gel.auth") + + +@dataclasses.dataclass +class AuthorizeData: + verifier: str + redirect_url: str + + +@dataclasses.dataclass +class TokenData(td_mod.TokenData): + provider_id_token: Optional[str] + + def get_id_token_claims(self) -> Optional[Any]: + if self.provider_id_token is None: + return None + return jwt.decode( + self.provider_id_token, options={"verify_signature": False} + ) + + +C = TypeVar("C", bound=httpx.Client | httpx.AsyncClient) + + +class BaseOAuth(base.BaseClient[C]): + def __init__( + self, + provider_name: str, + *, + connection_info: gel.ConnectionInfo, + **kwargs: Any, + ) -> None: + super().__init__(connection_info=connection_info, **kwargs) + self._provider_name = provider_name + + def authorize( + self, + *, + redirect_to: str, + redirect_to_on_signup: Optional[str], + callback_url: Optional[str] = None, + ) -> AuthorizeData: + pkce = self._generate_pkce() + redirect_url = ( + self._client.base_url.join("authorize") + .copy_set_param("provider", self._provider_name) + .copy_set_param("redirect_to", redirect_to) + .copy_set_param("challenge", pkce.challenge) + ) + if redirect_to_on_signup is not None: + redirect_url = redirect_url.copy_set_param( + "redirect_to_on_signup", redirect_to_on_signup + ) + if callback_url is not None: + redirect_url = redirect_url.copy_set_param( + "callback_url", callback_url + ) + + return AuthorizeData( + verifier=pkce.verifier, + redirect_url=str(redirect_url), + ) + + async def _get_token(self, *, verifier: str, code: str) -> TokenData: + pkce = self._pkce_from_verifier(verifier) + logger.info("exchanging code for token: %s", code) + return await pkce.internal_exchange_code_for_token(code, cls=TokenData) + + +class OAuth(BaseOAuth[httpx.Client]): + def _init_http_client(self, **kwargs: Any) -> httpx.Client: + return httpx.Client(**kwargs) + + def _generate_pkce(self) -> pkce_mod.PKCE: + return pkce_mod.generate_pkce(self._client) + + def _pkce_from_verifier(self, verifier: str) -> pkce_mod.PKCE: + return pkce_mod.PKCE(self._client, verifier) + + def get_token(self, *, verifier: str, code: str) -> TokenData: + return blocking_client.iter_coroutine( + self._get_token(verifier=verifier, code=code) + ) + + +def make( + client: gel.Client, *, provider_name: str, cls: type[OAuth] = OAuth +) -> OAuth: + return cls(provider_name, connection_info=client.check_connection()) + + +class AsyncOAuth(BaseOAuth[httpx.AsyncClient]): + def _init_http_client(self, **kwargs: Any) -> httpx.AsyncClient: + return httpx.AsyncClient(**kwargs) + + def _generate_pkce(self) -> pkce_mod.AsyncPKCE: + return pkce_mod.generate_async_pkce(self._client) + + def _pkce_from_verifier(self, verifier: str) -> pkce_mod.AsyncPKCE: + return pkce_mod.AsyncPKCE(self._client, verifier) + + async def get_token(self, *, verifier: str, code: str) -> TokenData: + return await self._get_token(verifier=verifier, code=code) + + +async def make_async( + client: gel.AsyncIOClient, + *, + provider_name: str, + cls: type[AsyncOAuth] = AsyncOAuth, +) -> AsyncOAuth: + return cls(provider_name, connection_info=await client.check_connection()) diff --git a/gel/_internal/_auth/_pkce.py b/gel/_internal/_auth/_pkce.py index c4d744a21..33e01a10e 100644 --- a/gel/_internal/_auth/_pkce.py +++ b/gel/_internal/_auth/_pkce.py @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. from __future__ import annotations -from typing import Generic, TypeVar +from typing import Generic, TypeVar, overload import base64 import dataclasses @@ -20,6 +20,7 @@ logger = logging.getLogger("gel.auth") C = TypeVar("C", bound=httpx.Client | httpx.AsyncClient) +TokenData_T = TypeVar("TokenData_T", bound=token_data.TokenData) class BasePKCE(Generic[C]): @@ -47,8 +48,25 @@ async def _send_http_request( ) -> httpx.Response: raise NotImplementedError + @overload async def internal_exchange_code_for_token( - self, code: str + self, + code: str, + *, + cls: type[TokenData_T], + ) -> TokenData_T: ... + + @overload + async def internal_exchange_code_for_token( + self, + code: str, + ) -> token_data.TokenData: ... + + async def internal_exchange_code_for_token( + self, + code: str, + *, + cls: type[token_data.TokenData] = token_data.TokenData, ) -> token_data.TokenData: request = self._http_client.build_request( "GET", @@ -70,9 +88,9 @@ async def internal_exchange_code_for_token( token_json = token_response.json() args = { field.name: token_json[field.name] - for field in dataclasses.fields(token_data.TokenData) + for field in dataclasses.fields(cls) } - return token_data.TokenData(**args) + return cls(**args) class PKCE(BasePKCE[httpx.Client]): diff --git a/gel/_internal/_integration/_fastapi/_auth/__init__.py b/gel/_internal/_integration/_fastapi/_auth/__init__.py index 1d51a19cb..64eef588e 100644 --- a/gel/_internal/_integration/_fastapi/_auth/__init__.py +++ b/gel/_internal/_integration/_fastapi/_auth/__init__.py @@ -3,6 +3,8 @@ # SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. from __future__ import annotations + +import functools from typing import Any, cast, Optional, TYPE_CHECKING from typing_extensions import Self @@ -26,6 +28,17 @@ import gel from ._email_password import EmailPassword from ._builtin_ui import BuiltinUI + from ._oidc import OpenIDConnect + + +_BUILTIN_OIDC_PROVIDERS = { + "apple": "builtin::oauth_apple", + "azure": "builtin::oauth_azure", + "discord": "builtin::oauth_discord", # actually OAuth2 without id_token + "slack": "builtin::oauth_slack", + "github": "builtin::oauth_github", # actually OAuth2 without id_token + "google": "builtin::oauth_google", +} class Installable: @@ -53,11 +66,14 @@ class GelAuth(client_mod.Extension): secure_cookie = utils.Config(True) # noqa: FBT003 redirect_to: utils.Config[Optional[str]] = utils.Config("/") redirect_to_page_name: utils.Config[Optional[str]] = utils.Config(None) + error_page_name = utils.Config("error_page") _email_password: Optional[EmailPassword] = None _auto_email_password: bool = True _builtin_ui: Optional[BuiltinUI] = None _auto_builtin_ui: bool = True + _manual_oidc_providers: list[str] + _oidc_providers: dict[str, OpenIDConnect] _on_new_identity_path = utils.Config("/") _on_new_identity_name = utils.Config("gel.fastapi.auth.on_new_identity") @@ -72,6 +88,11 @@ class GelAuth(client_mod.Extension): _maybe_auth_token: params.Depends _auth_token: params.Depends + def __init__(self, lifespan: client_mod.GelLifespan) -> None: + super().__init__(lifespan) + self._manual_oidc_providers = [] + self._oidc_providers = {} + def get_unchecked_exp(self, token: str) -> Optional[datetime.datetime]: jwt_payload = jwt.decode(token, options={"verify_signature": False}) if "exp" not in jwt_payload: @@ -245,6 +266,49 @@ def without_builtin_ui(self) -> Self: self._auto_builtin_ui = False return self + def openid_connect(self, name: str) -> OpenIDConnect: + if name in self._oidc_providers: + provider = self._oidc_providers[name] + else: + if self.installed: + raise ValueError("Cannot add OIDC provider after installation") + + from ._oidc import OpenIDConnect # noqa: PLC0415 + + provider = OpenIDConnect(self, provider_name=name) + self._oidc_providers[name] = provider + return provider + + def with_openid_connect(self, name: str, **kwargs: Any) -> Self: + provider = self.openid_connect(name) + for key, value in kwargs.items(): + getattr(provider, key)(value) + return self + + def without_openid_connect(self, name: str) -> Self: + if self.installed: + raise ValueError("Cannot remove OIDC provider after installation") + + if name in self._oidc_providers: + del self._oidc_providers[name] + self._manual_oidc_providers.append(name) + return self + + def __getattr__(self, item: str) -> Any: + if item.startswith("with_"): + name = _BUILTIN_OIDC_PROVIDERS.get(item.removeprefix("with_")) + if name is not None: + return functools.partial(self.with_openid_connect, name) + elif item.startswith("without_"): + name = _BUILTIN_OIDC_PROVIDERS.get(item.removeprefix("without_")) + if name is not None: + return functools.partial(self.without_openid_connect, name) + elif item in _BUILTIN_OIDC_PROVIDERS: + return self.openid_connect(_BUILTIN_OIDC_PROVIDERS[item]) + raise AttributeError( + f"{type(self).__name__!r} has no attribute {item!r}" + ) + async def on_startup(self, app: fastapi.FastAPI) -> None: router = fastapi.APIRouter( prefix=self.auth_path_prefix.value, @@ -258,13 +322,24 @@ async def on_startup(self, app: fastapi.FastAPI) -> None: select assert_single( cfg::Config.extensions[is ext::auth::AuthConfig] ) { - providers: { id, name }, + providers: { id, name, type := .__type__.name }, ui, } """ ) if config: for provider in config.providers: + if ( + provider.name in _BUILTIN_OIDC_PROVIDERS.values() + or provider.type == "ext::auth::OpenIDConnectProvider" + ): + if ( + provider.name not in self._manual_oidc_providers + and provider.name not in self._oidc_providers + ): + self.openid_connect(provider.name) + continue + match provider.name: case "builtin::local_emailpassword": if ( @@ -281,6 +356,7 @@ async def on_startup(self, app: fastapi.FastAPI) -> None: _ = self.builtin_ui insts.extend([self._email_password, self._builtin_ui]) + insts.extend(self._oidc_providers.values()) for inst in insts: if inst is not None: await inst.install(router) diff --git a/gel/_internal/_integration/_fastapi/_auth/_email_password.py b/gel/_internal/_integration/_fastapi/_auth/_email_password.py index 3df018f6c..06cdcfc60 100644 --- a/gel/_internal/_integration/_fastapi/_auth/_email_password.py +++ b/gel/_internal/_integration/_fastapi/_auth/_email_password.py @@ -46,7 +46,6 @@ class ResetPasswordBody(pydantic.BaseModel): class EmailPassword(Installable): - error_page_name = utils.Config("error_page") sign_in_page_name = utils.Config("sign_in_page") reset_password_page_name = utils.Config("reset_password_page") @@ -202,7 +201,7 @@ def _redirect_error( ).value return response_class( url=request.url_for( - self.error_page_name.value + self._auth.error_page_name.value ).include_query_params(**query_params), status_code=getattr(self, f"{key}_default_status_code").value, ) diff --git a/gel/_internal/_integration/_fastapi/_auth/_oidc.py b/gel/_internal/_integration/_fastapi/_auth/_oidc.py new file mode 100644 index 000000000..4ebd87b23 --- /dev/null +++ b/gel/_internal/_integration/_fastapi/_auth/_oidc.py @@ -0,0 +1,189 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from __future__ import annotations +from typing import Annotated, Optional + +import http + +import fastapi +from fastapi import responses +from starlette import concurrency + +from gel.auth import oauth +from gel._internal._auth._oauth import TokenData # noqa: TC001 + +from . import GelAuth, Installable +from .. import _utils as utils + + +class OpenIDConnect(Installable): + _auth: GelAuth + _core: oauth.AsyncOAuth + _blocking_io_core: oauth.OAuth + + install_endpoints = utils.Config(True) # noqa: FBT003 + + # Authorize + authorize_path = utils.Config("/{provider_name}/authorize") + authorize_name = utils.Config("gel.auth.oidc.{provider_name}.authorize") + authorize_summary = utils.Config( + "Authorize with OpenID Connect provider: {provider_name}" + ) + authorize_status_code = utils.Config(http.HTTPStatus.SEE_OTHER) + + # Callback + callback_path = utils.Config("/{provider_name}/callback") + callback_name = utils.Config("gel.auth.oidc.{provider_name}.callback") + callback_summary = utils.Config( + "Handle the OpenID Connect callback from provider: {provider_name}" + ) + callback_default_response_class = utils.Config(responses.RedirectResponse) + callback_default_status_code = utils.Config(http.HTTPStatus.SEE_OTHER) + on_sign_in_complete: utils.Hook[TokenData] = utils.Hook("callback") + on_sign_up_complete: utils.Hook[TokenData] = utils.Hook("callback") + + def __init__(self, auth: GelAuth, *, provider_name: str) -> None: + self._auth = auth + self._provider_name = provider_name + + def _redirect_success(self, request: fastapi.Request) -> fastapi.Response: + response_class = self.callback_default_response_class.value + response_code = self.callback_default_status_code.value + redirect_to = self._auth.redirect_to.value + redirect_to_page_name = self._auth.redirect_to_page_name.value + if redirect_to_page_name is not None: + return response_class( + url=request.url_for(redirect_to_page_name), + status_code=response_code, + ) + elif redirect_to is not None: + return response_class(url=redirect_to, status_code=response_code) + else: + raise RuntimeError( + "GelAuth should have either redirect_to or " + "redirect_to_page_name set" + ) + + def _redirect_error( + self, request: fastapi.Request, **query_params: str + ) -> fastapi.Response: + response_class = self.callback_default_response_class.value + return response_class( + url=request.url_for( + self._auth.error_page_name.value + ).include_query_params(**query_params), + status_code=self.callback_default_status_code.value, + ) + + def __install_authorize(self, router: fastapi.APIRouter) -> None: + callback_name = self.callback_name.value.format( + provider_name=self._provider_name + ) + + @router.get( + self.authorize_path.value.format( + provider_name=self._provider_name + ), + name=self.authorize_name.value.format( + provider_name=self._provider_name + ), + summary=self.authorize_summary.value.format( + provider_name=self._provider_name.title() + ), + response_class=responses.RedirectResponse, + status_code=self.authorize_status_code.value, + ) + async def authorize( + request: fastapi.Request, response: fastapi.Response + ) -> str: + callback_url = request.url_for(callback_name) + auth_data = self._core.authorize( + redirect_to=str(callback_url), + redirect_to_on_signup=str( + callback_url.replace_query_params(isSignUp=True) + ), + ) + self._auth.set_verifier_cookie(auth_data.verifier, response) + return auth_data.redirect_url + + @router.get( + self.callback_path.value.format(provider_name=self._provider_name), + name=callback_name, + summary=self.callback_summary.value.format( + provider_name=self._provider_name.title() + ), + ) + async def callback( + request: fastapi.Request, + *, + code: Optional[str] = None, + error: Optional[str] = None, + error_description: Optional[str] = None, + verifier: str = fastapi.Depends(self._auth.pkce_verifier), + is_sign_up: Annotated[ + bool, fastapi.Query(alias="isSignUp") + ] = False, + ) -> fastapi.Response: + if code is None: + assert error is not None + args = {"error": error} + if error_description is not None: + args["error_description"] = error_description + return self._redirect_error(request, **args) + + token_data = await self._core.get_token( + verifier=verifier, code=code + ) + if is_sign_up: + response = await self._auth.handle_new_identity( + request, token_data.identity_id, token_data + ) + if response is None: + if self.on_sign_up_complete.is_set(): + with self._auth.with_auth_token( + token_data.auth_token, request + ): + response = await self.on_sign_up_complete.call( + request, token_data + ) + else: + response = self._redirect_success(request) + else: + if self.on_sign_in_complete.is_set(): + with self._auth.with_auth_token( + token_data.auth_token, request + ): + response = await self.on_sign_in_complete.call( + request, token_data + ) + else: + response = self._redirect_success(request) + self._auth.set_auth_cookie( + token_data.auth_token, response=response + ) + return response + + @property + def blocking_io_core(self) -> oauth.OAuth: + return self._blocking_io_core + + @property + def core(self) -> oauth.AsyncOAuth: + return self._core + + async def install(self, router: fastapi.APIRouter) -> None: + self._core = await oauth.make_async( + self._auth.client, provider_name=self._provider_name + ) + self._blocking_io_core = await concurrency.run_in_threadpool( + oauth.make, + self._auth.blocking_io_client, + provider_name=self._provider_name, + ) + + if self.install_endpoints.value: + self.__install_authorize(router) + + await super().install(router) diff --git a/gel/auth/__init__.py b/gel/auth/__init__.py index 75c0fabed..0876a2a73 100644 --- a/gel/auth/__init__.py +++ b/gel/auth/__init__.py @@ -19,10 +19,12 @@ from . import builtin_ui from . import email_password + from . import oauth __all__ = [ "builtin_ui", "email_password", + "oauth", "TokenData", "PKCE", "generate_pkce", diff --git a/gel/auth/oauth.py b/gel/auth/oauth.py new file mode 100644 index 000000000..4dabbb967 --- /dev/null +++ b/gel/auth/oauth.py @@ -0,0 +1,21 @@ +# SPDX-PackageName: gel-python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors. + +from gel._internal._auth._oauth import ( + AsyncOAuth, + AuthorizeData, + OAuth, + TokenData, + make, + make_async, +) + +__all__ = [ + "AsyncOAuth", + "AuthorizeData", + "OAuth", + "TokenData", + "make", + "make_async", +]