Skip to content

Commit e7a10d1

Browse files
committed
feat: adds util function to get available first factors
- Moves is_recipe_initialized to supertokens.asyncio - Cleans up supertokens __init__ file to reduce redundancy - Adds test to ensure FactorIds class and method are in sync ref: supertokens/supertokens-node#1021
1 parent 09b4216 commit e7a10d1

File tree

6 files changed

+126
-46
lines changed

6 files changed

+126
-46
lines changed

supertokens_python/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from typing import Any, Dict, List, Optional
15+
from typing import List, Optional
1616

1717
from typing_extensions import Literal
1818

19-
from supertokens_python.framework.request import BaseRequest
2019
from supertokens_python.recipe_module import RecipeModule
2120
from supertokens_python.types import RecipeUserId
2221

@@ -30,6 +29,7 @@
3029
SupertokensExperimentalConfig,
3130
SupertokensInputConfig,
3231
SupertokensPublicConfig,
32+
get_request_from_user_context,
3333
)
3434

3535
# Some Pydantic models need a rebuild to resolve ForwardRefs
@@ -69,19 +69,10 @@ def get_all_cors_headers() -> List[str]:
6969
return Supertokens.get_instance().get_all_cors_headers()
7070

7171

72-
def get_request_from_user_context(
73-
user_context: Optional[Dict[str, Any]],
74-
) -> Optional[BaseRequest]:
75-
return Supertokens.get_instance().get_request_from_user_context(user_context)
76-
77-
7872
def convert_to_recipe_user_id(user_id: str) -> RecipeUserId:
7973
return RecipeUserId(user_id)
8074

8175

82-
is_recipe_initialized = Supertokens.is_recipe_initialized
83-
84-
8576
__all__ = [
8677
"AppInfo",
8778
"InputAppInfo",
@@ -95,5 +86,4 @@ def convert_to_recipe_user_id(user_id: str) -> RecipeUserId:
9586
"get_all_cors_headers",
9687
"get_request_from_user_context",
9788
"init",
98-
"is_recipe_initialized",
9989
]

supertokens_python/asyncio/__init__.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Any, Dict, List, Optional, Union
1515

1616
from supertokens_python import Supertokens
17+
from supertokens_python.exceptions import BadInputError
1718
from supertokens_python.interfaces import (
1819
CreateUserIdMappingOkResult,
1920
DeleteUserIdMappingOkResult,
@@ -26,8 +27,9 @@
2627
)
2728
from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult
2829
from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe
30+
from supertokens_python.recipe.session.interfaces import SessionContainer
2931
from supertokens_python.types import User
30-
from supertokens_python.types.base import AccountInfoInput
32+
from supertokens_python.types.base import AccountInfoInput, UserContext
3133

3234

