Skip to content

Commit e6bbf3c

Browse files
fix: app_roles missing from jwt payload (#16448)
* fix: jwt app_roles missing * add test
1 parent 086e557 commit e6bbf3c

File tree

2 files changed

+94
-55
lines changed

2 files changed

+94
-55
lines changed

litellm/proxy/management_endpoints/ui_sso.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,12 @@ def apply_user_info_values_to_sso_user_defined_values(
564564
if user_info is not None and user_info.user_id is not None:
565565
user_defined_values["user_id"] = user_info.user_id
566566

567-
if user_info is None or user_info.user_role is None:
568-
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
569-
else:
570-
user_defined_values["user_role"] = user_info.user_role
567+
# Check if user_role already exists in user_defined_values (from JWT/SSO response)
568+
if user_defined_values.get("user_role") is None:
569+
if user_info is None or user_info.user_role is None:
570+
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
571+
else:
572+
user_defined_values["user_role"] = user_info.user_role
571573

572574
# Preserve the user's existing models from the database
573575
if user_info is not None and hasattr(user_info, "models") and user_info.models:
@@ -1581,7 +1583,7 @@ async def get_redirect_response_from_openid( # noqa: PLR0915
15811583
user_id=user_id,
15821584
user_email=user_email,
15831585
max_budget=max_internal_user_budget,
1584-
user_role=None,
1586+
user_role=user_role,
15851587
budget_duration=internal_user_budget_duration,
15861588
)
15871589

tests/test_litellm/proxy/management_endpoints/test_entraid_app_roles.py

Lines changed: 87 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,92 @@
1-
"""
2-
Unit tests for EntraID app roles JWT claim extraction.
3-
4-
This module tests the get_app_roles_from_id_token method to ensure it correctly
5-
extracts app roles from Microsoft EntraID JWT tokens and prevents regressions.
6-
"""
7-
8-
import pytest
91
import jwt
102

113
from litellm.proxy.management_endpoints.ui_sso import MicrosoftSSOHandler
4+
from litellm.proxy.management_endpoints.types import get_litellm_user_role
5+
from litellm.proxy._types import LitellmUserRoles
6+
7+
8+
def test_extracts_proxy_admin_role_from_jwt():
9+
"""Ensure supported app roles like 'proxy_admin' are extracted from the id_token."""
10+
payload = {
11+
"sub": "user123",
12+
"email": "[email protected]",
13+
"app_roles": ["proxy_admin"],
14+
"aud": "litellm-app",
15+
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
16+
"exp": 9999999999,
17+
}
18+
19+
token = jwt.encode(payload, "secret", algorithm="HS256")
20+
roles = MicrosoftSSOHandler.get_app_roles_from_id_token(token)
21+
22+
assert roles == ["proxy_admin"]
23+
24+
25+
def test_maps_internal_user_role():
26+
"""Ensure internal_user role is correctly mapped to LitellmUserRoles."""
27+
payload = {
28+
"sub": "user456",
29+
"email": "[email protected]",
30+
"app_roles": ["internal_user"],
31+
"aud": "litellm-app",
32+
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
33+
"exp": 9999999999,
34+
}
35+
36+
token = jwt.encode(payload, "secret", algorithm="HS256")
37+
roles = MicrosoftSSOHandler.get_app_roles_from_id_token(token)
38+
39+
# Map to LitellmUserRoles
40+
chosen = None
41+
for r in roles:
42+
mapped = get_litellm_user_role(r)
43+
if mapped is not None:
44+
chosen = mapped
45+
break
46+
47+
assert chosen == LitellmUserRoles.INTERNAL_USER
48+
49+
50+
def test_maps_proxy_admin_viewer_role():
51+
"""Ensure proxy_admin_viewer role is correctly mapped."""
52+
payload = {
53+
"sub": "user789",
54+
"email": "[email protected]",
55+
"app_roles": ["proxy_admin_viewer"],
56+
"aud": "litellm-app",
57+
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
58+
"exp": 9999999999,
59+
}
60+
61+
token = jwt.encode(payload, "secret", algorithm="HS256")
62+
roles = MicrosoftSSOHandler.get_app_roles_from_id_token(token)
63+
64+
chosen = None
65+
for r in roles:
66+
mapped = get_litellm_user_role(r)
67+
if mapped is not None:
68+
chosen = mapped
69+
break
70+
71+
assert chosen == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
72+
73+
74+
def test_defaults_to_internal_user_viewer_when_no_role():
75+
"""Ensure default role is internal_user_viewer when no app role is present."""
76+
payload = {
77+
"sub": "user_no_role",
78+
"email": "[email protected]",
79+
"aud": "litellm-app",
80+
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
81+
"exp": 9999999999,
82+
}
83+
84+
token = jwt.encode(payload, "secret", algorithm="HS256")
85+
roles = MicrosoftSSOHandler.get_app_roles_from_id_token(token)
86+
87+
assert roles == []
1288

89+
# Default role would be internal_user_viewer
90+
default_role = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
91+
assert default_role.value == "internal_user_viewer"
1392

14-
class TestEntraIDAppRoles:
15-
"""Test EntraID app roles extraction from JWT tokens"""
16-
17-
def test_get_app_roles_from_id_token_works_without_roles(self):
18-
"""Test that JWT token works fine without app_roles claim"""
19-
# Arrange - Token without app_roles (normal user)
20-
payload = {
21-
"sub": "user123",
22-
"email": "[email protected]",
23-
"aud": "litellm-app",
24-
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
25-
"exp": 9999999999,
26-
}
27-
no_roles_token = jwt.encode(payload, "secret", algorithm="HS256")
28-
29-
# Act
30-
result = MicrosoftSSOHandler.get_app_roles_from_id_token(no_roles_token)
31-
32-
# Assert - Should return empty list, not error
33-
assert result == []
34-
assert len(result) == 0
35-
36-
def test_get_app_roles_from_id_token_assigns_roles_when_present(self):
37-
"""Test that valid app roles are properly assigned when present"""
38-
# Arrange - Token with valid roles
39-
payload = {
40-
"sub": "user123",
41-
"email": "[email protected]",
42-
"app_roles": ["proxy_admin"],
43-
"aud": "litellm-app",
44-
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
45-
"exp": 9999999999,
46-
}
47-
valid_roles_token = jwt.encode(payload, "secret", algorithm="HS256")
48-
49-
# Act
50-
result = MicrosoftSSOHandler.get_app_roles_from_id_token(valid_roles_token)
51-
52-
# Assert - Should extract the role
53-
assert result == ["proxy_admin"]
54-
assert len(result) == 1
55-
assert "proxy_admin" in result

0 commit comments

Comments
 (0)