-
Notifications
You must be signed in to change notification settings - Fork 7
APP-8642 : Add OAuth to pyatlan #767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
6dab7aa
ffee571
088ba3b
64027a4
eaf6dd7
c7903dc
c383729
c397fb2
ec8c0ec
e7bfefd
88c909a
0ff1ca6
9c08e1f
519fafb
899363b
3d77849
f5d618a
74723ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,2 +1,8 @@ | ||||||||||||||||||||||||||
| ATLAN_BASE_URL=your_tenant_base_url | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| #API KEY based authentication | ||||||||||||||||||||||||||
| ATLAN_API_KEY=your_api_key | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| #OAuth based authentication | ||||||||||||||||||||||||||
| ATLAN_OAUTH_CLIENT_ID=your_oauth_client_id | ||||||||||||||||||||||||||
| ATLAN_OAUTH_CLIENT_SECRET=your_oauth_client_secret | ||||||||||||||||||||||||||
|
Comment on lines
+3
to
+8
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from contextlib import _AsyncGeneratorContextManager | ||
| from http import HTTPStatus | ||
| from types import SimpleNamespace | ||
| from typing import Optional | ||
| from typing import Any, Optional | ||
|
|
||
| import httpx | ||
| from httpx_retries.retry import Retry | ||
|
|
@@ -90,6 +90,7 @@ class AsyncAtlanClient(AtlanClient): | |
| """ | ||
|
|
||
| _async_session: Optional[httpx.AsyncClient] = PrivateAttr(default=None) | ||
| _async_oauth_token_manager: Optional[Any] = PrivateAttr(default=None) | ||
| _async_admin_client: Optional[AsyncAdminClient] = PrivateAttr(default=None) | ||
| _async_asset_client: Optional[AsyncAssetClient] = PrivateAttr(default=None) | ||
| _async_audit_client: Optional[AsyncAuditClient] = PrivateAttr(default=None) | ||
|
|
@@ -134,6 +135,32 @@ def __init__(self, **kwargs): | |
| # Initialize sync client (handles all validation, env vars, etc.) | ||
| super().__init__(**kwargs) | ||
|
|
||
| if self.oauth_client_id and self.oauth_client_secret and self.api_key is None: | ||
| LOGGER.debug( | ||
| "API Key not provided. Using Async OAuth flow for authentication" | ||
| ) | ||
| from pyatlan.client.aio.oauth import AsyncOAuthTokenManager | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we move this import at top of the file (hoping we don't get any circular import error 🤞) |
||
|
|
||
| if self._oauth_token_manager: | ||
| LOGGER.debug("Sync oauth flow open. Closing it for Async oauth flow") | ||
| self._oauth_token_manager.close() | ||
| self._oauth_token_manager = None | ||
|
|
||
| final_base_url = self.base_url or os.environ.get( | ||
| "ATLAN_BASE_URL", "INTERNAL" | ||
| ) | ||
| final_oauth_client_id = self.oauth_client_id or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_ID" | ||
| ) | ||
| final_oauth_client_secret = self.oauth_client_secret or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_SECRET" | ||
| ) | ||
| self._async_oauth_token_manager = AsyncOAuthTokenManager( | ||
| base_url=final_base_url, | ||
| client_id=final_oauth_client_id, | ||
| client_secret=final_oauth_client_secret, | ||
| ) | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Build proxy/SSL configuration (reuse from sync client) | ||
| transport_kwargs = self._build_transport_proxy_config(kwargs) | ||
|
|
||
|
|
@@ -438,6 +465,9 @@ async def _create_params( | |
| Async version of _create_params that uses AsyncAtlanRequest for AtlanObject instances. | ||
| """ | ||
| params = copy.deepcopy(self._request_params) | ||
| if self._async_oauth_token_manager: | ||
| token = await self._async_oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| params["headers"]["Accept"] = api.consumes | ||
| params["headers"]["content-type"] = api.produces | ||
| if query_params is not None: | ||
|
|
@@ -687,7 +717,7 @@ async def _handle_error_response( | |
|
|
||
| # Retry with impersonation (if _user_id is present) on authentication failure | ||
| if ( | ||
| self._user_id | ||
| (self._user_id or self._async_oauth_token_manager) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we update this condition? (can you please add reason for it?) 🙏 |
||
| and not self._401_has_retried.get() | ||
| and response.status_code | ||
| == ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code | ||
|
|
@@ -746,6 +776,21 @@ async def _handle_401_token_refresh( | |
| Async version of token refresh and retry logic. | ||
| Handles token refresh and retries the API request upon a 401 Unauthorized response. | ||
| """ | ||
| if self._async_oauth_token_manager: | ||
| await self._async_oauth_token_manager.invalidate_token() | ||
| token = await self._async_oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| self._401_has_retried.set(True) | ||
| LOGGER.debug("Successfully refreshed OAuth token after 401.") | ||
| return await self._call_api_internal( | ||
| api, | ||
| path, | ||
| params, | ||
| binary_data=binary_data, | ||
| download_file_path=download_file_path, | ||
| text_response=text_response, | ||
| ) | ||
|
Comment on lines
+781
to
+794
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add comments for this change? 🙏 |
||
|
|
||
| try: | ||
| # Use sync impersonation call since it's a quick API call | ||
| new_token = await self.impersonate.user(user_id=self._user_id) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright 2025 Atlan Pte. Ltd. | ||
| import asyncio | ||
| import time | ||
| from typing import Optional | ||
|
|
||
| import httpx | ||
| from authlib.oauth2.rfc6749 import OAuth2Token | ||
|
|
||
| from pyatlan.client.constants import GET_OAUTH_CLIENT | ||
| from pyatlan.utils import API | ||
|
|
||
|
|
||
| class AsyncOAuthTokenManager: | ||
| def __init__( | ||
| self, | ||
| base_url: str, | ||
| client_id: str, | ||
| client_secret: str, | ||
| http_client: Optional[httpx.AsyncClient] = None, | ||
| ): | ||
| self.base_url = base_url | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.token_url = self._create_path(GET_OAUTH_CLIENT) | ||
| self._lock = asyncio.Lock() | ||
| self._http_client = http_client or httpx.AsyncClient(timeout=30.0) | ||
| self._token: Optional[OAuth2Token] = None | ||
| self._owns_client = http_client is None | ||
|
|
||
| async def get_token(self) -> str: | ||
| async with self._lock: | ||
| if self._token and not self._token.is_expired(): | ||
| return str(self._token["access_token"]) | ||
|
|
||
| response = await self._http_client.post( | ||
| self.token_url, | ||
| json={ | ||
| "clientId": self.client_id, | ||
| "clientSecret": self.client_secret, | ||
| }, | ||
| headers={"Content-Type": "application/json"}, | ||
| ) | ||
| response.raise_for_status() | ||
|
|
||
| data = response.json() | ||
| access_token = data.get("accessToken") or data.get("access_token") | ||
|
|
||
| if not access_token: | ||
| raise ValueError( | ||
| f"OAuth token response missing 'accessToken' field. " | ||
| f"Response keys: {list(data.keys())}" | ||
| ) | ||
|
|
||
| expires_in = data.get("expiresIn") or data.get("expires_in", 600) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+37
to
+61
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring missing ^^ |
||
|
|
||
| self._token = OAuth2Token( | ||
| { | ||
| "access_token": access_token, | ||
| "token_type": data.get("tokenType") | ||
| or data.get("token_type", "Bearer"), | ||
| "expires_in": expires_in, | ||
| "expires_at": int(time.time()) + expires_in, | ||
| } | ||
| ) | ||
|
|
||
| return access_token | ||
|
|
||
| async def invalidate_token(self): | ||
| async with self._lock: | ||
| self._token = None | ||
|
|
||
| def _create_path(self, api: API): | ||
| from urllib.parse import urljoin | ||
|
|
||
| if self.base_url == "INTERNAL": | ||
| base_with_prefix = urljoin(api.endpoint.service, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
| else: | ||
| base_with_prefix = urljoin(self.base_url, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
|
|
||
| async def aclose(self): | ||
| if self._owns_client: | ||
| await self._http_client.aclose() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,7 @@ | |
| from pyatlan.client.file import FileClient | ||
| from pyatlan.client.group import GroupClient | ||
| from pyatlan.client.impersonate import ImpersonationClient | ||
| from pyatlan.client.oauth import OAuthTokenManager | ||
| from pyatlan.client.open_lineage import OpenLineageClient | ||
| from pyatlan.client.query import QueryClient | ||
| from pyatlan.client.role import RoleClient | ||
|
|
@@ -127,7 +128,9 @@ def log_response(response, *args, **kwargs): | |
|
|
||
| class AtlanClient(BaseSettings): | ||
| base_url: Union[Literal["INTERNAL"], HttpUrl] | ||
| api_key: str | ||
| api_key: Optional[str] = None | ||
| oauth_client_id: Optional[str] = None | ||
| oauth_client_secret: Optional[str] = None | ||
| connect_timeout: float = 30.0 # 30 secs | ||
| read_timeout: float = 900.0 # 15 mins | ||
| retry: Retry = DEFAULT_RETRY | ||
|
|
@@ -137,6 +140,7 @@ class AtlanClient(BaseSettings): | |
| _session: httpx.Client = PrivateAttr() | ||
| _request_params: dict = PrivateAttr() | ||
| _user_id: Optional[str] = PrivateAttr(default=None) | ||
| _oauth_token_manager: Optional[Any] = PrivateAttr(default=None) | ||
| _workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None) | ||
| _credential_client: Optional[CredentialClient] = PrivateAttr(default=None) | ||
| _admin_client: Optional[AdminClient] = PrivateAttr(default=None) | ||
|
|
@@ -172,11 +176,31 @@ class Config: | |
|
|
||
| def __init__(self, **data): | ||
| super().__init__(**data) | ||
| self._request_params = ( | ||
| {"headers": {"authorization": f"Bearer {self.api_key}"}} | ||
| if self.api_key and self.api_key.strip() | ||
| else {"headers": {}} | ||
| ) | ||
|
|
||
| if self.oauth_client_id and self.oauth_client_secret and self.api_key is None: | ||
| LOGGER.debug("API KEY not provided. Using OAuth flow for authentication") | ||
|
|
||
| final_base_url = self.base_url or os.environ.get( | ||
| "ATLAN_BASE_URL", "INTERNAL" | ||
| ) | ||
| final_oauth_client_id = self.oauth_client_id or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_ID" | ||
| ) | ||
| final_oauth_client_secret = self.oauth_client_secret or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_SECRET" | ||
| ) | ||
| self._oauth_token_manager = OAuthTokenManager( | ||
| base_url=final_base_url, | ||
| client_id=final_oauth_client_id, | ||
| client_secret=final_oauth_client_secret, | ||
| ) | ||
| self._request_params = {"headers": {}} | ||
| else: | ||
| self._request_params = ( | ||
| {"headers": {"authorization": f"Bearer {self.api_key}"}} | ||
| if self.api_key and self.api_key.strip() | ||
| else {"headers": {}} | ||
| ) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Sync Client OAuth Leaks HTTP ResourcesThe sync |
||
|
|
||
| # Build proxy/SSL configuration with environment variable fallback | ||
| transport_kwargs = self._build_transport_proxy_config(data) | ||
|
|
@@ -691,7 +715,7 @@ def _call_api_internal( | |
| # Retry with impersonation (if _user_id is present) | ||
| # on authentication failure (token may have expired) | ||
| if ( | ||
| self._user_id | ||
| (self._user_id or self._oauth_token_manager) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we update this condition? (can you please add reason for it?) 🙏 |
||
| and not self._401_has_retried.get() | ||
| and response.status_code | ||
| == ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code | ||
|
|
@@ -813,6 +837,9 @@ def _create_params( | |
| self, api: API, query_params, request_obj, exclude_unset: bool = True | ||
| ): | ||
| params = copy.deepcopy(self._request_params) | ||
| if self._oauth_token_manager: | ||
| token = self._oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| params["headers"]["Accept"] = api.consumes | ||
| params["headers"]["content-type"] = api.produces | ||
| if query_params is not None: | ||
|
|
@@ -846,6 +873,21 @@ def _handle_401_token_refresh( | |
|
|
||
| returns: HTTP response received after retrying the request with the refreshed token | ||
| """ | ||
| if self._oauth_token_manager: | ||
| self._oauth_token_manager.invalidate_token() | ||
| token = self._oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| self._401_has_retried.set(True) | ||
| LOGGER.debug("Successfully refreshed OAuth token after 401.") | ||
| return self._call_api_internal( | ||
| api, | ||
| path, | ||
| params, | ||
| binary_data=binary_data, | ||
| download_file_path=download_file_path, | ||
| text_response=text_response, | ||
| ) | ||
|
|
||
|
Comment on lines
+878
to
+892
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add comments for this change? 🙏 |
||
| try: | ||
| new_token = self.impersonate.user(user_id=self._user_id) | ||
| except Exception as e: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -88,6 +88,14 @@ | |||||
| GET_WHOAMI_USER = API( | ||||||
| WHOAMI_API, HTTPMethod.GET, HTTPStatus.OK, endpoint=EndPoint.HERACLES | ||||||
| ) | ||||||
|
|
||||||
| # oauth client authentinatication | ||||||
|
||||||
| # oauth client authentinatication | |
| # oauth client authentication |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright 2025 Atlan Pte. Ltd. | ||
| import threading | ||
| import time | ||
| from typing import Optional | ||
|
|
||
| import httpx | ||
| from authlib.oauth2.rfc6749 import OAuth2Token | ||
|
|
||
| from pyatlan.client.constants import GET_OAUTH_CLIENT | ||
| from pyatlan.utils import API | ||
|
|
||
|
|
||
| class OAuthTokenManager: | ||
| def __init__( | ||
| self, | ||
| base_url: str, | ||
| client_id: str, | ||
| client_secret: str, | ||
| http_client: Optional[httpx.Client] = None, | ||
| ): | ||
| self.base_url = base_url | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.token_url = self._create_path(GET_OAUTH_CLIENT) | ||
| self._lock = threading.Lock() | ||
| self._http_client = http_client or httpx.Client(timeout=30.0) | ||
| self._token: Optional[OAuth2Token] = None | ||
| self._owns_client = http_client is None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any usage of this |
||
|
|
||
| def get_token(self) -> str: | ||
| with self._lock: | ||
| if self._token and not self._token.is_expired(): | ||
| return str(self._token["access_token"]) | ||
|
|
||
| response = self._http_client.post( | ||
| self.token_url, | ||
| json={ | ||
| "clientId": self.client_id, | ||
| "clientSecret": self.client_secret, | ||
| }, | ||
| headers={"Content-Type": "application/json"}, | ||
| ) | ||
| response.raise_for_status() | ||
|
|
||
| data = response.json() | ||
| access_token = data.get("accessToken") or data.get("access_token") | ||
|
|
||
| if not access_token: | ||
| raise ValueError( | ||
| f"OAuth token response missing 'accessToken' field. " | ||
| f"Response keys: {list(data.keys())}" | ||
| ) | ||
|
|
||
| expires_in = data.get("expiresIn") or data.get("expires_in", 600) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self._token = OAuth2Token( | ||
| { | ||
| "access_token": access_token, | ||
| "token_type": data.get("tokenType") | ||
| or data.get("token_type", "Bearer"), | ||
| "expires_in": expires_in, | ||
| "expires_at": int(time.time()) + expires_in, | ||
| } | ||
| ) | ||
|
|
||
| return access_token | ||
|
Comment on lines
+37
to
+73
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. method docstring missing |
||
|
|
||
| def invalidate_token(self): | ||
| with self._lock: | ||
| self._token = None | ||
|
|
||
| def _create_path(self, api: API): | ||
| from urllib.parse import urljoin | ||
|
|
||
| if self.base_url == "INTERNAL": | ||
| base_with_prefix = urljoin(api.endpoint.service, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
| else: | ||
| base_with_prefix = urljoin(self.base_url, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
|
|
||
| def close(self): | ||
| if self._owns_client: | ||
| self._http_client.close() | ||
Uh oh!
There was an error while loading. Please reload this page.