diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d95b7bd03d6a..81f34bd9cde1 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -57,6 +57,7 @@ from .auth_checks_organization import organization_role_based_access_check from .auth_utils import get_model_from_request +from functools import lru_cache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -1829,8 +1830,7 @@ def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: bool: True if model matches the pattern, False otherwise """ if "*" in allowed_model_pattern: - pattern = f"^{allowed_model_pattern.replace('*', '.*')}$" - return bool(re.match(pattern, model)) + return bool(_compile_pattern(allowed_model_pattern).match(model)) return False @@ -1846,24 +1846,17 @@ def _model_matches_any_wildcard_pattern_in_list( - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False """ + # Split out patterns to test only wildcard ones per spec. + for allowed_model_pattern in allowed_model_list: + if _is_wildcard_pattern(allowed_model_pattern) and is_model_allowed_by_pattern(model, allowed_model_pattern): + return True - if any( - _is_wildcard_pattern(allowed_model_pattern) - and is_model_allowed_by_pattern( + for allowed_model_pattern in allowed_model_list: + if _is_wildcard_pattern(allowed_model_pattern) and _model_custom_llm_provider_matches_wildcard_pattern( model=model, allowed_model_pattern=allowed_model_pattern - ) - for allowed_model_pattern in allowed_model_list - ): - return True + ): + return True - if any( - _is_wildcard_pattern(allowed_model_pattern) - and _model_custom_llm_provider_matches_wildcard_pattern( - model=model, allowed_model_pattern=allowed_model_pattern - ) - for allowed_model_pattern in allowed_model_list - ): - return True return False @@ -1999,3 +1992,12 @@ def _can_object_call_vector_stores( ) return True + + +@lru_cache(maxsize=8192) +def _compile_pattern(allowed_model_pattern: str) -> re.Pattern: + """ + Cache compiled regex patterns for allowed model patterns containing wildcards. + """ + pattern = f"^{allowed_model_pattern.replace('*', '.*')}$" + return re.compile(pattern)