Source code for cruds.auth

"""
Complex authentication flows used to gain access with the CRUDs Client
"""

import json
import logging
import secrets
from time import time
from urllib.parse import urlencode, parse_qs, urlparse
from typing import Any

import certifi
import urllib3
import hashlib
import hmac

from .core import AuthABC
from .exception import OAuthAccessTokenError, OAuthStateError

logger = logging.getLogger(__name__)

TOKEN_REFRESH_LEAD_TIME = 30


[docs] class OAuth2(AuthABC): """ A client for the OAuth 2.0 specification, supporting access token using the 'Client Credientials', 'Password', and 'Authorization Code' grant types. """ def __init__( self, url: str, client_id: str, client_secret: str, scope: str, authorization_details=None, username=None, password=None, encryption_key: str | None = None, # Authorization Code flow parameters authorization_url: str | None = None, redirect_uri: str | None = None, state_length: int = 32, ) -> None: """ Arguments --------- url: str The url name of the OAuth2 platform. client_id: str The ID for the client authentication. client_secret: str The Secret for the client authentication. scope: str The scope required that the token will have access too. authorization_details: list(dict) Fine-grained parameters for Rich Authorization Request (RAR) username: str (optional) Username used for 'Password' grant type. password: str (optional) Password used for 'Password' grant type. encryption_key: str (optional) Key for encrypting token state. If not provided, a key will be derived from client_secret. For production use, provide a strong encryption key. authorization_url: str (optional) Authorization endpoint URL for Authorization Code flow. redirect_uri: str (optional) Redirect URI for Authorization Code flow. state_length: int (optional) Length of the state parameter for CSRF protection (default: 32). """ self.url = url self.client_id = client_id self.client_secret = client_secret self.scope = scope self.authorization_details = authorization_details self.username = username self.password = password self.authorization_url = authorization_url self.redirect_uri = redirect_uri self.state_length = state_length # Initialize encryption self._encryption_key: str | None = self._initialize_encryption(encryption_key) self._encrypted_state = b"" self._http = urllib3.PoolManager(ca_certs=certifi.where()) # State parameter for CSRF protection self._pending_state: str | None = None def _initialize_encryption(self, encryption_key: str | None = None) -> str | None: """ Initialize encryption using provided key. Args: encryption_key: Optional encryption key. If not provided, encryption will be handled dynamically with random salts. Returns: Encryption key string, or None if using dynamic encryption """ if encryption_key: # Use provided key return encryption_key else: # For client_secret-based encryption, we use dynamic salts # No single key is created here return None def _generate_encryption_key(self, salt: bytes) -> bytes: """ Generate encryption key from client_secret using the provided salt. Args: salt: The salt to use for key derivation Returns: Encryption key bytes """ kdf = hashlib.pbkdf2_hmac( "sha256", self.client_secret.encode(), salt, 100000, ) return kdf def _encrypt_with_key(self, data: bytes, key: bytes) -> bytes: """ Encrypt data using a simple XOR-based encryption with HMAC for integrity. Args: data: Data to encrypt key: Encryption key Returns: Encrypted data with HMAC """ # Generate a random IV iv = secrets.token_bytes(16) # Pad data to be a multiple of 16 bytes padding_length = 16 - (len(data) % 16) padded_data = data + bytes([padding_length] * padding_length) # Simple XOR encryption with key cycling encrypted = bytearray() for i, byte in enumerate(padded_data): key_byte = key[i % len(key)] encrypted.append(byte ^ key_byte) # Add HMAC for integrity h = hmac.new(key, iv + bytes(encrypted), hashlib.sha256) hmac_digest = h.digest() # Return: IV + encrypted_data + HMAC return iv + bytes(encrypted) + hmac_digest def _decrypt_with_key(self, encrypted_data: bytes, key: bytes) -> bytes: """ Decrypt data using the same XOR-based encryption. Args: encrypted_data: Encrypted data with IV and HMAC key: Encryption key Returns: Decrypted data Raises: ValueError: If HMAC verification fails """ if len(encrypted_data) < 48: # IV(16) + HMAC(32) + at least 1 byte of data raise ValueError("Encrypted data too short") # Extract components iv = encrypted_data[:16] hmac_digest = encrypted_data[-32:] encrypted = encrypted_data[16:-32] # Verify HMAC h = hmac.new(key, iv + encrypted, hashlib.sha256) expected_hmac = h.digest() if not hmac.compare_digest(hmac_digest, expected_hmac): raise ValueError("HMAC verification failed") # Decrypt using XOR decrypted = bytearray() for i, byte in enumerate(encrypted): key_byte = key[i % len(key)] decrypted.append(byte ^ key_byte) # Remove padding padding_length = decrypted[-1] if padding_length > 16 or padding_length == 0: raise ValueError("Invalid padding") return bytes(decrypted[:-padding_length]) def _generate_state_parameter(self) -> str: """ Generate a cryptographically secure state parameter for CSRF protection. Returns: A secure random state string """ return secrets.token_urlsafe(self.state_length) def _validate_state_parameter(self, received_state: str) -> bool: """ Validate the received state parameter against the stored pending state. Args: received_state: The state parameter received from the authorization server Returns: True if state is valid, False otherwise """ if not self._pending_state: logger.warning("No pending state parameter found") return False if received_state != self._pending_state: logger.warning("State parameter mismatch - possible CSRF attack") return False # Clear the pending state after successful validation self._pending_state = None return True def _handle_http_response(self, response) -> None: """ Handle HTTP response and raise appropriate exceptions for errors. Args: response: The HTTP response from urllib3 Raises: OAuthAccessTokenError: For 4xx client errors urllib3.exceptions.HTTPError: For other HTTP errors """ if 400 <= response.status < 500: msg: str = response.json().get("error_description", "Unknown") raise OAuthAccessTokenError(msg) if not 200 <= response.status < 299: # Try to decode error message, fallback to raw data try: error_message = response.data.decode("utf-8") except UnicodeDecodeError: error_message = str(response.data) raise urllib3.exceptions.HTTPError( f"Error with status code {response.status} Message: {error_message}" ) def _add_authorization_details(self, params: dict[str, str]) -> None: """ Add authorization details to the parameters if provided. Args: params: Dictionary to add authorization details to """ if self.authorization_details is not None and isinstance( self.authorization_details, list ): params["authorization_details"] = json.dumps(self.authorization_details) elif self.authorization_details is not None: params["authorization_details"] = self.authorization_details def _make_oauth_request( self, fields: dict[str, str], use_auth: bool = True ) -> dict[str, Any]: """ Make an OAuth request to the token endpoint. Args: fields: The form fields to send in the request use_auth: Whether to include basic authentication headers Returns: The JSON response from the server Raises: OAuthAccessTokenError: For 4xx client errors urllib3.exceptions.HTTPError: For other HTTP errors """ # Setup Authentication if use_auth: request_headers = urllib3.make_headers( basic_auth=f"{self.client_id}:{self.client_secret}" ) else: request_headers = urllib3.make_headers() request_headers["Content-Type"] = ( "application/x-www-form-urlencoded; charset=utf-8" ) # Make request to the token endpoint response = self._http.request( "POST", self.url, body=urlencode(fields), headers=request_headers, redirect=False, ) # Handle response self._handle_http_response(response) # Return the JSON response return response.json()
[docs] def get_authorization_url( self, additional_params: dict[str, str] | None = None ) -> str: """ Generate the authorization URL for the Authorization Code flow. Args: additional_params: Additional parameters to include in the authorization URL Returns: The complete authorization URL with state parameter Raises: RuntimeError: If authorization_url or redirect_uri is not configured """ if not self.authorization_url or not self.redirect_uri: raise RuntimeError( "Authorization Code flow requires both authorization_url and redirect_uri to be configured" ) # Generate state parameter for CSRF protection state = self._generate_state_parameter() self._pending_state = state # Build authorization parameters params = { "response_type": "code", "client_id": self.client_id, "redirect_uri": self.redirect_uri, "scope": self.scope, "state": state, } # Add authorization details if provided self._add_authorization_details(params) # Add additional parameters if additional_params: params.update(additional_params) # Build the authorization URL query_string = urlencode(params) authorization_url = f"{self.authorization_url}?{query_string}" logger.debug(f"Generated authorization URL with state parameter: {state}") return authorization_url
[docs] def exchange_code_for_token(self, authorization_code: str, state: str) -> str: """ Exchange authorization code for access token (Authorization Code flow). Args: authorization_code: The authorization code received from the authorization server state: The state parameter received from the authorization server Returns: The access token Raises: RuntimeError: If state parameter validation fails OAuthAccessTokenError: If token exchange fails """ # Validate state parameter for CSRF protection if not self._validate_state_parameter(state): raise OAuthStateError("Invalid state parameter - possible CSRF attack") logger.debug("Exchanging authorization code for access token") if not self.redirect_uri: raise RuntimeError("redirect_uri is required for Authorization Code flow") # Prepare token exchange request fields: dict[str, str] = { "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": self.redirect_uri, } # Make request to the token endpoint access_token = self._make_oauth_request(fields) # Store the token response access_token["created"] = int(time()) self._state = access_token logger.debug("Successfully exchanged authorization code for access token") return access_token.get("access_token", "")
[docs] def parse_authorization_response(self, redirect_url: str) -> tuple[str, str]: """ Parse the authorization response from the redirect URL. Args: redirect_url: The full redirect URL received from the authorization server Returns: Tuple of (authorization_code, state) Raises: RuntimeError: If the redirect URL is invalid or missing required parameters """ parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) # Check for error response first if "error" in query_params: error = query_params["error"][0] error_description = query_params.get("error_description", ["Unknown"])[0] raise OAuthAccessTokenError( f"Authorization failed: {error} - {error_description}" ) # Extract authorization code if "code" not in query_params: raise RuntimeError("Authorization code not found in redirect URL") authorization_code = query_params["code"][0] # Extract state parameter if "state" not in query_params: raise RuntimeError("State parameter not found in redirect URL") state = query_params["state"][0] return authorization_code, state
def _encrypt_state(self, state: dict[str, Any]) -> bytes: """ Encrypt the state dictionary with a random salt. Args: state: Dictionary containing token state Returns: Encrypted bytes with salt prefix (for dynamic encryption) or without (for fixed key) """ if not state: return b"" state_json = json.dumps(state) state_bytes = state_json.encode("utf-8") if self._encryption_key is not None: # Use fixed encryption key key = hashlib.sha256(self._encryption_key.encode()).digest() return self._encrypt_with_key(state_bytes, key) else: # Use dynamic encryption with random salt salt = secrets.token_bytes(16) key = self._generate_encryption_key(salt) encrypted_data = self._encrypt_with_key(state_bytes, key) # Return salt + encrypted data return salt + encrypted_data def _decrypt_state(self, encrypted_data: bytes) -> dict[str, Any]: """ Decrypt the state dictionary using the stored salt. Args: encrypted_data: Encrypted state data with salt prefix (for dynamic encryption) or without (for fixed key) Returns: Decrypted state dictionary """ if not encrypted_data: return {} try: if self._encryption_key is not None: # Use fixed encryption key key = hashlib.sha256(self._encryption_key.encode()).digest() decrypted_data = self._decrypt_with_key(encrypted_data, key) return json.loads(decrypted_data.decode("utf-8")) else: # Use dynamic encryption with salt prefix # Try to decrypt with old fixed salt first for backward compatibility try: old_salt = b"cruds_oauth_salt" key = self._generate_encryption_key(old_salt) decrypted_data = self._decrypt_with_key(encrypted_data, key) return json.loads(decrypted_data.decode("utf-8")) except Exception: # If old salt fails, try new format with salt prefix if len(encrypted_data) >= 16: # Extract salt (first 16 bytes) and encrypted data salt = encrypted_data[:16] actual_encrypted_data = encrypted_data[16:] # Generate key with the stored salt key = self._generate_encryption_key(salt) # Decrypt the data decrypted_data = self._decrypt_with_key( actual_encrypted_data, key ) return json.loads(decrypted_data.decode("utf-8")) else: # Data too short, can't be valid raise ValueError("Encrypted data too short") except Exception as e: logger.warning(f"Failed to decrypt state: {e}") return {} @property def _state(self) -> dict[str, Any]: """Get decrypted state.""" return self._decrypt_state(self._encrypted_state) @_state.setter def _state(self, value: dict[str, Any] | None) -> None: """Set encrypted state.""" self._encrypted_state = self._encrypt_state(value or {})
[docs] def access_token(self) -> str: if self.is_valid(): logger.debug("OAuth Token is still valid") return self._state["access_token"] logger.debug("Retrieving OAuth token") # Determine if a refresh_token needs to be used. if refresh_token := self._state.get("refresh_token"): logger.debug("Use refresh token to renew access token") fields = { "grant_type": "refresh_token", "refresh_token": refresh_token, } # Refresh token requests don't use basic auth access_token = self._make_oauth_request(fields, use_auth=False) else: fields = { "grant_type": "client_credentials", "scope": self.scope, } # Support Password Grant Type in 2.0 if self.username is not None and self.password is not None: fields["grant_type"] = "password" fields["username"] = self.username fields["password"] = self.password # Rich Authorization Request (RAR) self._add_authorization_details(fields) # Make request to the Server access_token = self._make_oauth_request(fields) # Store the token response access_token["created"] = int(time()) self._state = access_token # Check if the token is valid return self._state["access_token"] if self.is_valid() else ""
[docs] def is_valid(self) -> bool: state: dict = self._state if state == {}: return False if "access_token" not in state or state.get("token_type", "") != "Bearer": raise RuntimeError("Auth state is missing critical information") if "expires_in" not in state: return True return int(state["created"] + state["expires_in"]) > ( int(time()) + TOKEN_REFRESH_LEAD_TIME )
[docs] def clear_state(self) -> None: """ Clear the encrypted state, effectively logging out the user. """ self._encrypted_state = b"" self._pending_state = None logger.debug("OAuth state cleared")