class OIDCProvider(JWTProviderMixin):
"""OIDC identity provider.
This provider authenticates users using the OIDC password flow, retrieves
user information from OIDC claims, and validates tokens using either a
token factory or OIDC token introspection. It can be configured with custom
claim mappings and options to populate groups, permissions, and raw claims
on the Identity object.
"""
protocol = "oidc"
token_factory: TokenFactory | None = None
def __init__(
self,
connector: OIDCConnector,
token_factory: TokenFactory | None = None,
claim_mappings: Mapping[str, str | Sequence[str]] | None = None,
populate_groups: bool = True,
populate_permissions: bool = False,
populate_claims: bool = False,
change_password_supported: bool = False,
) -> None:
"""Initialize OIDCProvider.
Parameters
----------
connector
Connector to use for OIDC operations.
token_factory
Factory used to issue/validate local tokens.
claim_mappings
Mapping of OIDC claims to Identity fields. Defaults to common OIDC
claim mappings. The mapping values can be either a single claim
name or a sequence of claim names. If a sequence is provided, the
claims will be checked in order and the first non-empty value will
be used.
populate_groups
Whether to populate group memberships on the Identity.
populate_permissions
Whether to populate permissions on the Identity.
populate_claims
Whether to include raw claims on the Identity.
change_password_supported
Whether this provider supports changing passwords.
"""
self._connector = connector
self.token_factory = token_factory
self._claim_mappings = (
dict(claim_mappings) if claim_mappings else DEFAULT_OIDC_MAPPINGS
)
self._populate_groups = populate_groups
self._populate_permissions = populate_permissions
self._populate_claims = populate_claims
self._change_password_supported = change_password_supported
def authenticate(self, credentials: PasswordCredentials) -> Identity:
"""Authenticate a user using OIDC password flow.
Parameters
----------
credentials
User credentials.
Returns
-------
Identity
Authenticated user identity.
"""
token_data = self._connector.request_password_token(
username=credentials.username,
password=credentials.password,
)
access_token = token_data.get("access_token")
if not access_token:
raise exceptions.InvalidCredentialsException(
"OIDC token response did not include access_token"
)
claims = {}
if self._connector.userinfo_url:
claims = self._connector.get_userinfo(access_token)
merged_claims = {**token_data, **claims}
return self._convert_claims_to_identity(merged_claims)
def get_user(self, subject: str) -> Identity:
"""Retrieve a user by subject using an admin lookup.
Parameters
----------
subject
User subject identifier.
Returns
-------
Identity
Retrieved user identity.
"""
if not self._connector.user_lookup_url_template:
raise exceptions.NotSupportedException(
"User lookup is not configured for this provider"
)
claims = self._connector.get_user_by_subject(subject)
return self._convert_claims_to_identity(claims)
def change_password(
self, credentials: PasswordCredentials, new_password: str
) -> None:
"""Change user password (if supported).
Parameters
----------
credentials
Current user credentials.
new_password
New password to set.
"""
if not self._change_password_supported:
message = (
"Change password operation is not supported by this provider"
)
else:
message = "Change password is not implemented for OIDC providers"
raise exceptions.NotSupportedException(message)
def validate(self, token: Token) -> Identity:
"""Validate a token using token factory or introspection.
Parameters
----------
token
Token to validate.
Returns
-------
Identity
Validated user identity.
"""
if self.token_factory:
return super().validate(token)
if not self._connector.introspection_url:
raise exceptions.NotSupportedException(
"Token introspection is not configured for this provider"
)
claims = self._connector.introspect_token(token.value)
if not claims.get("active", False):
raise exceptions.InvalidTokenException("Token is not active")
return self._convert_claims_to_identity(claims)
def _convert_claims_to_identity(
self, claims: Mapping[str, Any]
) -> Identity:
"""Convert OIDC claims to an Identity.
Parameters
----------
claims
OIDC claims to convert.
Returns
-------
Converted Identity object.
Raises
------
exceptions.IdentityError
Raised if the claims do not include a required subject.
"""
subject = self._get_claim(claims, "subject")
if not subject:
raise exceptions.IdentityError(
"OIDC claims did not include a subject"
)
username = self._get_claim(claims, "username")
if not username:
username = subject
issued_at = datetime.now(tz=timezone.utc)
groups = (
self._extract_sequence(self._get_claim(claims, "groups"))
if self._populate_groups
else []
)
permissions = (
self._extract_permissions(claims)
if self._populate_permissions
else []
)
audience = self._extract_audience(claims)
identity = Identity(
subject=str(subject),
username=username,
email=self._get_claim(claims, "email"),
display_name=self._get_claim(claims, "display_name"),
groups=groups,
permissions=permissions,
claims=dict(claims) if self._populate_claims else {},
issued_at=issued_at,
audience=audience,
role=self._get_claim(claims, "role"),
admin=bool(self._get_claim(claims, "admin", False)),
)
return identity
def _get_claim(
self,
claims: Mapping[str, Any],
field: str,
default: Any | None = None,
) -> Any:
"""Get a claim value by field name, using mappings if configured.
Parameters
----------
claims
OIDC claims to retrieve the value from.
field
Name of the claim field to retrieve.
default
Default value to return if the claim is not found, by default None.
Returns
-------
Value of the claim field, or the default if not found.
"""
mapping = self._claim_mappings.get(field)
if mapping is None:
return claims.get(field, default)
for path in self._ensure_sequence(mapping):
value = self._get_claim_by_path(claims, path)
if value is not None:
return value
return default
@staticmethod
def _ensure_sequence(value: str | Sequence[str]) -> Sequence[str]:
"""Ensure the value is a sequence of strings.
Parameters
----------
value
Value to ensure as a sequence of strings.
Returns
-------
Sequence of strings.
"""
if isinstance(value, str):
return [value]
return value
@staticmethod
def _get_claim_by_path(claims: Mapping[str, Any], path: str) -> Any:
"""Get a claim value by a dot-separated path.
Parameters
----------
claims
OIDC claims to retrieve the value from.
path
Dot-separated path to the claim value.
Returns
-------
Value of the claim at the specified path, or None if not found.
"""
current: Any = claims
for segment in path.split("."):
if not isinstance(current, Mapping):
return None
current = cast(Mapping[str, Any], current).get(segment)
return current
@staticmethod
def _extract_sequence(value: Any) -> list[str]:
"""Extract a sequence of strings from the given value.
Parameters
----------
value
Value to extract the sequence from.
Returns
-------
Sequence of strings.
"""
if value is None:
return []
if isinstance(value, str):
return [value]
if isinstance(value, Sequence):
return [str(item) for item in cast(Sequence[Any], value)]
return [str(value)]
def _extract_permissions(self, claims: Mapping[str, Any]) -> list[str]:
"""Extract permissions from the given claims.
Parameters
----------
claims
OIDC claims to retrieve the permissions from.
Returns
-------
Sequence of strings.
"""
raw = self._get_claim(claims, "permissions")
if raw is None:
raw = claims.get("scope")
if isinstance(raw, str):
return [value for value in raw.split() if value]
if isinstance(raw, Sequence):
return [str(item) for item in cast(Sequence[Any], raw)]
return []
@staticmethod
def _extract_audience(claims: Mapping[str, Any]) -> list[str] | None:
"""Extract audience from the given claims.
Parameters
----------
claims
OIDC claims to retrieve the audience from.
Returns
-------
Sequence of strings, or None if not found.
"""
aud = claims.get("aud")
if aud is None:
return None
if isinstance(aud, str):
return [aud]
if isinstance(aud, Sequence):
return [str(item) for item in cast(Sequence[Any], aud)]
return [str(aud)]