Skip to content

Commit 6dab7aa

Browse files
committed
Inital version for oauth_client support in client
1 parent 1540259 commit 6dab7aa

File tree

6 files changed

+244
-46
lines changed

6 files changed

+244
-46
lines changed

pyatlan/client/aio/client.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class AsyncAtlanClient(AtlanClient):
9090
"""
9191

9292
_async_session: Optional[httpx.AsyncClient] = PrivateAttr(default=None)
93+
_async_oauth_token_manager: Optional[Any] = PrivateAttr(default=None)
9394
_async_admin_client: Optional[AsyncAdminClient] = PrivateAttr(default=None)
9495
_async_asset_client: Optional[AsyncAssetClient] = PrivateAttr(default=None)
9596
_async_audit_client: Optional[AsyncAuditClient] = PrivateAttr(default=None)
@@ -131,13 +132,19 @@ class AsyncAtlanClient(AtlanClient):
131132
_async_user_cache: Optional[AsyncUserCache] = PrivateAttr(default=None)
132133

133134
def __init__(self, **kwargs):
134-
# Initialize sync client (handles all validation, env vars, etc.)
135135
super().__init__(**kwargs)
136136

137-
# Build proxy/SSL configuration (reuse from sync client)
137+
if self.oauth_client_id and self.oauth_client_secret:
138+
from pyatlan.client.aio.oauth import AsyncOAuthTokenManager
139+
140+
self._async_oauth_token_manager = AsyncOAuthTokenManager(
141+
base_url=str(self.base_url),
142+
client_id=self.oauth_client_id,
143+
client_secret=self.oauth_client_secret,
144+
)
145+
138146
transport_kwargs = self._build_transport_proxy_config(kwargs)
139147

140-
# Create async session with custom transport that supports retry and proxy
141148
self._async_session = httpx.AsyncClient(
142149
transport=PyatlanAsyncTransport(retry=self.retry, **transport_kwargs),
143150
headers={
@@ -434,17 +441,16 @@ def _api_logger(self, api, path):
434441
async def _create_params(
435442
self, api, query_params, request_obj, exclude_unset: bool = True
436443
):
437-
"""
438-
Async version of _create_params that uses AsyncAtlanRequest for AtlanObject instances.
439-
"""
440444
params = copy.deepcopy(self._request_params)
445+
if self._async_oauth_token_manager:
446+
token = await self._async_oauth_token_manager.get_token()
447+
params["headers"]["authorization"] = f"Bearer {token}"
441448
params["headers"]["Accept"] = api.consumes
442449
params["headers"]["content-type"] = api.produces
443450
if query_params is not None:
444451
params["params"] = query_params
445452
if request_obj is not None:
446453
if isinstance(request_obj, AtlanObject):
447-
# Use AsyncAtlanRequest for async retranslation
448454
async_request = AsyncAtlanRequest(instance=request_obj, client=self)
449455
params["data"] = await async_request.json()
450456
elif api.consumes == APPLICATION_ENCODED_FORM:
@@ -685,9 +691,8 @@ async def _handle_error_response(
685691
"\n".join(error_cause_details) if error_cause_details else ""
686692
)
687693

688-
# Retry with impersonation (if _user_id is present) on authentication failure
689694
if (
690-
self._user_id
695+
(self._user_id or self._async_oauth_token_manager)
691696
and not self._401_has_retried.get()
692697
and response.status_code
693698
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
@@ -742,12 +747,22 @@ async def _handle_401_token_refresh(
742747
download_file_path=None,
743748
text_response=False,
744749
):
745-
"""
746-
Async version of token refresh and retry logic.
747-
Handles token refresh and retries the API request upon a 401 Unauthorized response.
748-
"""
750+
if self._async_oauth_token_manager:
751+
await self._async_oauth_token_manager.invalidate_token()
752+
token = await self._async_oauth_token_manager.get_token()
753+
params["headers"]["authorization"] = f"Bearer {token}"
754+
self._401_has_retried.set(True)
755+
LOGGER.debug("Successfully refreshed OAuth token after 401.")
756+
return await self._call_api_internal(
757+
api,
758+
path,
759+
params,
760+
binary_data=binary_data,
761+
download_file_path=download_file_path,
762+
text_response=text_response,
763+
)
764+
749765
try:
750-
# Use sync impersonation call since it's a quick API call
751766
new_token = await self.impersonate.user(user_id=self._user_id)
752767
except Exception as e:
753768
LOGGER.debug(
@@ -763,11 +778,9 @@ async def _handle_401_token_refresh(
763778
self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}"
764779
LOGGER.debug("Successfully completed async 401 automatic token refresh.")
765780

766-
# Async retry loop to ensure token is active before retrying original request
767781
retry_count = 1
768782
while retry_count <= self.retry.total:
769783
try:
770-
# Use async typedef call to validate token
771784
response = await self.typedef.get(
772785
type_category=[AtlanTypeCategory.STRUCT]
773786
)
@@ -778,10 +791,9 @@ async def _handle_401_token_refresh(
778791
"Retrying async to get typedefs (to ensure token is active) after token refresh failed: %s",
779792
e,
780793
)
781-
await asyncio.sleep(retry_count) # Linear backoff with async sleep
794+
await asyncio.sleep(retry_count)
782795
retry_count += 1
783796

784-
# Retry the API call with the new token
785797
return await self._call_api_internal(
786798
api,
787799
path,

pyatlan/client/aio/oauth.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright 2025 Atlan Pte. Ltd.
3+
import asyncio
4+
import time
5+
from typing import Optional
6+
7+
import httpx
8+
from authlib.oauth2.rfc6749 import OAuth2Token
9+
10+
11+
class AsyncOAuthTokenManager:
12+
def __init__(
13+
self,
14+
base_url: str,
15+
client_id: str,
16+
client_secret: str,
17+
http_client: Optional[httpx.AsyncClient] = None,
18+
):
19+
self.base_url = base_url.rstrip("/")
20+
self.client_id = client_id
21+
self.client_secret = client_secret
22+
self.token_url = f"{self.base_url}/api/service/oauth-clients/token"
23+
self._lock = asyncio.Lock()
24+
self._http_client = http_client or httpx.AsyncClient(timeout=30.0)
25+
self._token: Optional[OAuth2Token] = None
26+
self._owns_client = http_client is None
27+
28+
async def get_token(self) -> str:
29+
async with self._lock:
30+
if self._token and not self._token.is_expired():
31+
return self._token["access_token"]
32+
33+
response = await self._http_client.post(
34+
self.token_url,
35+
json={
36+
"clientId": self.client_id,
37+
"clientSecret": self.client_secret,
38+
},
39+
headers={"Content-Type": "application/json"},
40+
)
41+
response.raise_for_status()
42+
43+
data = response.json()
44+
access_token = data.get("accessToken") or data.get("access_token")
45+
expires_in = data.get("expiresIn") or data.get("expires_in", 600)
46+
47+
self._token = OAuth2Token(
48+
{
49+
"access_token": access_token,
50+
"token_type": data.get("tokenType")
51+
or data.get("token_type", "Bearer"),
52+
"expires_in": expires_in,
53+
"expires_at": int(time.time()) + expires_in,
54+
}
55+
)
56+
57+
return access_token
58+
59+
async def invalidate_token(self):
60+
async with self._lock:
61+
self._token = None
62+
63+
async def aclose(self):
64+
if self._owns_client:
65+
await self._http_client.aclose()

pyatlan/client/atlan.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def log_response(response, *args, **kwargs):
127127

128128
class AtlanClient(BaseSettings):
129129
base_url: Union[Literal["INTERNAL"], HttpUrl]
130-
api_key: str
130+
api_key: Optional[str] = None
131+
oauth_client_id: Optional[str] = None
132+
oauth_client_secret: Optional[str] = None
131133
connect_timeout: float = 30.0 # 30 secs
132134
read_timeout: float = 900.0 # 15 mins
133135
retry: Retry = DEFAULT_RETRY
@@ -137,6 +139,7 @@ class AtlanClient(BaseSettings):
137139
_session: httpx.Client = PrivateAttr()
138140
_request_params: dict = PrivateAttr()
139141
_user_id: Optional[str] = PrivateAttr(default=None)
142+
_oauth_token_manager: Optional[Any] = PrivateAttr(default=None)
140143
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
141144
_credential_client: Optional[CredentialClient] = PrivateAttr(default=None)
142145
_admin_client: Optional[AdminClient] = PrivateAttr(default=None)
@@ -172,17 +175,24 @@ class Config:
172175

173176
def __init__(self, **data):
174177
super().__init__(**data)
175-
self._request_params = (
176-
{"headers": {"authorization": f"Bearer {self.api_key}"}}
177-
if self.api_key and self.api_key.strip()
178-
else {"headers": {}}
179-
)
180178

181-
# Build proxy/SSL configuration with environment variable fallback
179+
if self.oauth_client_id and self.oauth_client_secret:
180+
from pyatlan.client.oauth import OAuthTokenManager
181+
182+
self._oauth_token_manager = OAuthTokenManager(
183+
base_url=str(self.base_url),
184+
client_id=self.oauth_client_id,
185+
client_secret=self.oauth_client_secret,
186+
)
187+
self._request_params = {"headers": {}}
188+
else:
189+
self._request_params = (
190+
{"headers": {"authorization": f"Bearer {self.api_key}"}}
191+
if self.api_key and self.api_key.strip()
192+
else {"headers": {}}
193+
)
194+
182195
transport_kwargs = self._build_transport_proxy_config(data)
183-
# Configure httpx client with custom transport that supports retry and proxy
184-
# Note: We pass proxy/SSL config to the transport, not the client,
185-
# so that retry logic properly respects these settings
186196
self._session = httpx.Client(
187197
transport=PyatlanSyncTransport(retry=self.retry, **transport_kwargs),
188198
headers={
@@ -691,7 +701,7 @@ def _call_api_internal(
691701
# Retry with impersonation (if _user_id is present)
692702
# on authentication failure (token may have expired)
693703
if (
694-
self._user_id
704+
(self._user_id or self._oauth_token_manager)
695705
and not self._401_has_retried.get()
696706
and response.status_code
697707
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
@@ -813,15 +823,15 @@ def _create_params(
813823
self, api: API, query_params, request_obj, exclude_unset: bool = True
814824
):
815825
params = copy.deepcopy(self._request_params)
826+
if self._oauth_token_manager:
827+
token = self._oauth_token_manager.get_token()
828+
params["headers"]["authorization"] = f"Bearer {token}"
816829
params["headers"]["Accept"] = api.consumes
817830
params["headers"]["content-type"] = api.produces
818831
if query_params is not None:
819832
params["params"] = query_params
820833
if request_obj is not None:
821834
if isinstance(request_obj, AtlanObject):
822-
# Always use AtlanRequest, which accepts a Pydantic model instance and the client
823-
# Behind the scenes, it handles retranslation tasks—such as converting
824-
# human-readable Atlan tag names back into hashed IDs as required by the backend
825835
params["data"] = AtlanRequest(instance=request_obj, client=self).json()
826836
elif api.consumes == APPLICATION_ENCODED_FORM:
827837
params["data"] = request_obj
@@ -838,14 +848,21 @@ def _handle_401_token_refresh(
838848
download_file_path=None,
839849
text_response=False,
840850
):
841-
"""
842-
Handles token refresh and retries the API request upon a 401 Unauthorized response.
843-
1. Impersonates the user (if a user ID is available) to fetch a new token.
844-
2. Updates the authorization header with the refreshed token.
845-
3. Retries the API request with the new token.
851+
if self._oauth_token_manager:
852+
self._oauth_token_manager.invalidate_token()
853+
token = self._oauth_token_manager.get_token()
854+
params["headers"]["authorization"] = f"Bearer {token}"
855+
self._401_has_retried.set(True)
856+
LOGGER.debug("Successfully refreshed OAuth token after 401.")
857+
return self._call_api_internal(
858+
api,
859+
path,
860+
params,
861+
binary_data=binary_data,
862+
download_file_path=download_file_path,
863+
text_response=text_response,
864+
)
846865