3335
async def get_users_oldest_first(
@@ -172,3 +174,44 @@ async def list_users_by_account_info(
172174
do_union_of_account_info,
173175
user_context,
174176
)
177+
178+
179+
# Async not really required, but keeping for consistency
180+
async def is_recipe_initialized(recipe_id: str) -> bool:
181+
"""
182+
Check if a recipe is initialized.
183+
:param recipe_id: The ID of the recipe to check.
184+
:return: Whether the recipe is initialized.
185+
"""
186+
return any(
187+
recipe.get_recipe_id() == recipe_id
188+
for recipe in Supertokens.get_instance().recipe_modules
189+
)
190+
191+
192+
async def get_available_first_factors(
193+
tenant_id: str,
194+
session: Optional[SessionContainer],
195+
user_context: Optional[UserContext],
196+
):
197+
from supertokens_python.auth_utils import (
198+
filter_out_invalid_first_factors_or_throw_if_all_are_invalid,
199+
)
200+
from supertokens_python.recipe.multifactorauth.types import FactorIds
201+
202+
available_first_factors: List[str] = []
203+
204+
try:
205+
available_first_factors = (
206+
await filter_out_invalid_first_factors_or_throw_if_all_are_invalid(
207+
factor_ids=FactorIds.get_all_factors(),
208+
tenant_id=tenant_id,
209+
has_session=session is not None,
210+
user_context=user_context if user_context is not None else {},
211+
)
212+
)
213+
except BadInputError:
214+
# All the factors were invalid, so we let it pass through and return the empty list
215+
pass
216+
217+
return available_first_factors

supertokens_python/recipe/multifactorauth/types.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,27 @@ class NormalisedMultiFactorAuthConfig(
6161

6262

6363
class FactorIds:
64-
EMAILPASSWORD: Literal["emailpassword"] = "emailpassword"
65-
OTP_EMAIL: Literal["otp-email"] = "otp-email"
66-
OTP_PHONE: Literal["otp-phone"] = "otp-phone"
67-
LINK_EMAIL: Literal["link-email"] = "link-email"
68-
LINK_PHONE: Literal["link-phone"] = "link-phone"
69-
THIRDPARTY: Literal["thirdparty"] = "thirdparty"
70-
TOTP: Literal["totp"] = "totp"
71-
WEBAUTHN: Literal["webauthn"] = "webauthn"
64+
EMAILPASSWORD = "emailpassword"
65+
OTP_EMAIL = "otp-email"
66+
OTP_PHONE = "otp-phone"
67+
LINK_EMAIL = "link-email"
68+
LINK_PHONE = "link-phone"
69+
THIRDPARTY = "thirdparty"
70+
TOTP = "totp"
71+
WEBAUTHN = "webauthn"
72+
73+
@staticmethod
74+
def get_all_factors():
75+
return [
76+
FactorIds.EMAILPASSWORD,
77+
FactorIds.OTP_EMAIL,
78+
FactorIds.OTP_PHONE,
79+
FactorIds.LINK_EMAIL,
80+
FactorIds.LINK_PHONE,
81+
FactorIds.THIRDPARTY,
82+
FactorIds.TOTP,
83+
FactorIds.WEBAUTHN,
84+
]
7285

7386

7487
class FactorIdsAndType:

supertokens_python/supertokens.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SuperTokensPlugin,
4343
SuperTokensPublicPlugin,
4444
)
45+
from supertokens_python.types.base import UserContext
4546
from supertokens_python.types.response import CamelCaseBaseModel
4647

4748
from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT
@@ -181,13 +182,13 @@ def __init__(
181182
self.mode = mode
182183

183184
def get_top_level_website_domain(
184-
self, request: Optional[BaseRequest], user_context: Dict[str, Any]
185+
self, request: Optional[BaseRequest], user_context: UserContext
185186
) -> str:
186187
return get_top_level_domain_for_same_site_resolution(
187188
self.get_origin(request, user_context).get_as_string_dangerous()
188189
)
189190

190-
def get_origin(self, request: Optional[BaseRequest], user_context: Dict[str, Any]):
191+
def get_origin(self, request: Optional[BaseRequest], user_context: UserContext):
191192
origin = self.__origin
192193
if origin is None:
193194
origin = self.__website_domain
@@ -211,7 +212,7 @@ def defaultImpl(o: Any):
211212

212213

213214
def manage_session_post_response(
214-
session: SessionContainer, response: BaseResponse, user_context: Dict[str, Any]
215+
session: SessionContainer, response: BaseResponse, user_context: UserContext
215216
):
216217
# Something similar happens in handle_error of session/recipe.py
217218
for mutator in session.response_mutators:
@@ -571,7 +572,7 @@ async def get_user_count(
571572
self,
572573
include_recipe_ids: Union[None, List[str]],
573574
tenant_id: Optional[str] = None,
574-
user_context: Optional[Dict[str, Any]] = None,
575+
user_context: Optional[UserContext] = None,
575576
) -> int:
576577
querier = Querier.get_instance(None)
577578
include_recipe_ids_str = None
@@ -595,7 +596,7 @@ async def create_user_id_mapping(
595596
external_user_id: str,
596597
external_user_id_info: Optional[str],
597598
force: Optional[bool],
598-
user_context: Optional[Dict[str, Any]],
599+
user_context: Optional[UserContext],
599600
) -> Union[
600601
CreateUserIdMappingOkResult,
601602
UnknownSupertokensUserIDError,
@@ -635,7 +636,7 @@ async def get_user_id_mapping(
635636
self,
636637
user_id: str,
637638
user_id_type: Optional[UserIDTypes],
638-
user_context: Optional[Dict[str, Any]],
639+
user_context: Optional[UserContext],
639640
) -> Union[GetUserIdMappingOkResult, UnknownMappingError]:
640641
querier = Querier.get_instance(None)
641642

@@ -670,7 +671,7 @@ async def delete_user_id_mapping(
670671
user_id: str,
671672
user_id_type: Optional[UserIDTypes],
672673
force: Optional[bool],
673-
user_context: Optional[Dict[str, Any]],
674+
user_context: Optional[UserContext],
674675
) -> DeleteUserIdMappingOkResult:
675676
querier = Querier.get_instance(None)
676677

@@ -702,7 +703,7 @@ async def update_or_delete_user_id_mapping_info(
702703
user_id: str,
703704
user_id_type: Optional[UserIDTypes],
704705
external_user_id_info: Optional[str],
705-
user_context: Optional[Dict[str, Any]],
706+
user_context: Optional[UserContext],
706707
) -> Union[UpdateOrDeleteUserIdMappingInfoOkResult, UnknownMappingError]:
707708
querier = Querier.get_instance(None)
708709

@@ -728,7 +729,7 @@ async def update_or_delete_user_id_mapping_info(
728729
raise_general_exception("Please upgrade the SuperTokens core to >= 3.15.0")
729730

730731
async def middleware(
731-
self, request: BaseRequest, response: BaseResponse, user_context: Dict[str, Any]
732+
self, request: BaseRequest, response: BaseResponse, user_context: UserContext
732733
) -> Union[BaseResponse, None]:
733734
from supertokens_python.recipe.session.recipe import SessionRecipe
734735

@@ -901,7 +902,7 @@ async def handle_supertokens_error(
901902
request: BaseRequest,
902903
err: Exception,
903904
response: BaseResponse,
904-
user_context: Dict[str, Any],
905+
user_context: UserContext,
905906
) -> Optional[BaseResponse]:
906907
log_debug_message("errorHandler: Started")
907908
log_debug_message(
@@ -929,7 +930,7 @@ async def handle_supertokens_error(
929930

930931
def get_request_from_user_context(
931932
self,
932-
user_context: Optional[Dict[str, Any]] = None,
933+
user_context: Optional[UserContext] = None,
933934
) -> Optional[BaseRequest]:
934935
if user_context is None:
935936
return None
@@ -942,20 +943,8 @@ def get_request_from_user_context(
942943

943944
return user_context.get("_default", {}).get("request")
944945

945-
@staticmethod
946-
def is_recipe_initialized(recipe_id: str) -> bool:
947-
"""
948-
Check if a recipe is initialized.
949-
:param recipe_id: The ID of the recipe to check.
950-
:return: Whether the recipe is initialized.
951-
"""
952-
return any(
953-
recipe.get_recipe_id() == recipe_id
954-
for recipe in Supertokens.get_instance().recipe_modules
955-
)
956-
957946

958947
def get_request_from_user_context(
959-
user_context: Optional[Dict[str, Any]],
948+
user_context: Optional[UserContext],
960949
) -> Optional[BaseRequest]:
961950
return Supertokens.get_instance().get_request_from_user_context(user_context)

supertokens_python/syncio/__init__.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
UserIdMappingAlreadyExistsError,
2626
UserIDTypes,
2727
)
28+
from supertokens_python.recipe.session.interfaces import SessionContainer
2829
from supertokens_python.types import User
29-
from supertokens_python.types.base import AccountInfoInput
30+
from supertokens_python.types.base import AccountInfoInput, UserContext
3031

3132

3233
def get_users_oldest_first(
@@ -178,3 +179,32 @@ def list_users_by_account_info(
178179
tenant_id, account_info, do_union_of_account_info, user_context
179180
)
180181
)
182+
183+
184+
def is_recipe_initialized(recipe_id: str) -> bool:
185+
"""
186+
Check if a recipe is initialized.
187+
:param recipe_id: The ID of the recipe to check.
188+
:return: Whether the recipe is initialized.
189+
"""
190+
from supertokens_python.asyncio import (
191+
is_recipe_initialized as async_is_recipe_initialized,
192+
)
193+
194+
return sync(async_is_recipe_initialized(recipe_id))
195+
196+
197+
def get_available_first_factors(
198+
tenant_id: str,
199+
session: Optional[SessionContainer],
200+
user_context: Optional[UserContext],
201+
):
202+
from supertokens_python.asyncio import (
203+
get_available_first_factors as async_get_available_first_factors,
204+
)
205+
206+
return sync(
207+
async_get_available_first_factors(
208+
tenant_id=tenant_id, session=session, user_context=user_context
209+
)
210+
)

tests/test_mfa.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from supertokens_python.recipe.multifactorauth.types import FactorIds
2+
3+
4+
def test_get_all_factors():
5+
"""Test that FactorIds.get_all_factors returns all factors defined in FactorIds class."""
6+
factors_from_dict: list[str] = []
7+
for k, v in FactorIds.__dict__.items():
8+
if (
9+
(not k.startswith("__") or not k.endswith("__"))
10+
and not k.startswith("<")
11+
and isinstance(v, str)
12+
):
13+
factors_from_dict.append(v)
14+
15+
assert factors_from_dict == FactorIds.get_all_factors()

0 commit comments

Comments
 (0)