Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,13 @@ def _is_ui_route(
"""
# this token is only used for managing the ui
allowed_routes = LiteLLMRoutes.ui_routes.value
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
):
# Do something if the current route starts with any of the allowed routes
return True
elif any(
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
for allowed_route in allowed_routes
):
return True
# Combine both checks in a single loop for efficiency
if route is not None and isinstance(route, str):
for allowed_route in allowed_routes:
if route.startswith(allowed_route):
return True
elif RouteChecks._route_matches_pattern(route=route, pattern=allowed_route):
return True
return False


Expand Down
133 changes: 43 additions & 90 deletions litellm/proxy/auth/route_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from .auth_checks_organization import _user_is_org_admin
from functools import lru_cache


class RouteChecks:
Expand All @@ -31,15 +32,11 @@ def should_call_route(route: str, valid_token: UserAPIKeyAuth):
pass

# Check if Virtual Key is allowed to call the route - Applies to all Roles
RouteChecks.is_virtual_key_allowed_to_call_route(
route=route, valid_token=valid_token
)
RouteChecks.is_virtual_key_allowed_to_call_route(route=route, valid_token=valid_token)
return True

@staticmethod
def is_virtual_key_allowed_to_call_route(
route: str, valid_token: UserAPIKeyAuth
) -> bool:
def is_virtual_key_allowed_to_call_route(route: str, valid_token: UserAPIKeyAuth) -> bool:
"""
Raises Exception if Virtual Key is not allowed to call the route
"""
Expand All @@ -54,16 +51,11 @@ def is_virtual_key_allowed_to_call_route(

# explicit check for allowed routes (exact match or prefix match)
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
if RouteChecks._route_matches_allowed_route(route=route, allowed_route=allowed_route):
return True

## check if 'allowed_route' is a field name in LiteLLMRoutes
if any(
allowed_route in LiteLLMRoutes._member_names_
for allowed_route in valid_token.allowed_routes
):
if any(allowed_route in LiteLLMRoutes._member_names_ for allowed_route in valid_token.allowed_routes):
for allowed_route in valid_token.allowed_routes:
if allowed_route in LiteLLMRoutes._member_names_:
if RouteChecks.check_route_access(
Expand All @@ -80,16 +72,12 @@ def is_virtual_key_allowed_to_call_route(
InitPassThroughEndpointHelpers,
)

if InitPassThroughEndpointHelpers.is_registered_pass_through_route(
route=route
):
if InitPassThroughEndpointHelpers.is_registered_pass_through_route(route=route):
return True

# check if wildcard pattern is allowed
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
if RouteChecks._route_matches_wildcard_pattern(route=route, pattern=allowed_route):
return True

raise Exception(
Expand Down Expand Up @@ -163,19 +151,15 @@ def non_proxy_admin_allowed_routes_check(

if RouteChecks.is_llm_api_route(route=route):
pass
elif (
route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route
elif route in LiteLLMRoutes.info_routes.value: # check if user allowed to call an info route
if route == "/key/info":
# handled by function itself
pass
elif route == "/user/info":
# check if user can access this route
query_params = request.query_params
user_id = query_params.get("user_id")
verbose_proxy_logger.debug(
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
)
verbose_proxy_logger.debug(f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}")
if user_id and user_id != valid_token.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand All @@ -200,25 +184,17 @@ def non_proxy_admin_allowed_routes_check(
_user_role=_user_role,
request_data=request_data,
)
elif (
_user_role == LitellmUserRoles.INTERNAL_USER.value
and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
)
elif _user_role == LitellmUserRoles.INTERNAL_USER.value and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
):
pass
elif _user_is_org_admin(
request_data=request_data, user_object=user_obj
) and RouteChecks.check_route_access(
elif _user_is_org_admin(request_data=request_data, user_object=user_obj) and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
):
pass
elif (
_user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
and RouteChecks.check_route_access(
route=route,
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
)
elif _user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value and RouteChecks.check_route_access(
route=route,
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
):
pass
elif RouteChecks.check_route_access(
Expand All @@ -227,28 +203,20 @@ def non_proxy_admin_allowed_routes_check(
pass
elif route.startswith("/v1/mcp/") or route.startswith("/mcp-rest/"):
pass # authN/authZ handled by api itself
elif RouteChecks.check_passthrough_route_access(
route=route, user_api_key_dict=valid_token
):
elif RouteChecks.check_passthrough_route_access(route=route, user_api_key_dict=valid_token):
pass
elif valid_token.allowed_routes is not None:
# check if route is in allowed_routes (exact match or prefix match)
route_allowed = False
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
if RouteChecks._route_matches_allowed_route(route=route, allowed_route=allowed_route):
route_allowed = True
break

if not route_allowed:
RouteChecks._raise_admin_only_route_exception(
user_obj=user_obj, route=route
)
RouteChecks._raise_admin_only_route_exception(user_obj=user_obj, route=route)
else:
RouteChecks._raise_admin_only_route_exception(
user_obj=user_obj, route=route
)
RouteChecks._raise_admin_only_route_exception(user_obj=user_obj, route=route)

@staticmethod
def custom_admin_only_route_check(route: str):
Expand Down Expand Up @@ -287,9 +255,7 @@ def is_llm_api_route(route: str) -> bool:
if route in LiteLLMRoutes.anthropic_routes.value:
return True

if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
):
if RouteChecks.check_route_access(route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value):
return True

# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
Expand All @@ -298,9 +264,7 @@ def is_llm_api_route(route: str) -> bool:
# Replace placeholders with regex pattern
# placeholders are written as "/threads/{thread_id}"
if "{" in openai_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=openai_route
):
if RouteChecks._route_matches_pattern(route=route, pattern=openai_route):
return True

if RouteChecks._is_azure_openai_route(route=route):
Expand Down Expand Up @@ -356,10 +320,10 @@ def _route_matches_pattern(route: str, pattern: str) -> bool:
# Ensure route is a string before attempting regex matching
if not isinstance(route, str):
return False
pattern = re.sub(r"\{[^}]+\}", r"[^/]+", pattern)
# Anchor the pattern to match the entire string
pattern = f"^{pattern}$"
if re.match(pattern, route):

# Use LRU cache to avoid recompiling regex for repeated patterns
regex = _get_compiled_route_regex(pattern)
if regex.match(route):
return True
return False

Expand Down Expand Up @@ -447,15 +411,9 @@ def check_route_access(route: str, allowed_routes: List[str]) -> bool:
# wildcard match route is in allowed_routes
# e.g calling /anthropic/v1/messages is allowed if allowed_routes has /anthropic/*
#########################################################
wildcard_allowed_routes = [
route
for route in allowed_routes
if RouteChecks._is_wildcard_pattern(pattern=route)
]
wildcard_allowed_routes = [route for route in allowed_routes if RouteChecks._is_wildcard_pattern(pattern=route)]
for allowed_route in wildcard_allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
if RouteChecks._route_matches_wildcard_pattern(route=route, pattern=allowed_route):
return True

#########################################################
Expand All @@ -465,17 +423,14 @@ def check_route_access(route: str, allowed_routes: List[str]) -> bool:
# returns: True
#########################################################
if any( # Check pattern match
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
for allowed_route in allowed_routes
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route) for allowed_route in allowed_routes
):
return True

return False

@staticmethod
def check_passthrough_route_access(
route: str, user_api_key_dict: UserAPIKeyAuth
) -> bool:
def check_passthrough_route_access(route: str, user_api_key_dict: UserAPIKeyAuth) -> bool:
"""
Check if route is a passthrough route.
Supports both exact match and prefix match.
Expand All @@ -484,10 +439,7 @@ def check_passthrough_route_access(
team_metadata = user_api_key_dict.team_metadata or {}
if metadata is None and team_metadata is None:
return False
if (
"allowed_passthrough_routes" not in metadata
and "allowed_passthrough_routes" not in team_metadata
):
if "allowed_passthrough_routes" not in metadata and "allowed_passthrough_routes" not in team_metadata:
return False
if (
metadata.get("allowed_passthrough_routes") is None
Expand All @@ -496,16 +448,12 @@ def check_passthrough_route_access(
return False

allowed_passthrough_routes = (
metadata.get("allowed_passthrough_routes")
or team_metadata.get("allowed_passthrough_routes")
or []
metadata.get("allowed_passthrough_routes") or team_metadata.get("allowed_passthrough_routes") or []
)

# Check if route matches any allowed passthrough route (exact or prefix match)
for allowed_route in allowed_passthrough_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
if RouteChecks._route_matches_allowed_route(route=route, allowed_route=allowed_route):
return True

return False
Expand Down Expand Up @@ -554,9 +502,7 @@ def _check_proxy_admin_viewer_access(
)

# Check if this is a write operation on management routes
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.management_routes.value
):
if RouteChecks.check_route_access(route=route, allowed_routes=LiteLLMRoutes.management_routes.value):
# For management routes, only allow read operations or specific allowed updates
if route == "/user/update":
# Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
Expand Down Expand Up @@ -597,9 +543,7 @@ def _check_proxy_admin_viewer_access(
)
# Allow read operations on management routes (like /user/info, /team/info, /model/info)
return
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.admin_viewer_routes.value
):
elif RouteChecks.check_route_access(route=route, allowed_routes=LiteLLMRoutes.admin_viewer_routes.value):
# Allow access to admin viewer routes (read-only admin endpoints)
return
else:
Expand All @@ -608,3 +552,12 @@ def _check_proxy_admin_viewer_access(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)


# cache compiled regex patterns for route patterns
@lru_cache(maxsize=128)
def _get_compiled_route_regex(pattern: str) -> re.Pattern:
# Transform pattern into regex
p = re.sub(r"\{[^}]+\}", r"[^/]+", pattern)
p = f"^{p}$"
return re.compile(p)