847-
returns: HTTP response received after retrying the request with the refreshed token
848-
"""
849866
try:
850867
new_token = self.impersonate.user(user_id=self._user_id)
851868
except Exception as e:
@@ -861,11 +878,6 @@ def _handle_401_token_refresh(
861878
self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}"
862879
LOGGER.debug("Successfully completed 401 automatic token refresh.")
863880

864-
# Added a retry loop to ensure a token is active before retrying original request
865-
# This helps ensure that when we fetch typedefs using the new token,
866-
# the backend has fully recognized the token as valid.
867-
# Without this delay, we occasionally get an empty response `[]` from the API,
868-
# likely because the backend hasn’t fully propagated token validity yet.
869881
import time
870882

871883
retry_count = 1
@@ -879,10 +891,9 @@ def _handle_401_token_refresh(
879891
"Retrying to get typedefs (to ensure token is active) after token refresh failed: %s",
880892
e,
881893
)
882-
time.sleep(retry_count) # Linear backoff
894+
time.sleep(retry_count)
883895
retry_count += 1
884896

885-
# Retry the API call with the new token
886897
return self._call_api_internal(
887898
api,
888899
path,

pyatlan/client/oauth.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright 2025 Atlan Pte. Ltd.
3+
import threading
4+
import time
5+
from typing import Optional
6+
7+
import httpx
8+
from authlib.oauth2.rfc6749 import OAuth2Token
9+
10+
11+
class OAuthTokenManager:
12+
def __init__(
13+
self,
14+
base_url: str,
15+
client_id: str,
16+
client_secret: str,
17+
http_client: Optional[httpx.Client] = None,
18+
):
19+
self.base_url = base_url.rstrip("/")
20+
self.client_id = client_id
21+
self.client_secret = client_secret
22+
self.token_url = f"{self.base_url}/api/service/oauth-clients/token"
23+
self._lock = threading.Lock()
24+
self._http_client = http_client or httpx.Client(timeout=30.0)
25+
self._token: Optional[OAuth2Token] = None
26+
self._owns_client = http_client is None
27+
28+
def get_token(self) -> str:
29+
with self._lock:
30+
if self._token and not self._token.is_expired():
31+
return self._token["access_token"]
32+
33+
response = self._http_client.post(
34+
self.token_url,
35+
json={
36+
"clientId": self.client_id,
37+
"clientSecret": self.client_secret,
38+
},
39+
headers={"Content-Type": "application/json"},
40+
)
41+
response.raise_for_status()
42+
43+
data = response.json()
44+
access_token = data.get("accessToken") or data.get("access_token")
45+
expires_in = data.get("expiresIn") or data.get("expires_in", 600)
46+
print(data.get("expiresIn"))
47+
self._token = OAuth2Token(
48+
{
49+
"access_token": access_token,
50+
"token_type": data.get("tokenType")
51+
or data.get("token_type", "Bearer"),
52+
"expires_in": expires_in,
53+
"expires_at": int(time.time()) + expires_in,
54+
}
55+
)
56+
57+
return access_token
58+
59+
def invalidate_token(self):
60+
with self._lock:
61+
self._token = None
62+
63+
def close(self):
64+
if self._owns_client:
65+
self._http_client.close()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"PyYAML~=6.0.3",
3838
"httpx~=0.28.1",
3939
"httpx-retries~=0.4.5",
40+
"authlib~=1.3.0",
4041
]
4142

4243
[project.urls]

0 commit comments

Comments
 (0)