Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@

from .auth_checks_organization import organization_role_based_access_check
from .auth_utils import get_model_from_request
from functools import lru_cache

if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Expand Down Expand Up @@ -1829,8 +1830,7 @@ def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
bool: True if model matches the pattern, False otherwise
"""
if "*" in allowed_model_pattern:
pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
return bool(re.match(pattern, model))
return bool(_compile_pattern(allowed_model_pattern).match(model))

return False

Expand All @@ -1846,24 +1846,17 @@ def _model_matches_any_wildcard_pattern_in_list(
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
"""
# Split out patterns to test only wildcard ones per spec.
for allowed_model_pattern in allowed_model_list:
if _is_wildcard_pattern(allowed_model_pattern) and is_model_allowed_by_pattern(model, allowed_model_pattern):
return True

if any(
_is_wildcard_pattern(allowed_model_pattern)
and is_model_allowed_by_pattern(
for allowed_model_pattern in allowed_model_list:
if _is_wildcard_pattern(allowed_model_pattern) and _model_custom_llm_provider_matches_wildcard_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
):
return True
):
return True

if any(
_is_wildcard_pattern(allowed_model_pattern)
and _model_custom_llm_provider_matches_wildcard_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
):
return True

return False

Expand Down Expand Up @@ -1999,3 +1992,12 @@ def _can_object_call_vector_stores(
)

return True


@lru_cache(maxsize=8192)
def _compile_pattern(allowed_model_pattern: str) -> re.Pattern:
"""
Cache compiled regex patterns for allowed model patterns containing wildcards.
"""
pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
return re.compile(pattern)