-
Notifications
You must be signed in to change notification settings - Fork 7
APP-8642 : Add OAuth to pyatlan #767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0128272
a76fc87
8c0b9c0
94255f2
073974f
ca05037
eee84f0
3c6d7f9
73418cf
7b1e7b3
984c10c
5d46a8e
fb1598d
028feb9
d1b3311
bae36f0
da044c1
c4c96f8
00ea476
fb69b17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,8 @@ | ||
| ATLAN_BASE_URL=your_tenant_base_url | ||
|
|
||
| #API KEY based authentication | ||
| ATLAN_API_KEY=your_api_key | ||
|
|
||
| #OAuth based authentication | ||
| ATLAN_OAUTH_CLIENT_ID=your_oauth_client_id | ||
| ATLAN_OAUTH_CLIENT_SECRET=your_oauth_client_secret | ||
Aryamanz29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from contextlib import _AsyncGeneratorContextManager | ||
| from http import HTTPStatus | ||
| from types import SimpleNamespace | ||
| from typing import Optional | ||
| from typing import Any, Optional | ||
|
|
||
| import httpx | ||
| from httpx_retries.retry import Retry | ||
|
|
@@ -41,6 +41,7 @@ | |
| from pyatlan.client.aio.file import AsyncFileClient | ||
| from pyatlan.client.aio.group import AsyncGroupClient | ||
| from pyatlan.client.aio.impersonate import AsyncImpersonationClient | ||
| from pyatlan.client.aio.oauth import AsyncOAuthTokenManager | ||
| from pyatlan.client.aio.open_lineage import AsyncOpenLineageClient | ||
| from pyatlan.client.aio.query import AsyncQueryClient | ||
| from pyatlan.client.aio.role import AsyncRoleClient | ||
|
|
@@ -90,6 +91,7 @@ class AsyncAtlanClient(AtlanClient): | |
| """ | ||
|
|
||
| _async_session: Optional[httpx.AsyncClient] = PrivateAttr(default=None) | ||
| _async_oauth_token_manager: Optional[Any] = PrivateAttr(default=None) | ||
| _async_admin_client: Optional[AsyncAdminClient] = PrivateAttr(default=None) | ||
| _async_asset_client: Optional[AsyncAssetClient] = PrivateAttr(default=None) | ||
| _async_audit_client: Optional[AsyncAuditClient] = PrivateAttr(default=None) | ||
|
|
@@ -133,6 +135,31 @@ class AsyncAtlanClient(AtlanClient): | |
| def __init__(self, **kwargs): | ||
| # Initialize sync client (handles all validation, env vars, etc.) | ||
| super().__init__(**kwargs) | ||
| if self.oauth_client_id and self.oauth_client_secret and self.api_key is None: | ||
| LOGGER.debug( | ||
| "API Key not provided. Using Async OAuth flow for authentication" | ||
| ) | ||
| if self._oauth_token_manager: | ||
| LOGGER.debug("Sync oauth flow open. Closing it for Async oauth flow") | ||
| self._oauth_token_manager.close() | ||
| self._oauth_token_manager = None | ||
|
|
||
| final_base_url = self.base_url or os.environ.get( | ||
| "ATLAN_BASE_URL", "INTERNAL" | ||
| ) | ||
| final_oauth_client_id = self.oauth_client_id or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_ID" | ||
| ) | ||
| final_oauth_client_secret = self.oauth_client_secret or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_SECRET" | ||
| ) | ||
| self._async_oauth_token_manager = AsyncOAuthTokenManager( | ||
| base_url=final_base_url, | ||
| client_id=final_oauth_client_id, | ||
| client_secret=final_oauth_client_secret, | ||
| connect_timeout=self.connect_timeout, | ||
| read_timeout=self.read_timeout, | ||
| ) | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Build proxy/SSL configuration (reuse from sync client) | ||
| transport_kwargs = self._build_transport_proxy_config(kwargs) | ||
|
|
@@ -438,6 +465,9 @@ async def _create_params( | |
| Async version of _create_params that uses AsyncAtlanRequest for AtlanObject instances. | ||
| """ | ||
| params = copy.deepcopy(self._request_params) | ||
| if self._async_oauth_token_manager: | ||
| token = await self._async_oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| params["headers"]["Accept"] = api.consumes | ||
| params["headers"]["content-type"] = api.produces | ||
| if query_params is not None: | ||
|
|
@@ -687,7 +717,7 @@ async def _handle_error_response( | |
|
|
||
| # Retry with impersonation (if _user_id is present) on authentication failure | ||
| if ( | ||
| self._user_id | ||
| (self._user_id or self._async_oauth_token_manager) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we update this condition? (can you please add reason for it?) 🙏 |
||
| and not self._401_has_retried.get() | ||
| and response.status_code | ||
| == ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code | ||
|
|
@@ -746,6 +776,21 @@ async def _handle_401_token_refresh( | |
| Async version of token refresh and retry logic. | ||
| Handles token refresh and retries the API request upon a 401 Unauthorized response. | ||
| """ | ||
| if self._async_oauth_token_manager: | ||
| await self._async_oauth_token_manager.invalidate_token() | ||
| token = await self._async_oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| self._401_has_retried.set(True) | ||
| LOGGER.debug("Successfully refreshed OAuth token after 401.") | ||
| return await self._call_api_internal( | ||
| api, | ||
| path, | ||
| params, | ||
| binary_data=binary_data, | ||
| download_file_path=download_file_path, | ||
| text_response=text_response, | ||
| ) | ||
|
Comment on lines
+779
to
+792
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add comments for this change? 🙏 |
||
|
|
||
| try: | ||
| # Use sync impersonation call since it's a quick API call | ||
| new_token = await self.impersonate.user(user_id=self._user_id) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright 2025 Atlan Pte. Ltd. | ||
| import asyncio | ||
| import time | ||
| from typing import Optional | ||
| from urllib.parse import urljoin | ||
|
|
||
| import httpx | ||
| from authlib.oauth2.rfc6749 import OAuth2Token | ||
|
|
||
| from pyatlan.client.constants import GET_OAUTH_CLIENT | ||
| from pyatlan.utils import API | ||
|
|
||
|
|
||
| class AsyncOAuthTokenManager: | ||
| """ | ||
| Manages OAuth tokens for asynchronous HTTP clients. | ||
|
|
||
| :param base_url: Base URL of the Atlan tenant. | ||
| :param client_id: OAuth client ID. | ||
| :param client_secret: OAuth client secret. | ||
| :param http_client: Optional asynchronous HTTP client to use. | ||
| :param connect_timeout: Timeout for establishing connections. | ||
| :param read_timeout: Timeout for reading data. | ||
| :param write_timeout: Timeout for writing data. | ||
| :param pool_timeout: Timeout for acquiring a connection from the pool. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| base_url: str, | ||
| client_id: str, | ||
| client_secret: str, | ||
| http_client: Optional[httpx.AsyncClient] = None, | ||
| connect_timeout: float = 30.0, | ||
| read_timeout: float = 900.0, | ||
| write_timeout: float = 30.0, | ||
| pool_timeout: float = 30.0, | ||
| ): | ||
| self.base_url = base_url | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.token_url = self._create_path(GET_OAUTH_CLIENT) | ||
| self._lock = asyncio.Lock() | ||
| self._http_client = http_client or httpx.AsyncClient( | ||
| timeout=httpx.Timeout( | ||
| connect=connect_timeout, | ||
| read=read_timeout, | ||
| write=write_timeout, | ||
| pool=pool_timeout, | ||
| ) | ||
| ) | ||
| self._token: Optional[OAuth2Token] = None | ||
| self._owns_client = http_client is None | ||
|
|
||
| async def get_token(self) -> str: | ||
| """ | ||
| Retrieves a valid OAuth token, refreshing it if necessary. | ||
| """ | ||
| async with self._lock: | ||
| if self._token and not self._token.is_expired(): | ||
| return str(self._token["access_token"]) | ||
|
|
||
| response = await self._http_client.post( | ||
| self.token_url, | ||
| json={ | ||
| "clientId": self.client_id, | ||
| "clientSecret": self.client_secret, | ||
| }, | ||
| headers={"Content-Type": "application/json"}, | ||
| ) | ||
| response.raise_for_status() | ||
|
|
||
| data = response.json() | ||
| access_token = data.get("accessToken") or data.get("access_token") | ||
|
|
||
| if not access_token: | ||
| raise ValueError( | ||
| f"OAuth token response missing 'accessToken' field. " | ||
| f"Response keys: {list(data.keys())}" | ||
| ) | ||
|
|
||
| expires_in = data.get("expiresIn") or data.get("expires_in", 600) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self._token = OAuth2Token( | ||
| { | ||
| "access_token": access_token, | ||
| "token_type": data.get("tokenType") | ||
| or data.get("token_type", "Bearer"), | ||
| "expires_in": expires_in, | ||
| "expires_at": int(time.time()) + expires_in, | ||
| } | ||
| ) | ||
|
|
||
| return access_token | ||
|
|
||
| async def invalidate_token(self): | ||
| """ | ||
| Invalidates the current OAuth token. | ||
| """ | ||
| async with self._lock: | ||
| self._token = None | ||
|
|
||
| def _create_path(self, api: API): | ||
| """ | ||
| Creates the full URL for the given API endpoint. | ||
| """ | ||
| if self.base_url == "INTERNAL": | ||
| return urljoin(api.endpoint.service, api.path) | ||
| else: | ||
| base_with_prefix = urljoin(self.base_url, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
|
|
||
| async def aclose(self): | ||
| """ | ||
| Closes the underlying HTTP client if owned by this manager. | ||
| """ | ||
| if self._owns_client: | ||
| await self._http_client.aclose() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,7 @@ | |
| from pyatlan.client.file import FileClient | ||
| from pyatlan.client.group import GroupClient | ||
| from pyatlan.client.impersonate import ImpersonationClient | ||
| from pyatlan.client.oauth import OAuthTokenManager | ||
| from pyatlan.client.open_lineage import OpenLineageClient | ||
| from pyatlan.client.query import QueryClient | ||
| from pyatlan.client.role import RoleClient | ||
|
|
@@ -127,7 +128,9 @@ def log_response(response, *args, **kwargs): | |
|
|
||
| class AtlanClient(BaseSettings): | ||
| base_url: Union[Literal["INTERNAL"], HttpUrl] | ||
| api_key: str | ||
| api_key: Optional[str] = None | ||
| oauth_client_id: Optional[str] = None | ||
| oauth_client_secret: Optional[str] = None | ||
| connect_timeout: float = 30.0 # 30 secs | ||
| read_timeout: float = 900.0 # 15 mins | ||
| retry: Retry = DEFAULT_RETRY | ||
|
|
@@ -137,6 +140,7 @@ class AtlanClient(BaseSettings): | |
| _session: httpx.Client = PrivateAttr() | ||
| _request_params: dict = PrivateAttr() | ||
| _user_id: Optional[str] = PrivateAttr(default=None) | ||
| _oauth_token_manager: Optional[Any] = PrivateAttr(default=None) | ||
| _workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None) | ||
| _credential_client: Optional[CredentialClient] = PrivateAttr(default=None) | ||
| _admin_client: Optional[AdminClient] = PrivateAttr(default=None) | ||
|
|
@@ -172,11 +176,33 @@ class Config: | |
|
|
||
| def __init__(self, **data): | ||
| super().__init__(**data) | ||
| self._request_params = ( | ||
| {"headers": {"authorization": f"Bearer {self.api_key}"}} | ||
| if self.api_key and self.api_key.strip() | ||
| else {"headers": {}} | ||
| ) | ||
|
|
||
| if self.oauth_client_id and self.oauth_client_secret and self.api_key is None: | ||
| LOGGER.debug("API KEY not provided. Using OAuth flow for authentication") | ||
|
|
||
| final_base_url = self.base_url or os.environ.get( | ||
| "ATLAN_BASE_URL", "INTERNAL" | ||
| ) | ||
| final_oauth_client_id = self.oauth_client_id or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_ID" | ||
| ) | ||
| final_oauth_client_secret = self.oauth_client_secret or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_SECRET" | ||
| ) | ||
| self._oauth_token_manager = OAuthTokenManager( | ||
| base_url=final_base_url, | ||
| client_id=final_oauth_client_id, | ||
| client_secret=final_oauth_client_secret, | ||
| connect_timeout=self.connect_timeout, | ||
| read_timeout=self.read_timeout, | ||
| ) | ||
| self._request_params = {"headers": {}} | ||
| else: | ||
| self._request_params = ( | ||
| {"headers": {"authorization": f"Bearer {self.api_key}"}} | ||
| if self.api_key and self.api_key.strip() | ||
| else {"headers": {}} | ||
| ) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Sync Client OAuth Leaks HTTP ResourcesThe sync |
||
|
|
||
| # Build proxy/SSL configuration with environment variable fallback | ||
| transport_kwargs = self._build_transport_proxy_config(data) | ||
|
|
@@ -691,7 +717,7 @@ def _call_api_internal( | |
| # Retry with impersonation (if _user_id is present) | ||
| # on authentication failure (token may have expired) | ||
| if ( | ||
| self._user_id | ||
| (self._user_id or self._oauth_token_manager) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we update this condition? (can you please add reason for it?) 🙏 |
||
| and not self._401_has_retried.get() | ||
| and response.status_code | ||
| == ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code | ||
|
|
@@ -813,6 +839,9 @@ def _create_params( | |
| self, api: API, query_params, request_obj, exclude_unset: bool = True | ||
| ): | ||
| params = copy.deepcopy(self._request_params) | ||
| if self._oauth_token_manager: | ||
| token = self._oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| params["headers"]["Accept"] = api.consumes | ||
| params["headers"]["content-type"] = api.produces | ||
| if query_params is not None: | ||
|
|
@@ -846,6 +875,21 @@ def _handle_401_token_refresh( | |
|
|
||
| returns: HTTP response received after retrying the request with the refreshed token | ||
| """ | ||
| if self._oauth_token_manager: | ||
| self._oauth_token_manager.invalidate_token() | ||
| token = self._oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| self._401_has_retried.set(True) | ||
| LOGGER.debug("Successfully refreshed OAuth token after 401.") | ||
| return self._call_api_internal( | ||
| api, | ||
| path, | ||
| params, | ||
| binary_data=binary_data, | ||
| download_file_path=download_file_path, | ||
| text_response=text_response, | ||
| ) | ||
|
|
||
|
Comment on lines
+878
to
+892
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add comments for this change? 🙏 |
||
| try: | ||
| new_token = self.impersonate.user(user_id=self._user_id) | ||
| except Exception as e: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.