Skip to content

Commit 3e6d4da

Browse files
committed
Support non ddtool auth
1 parent 7707661 commit 3e6d4da

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed
Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1-
import logging
1+
import os
22

3-
from dd_internal_authentication.libs.py.dd_internal_authentication.dd_internal_authentication.client import (
4-
JWTDDToolAuthClientTokenManager,
5-
)
63

7-
logger = logging.getLogger(__name__)
4+
class JWT:
5+
def __init__(self, audience: str, datacenter: str):
6+
self.audience = audience
87

8+
from dd_internal_authentication.libs.py.dd_internal_authentication.dd_internal_authentication.client import (
9+
JWTDDToolAuthClientTokenManager,
10+
JWTInternalServiceAuthClientTokenManager,
11+
)
912

10-
def get_token(datacenter: str, audience: str) -> str:
11-
token = JWTDDToolAuthClientTokenManager.instance(
12-
name=audience, datacenter=datacenter
13-
).get_token(audience)
14-
return str(token)
13+
if os.getenv("POD_NAME"):
14+
token_manager_class = JWTInternalServiceAuthClientTokenManager
15+
else:
16+
token_manager_class = JWTDDToolAuthClientTokenManager
17+
18+
self.token_manager = token_manager_class.instance(
19+
name=self.audience, datacenter=datacenter
20+
)
21+
22+
def get_token(self) -> str:
23+
try:
24+
return str(self.token_manager.get_token(self.audience))
25+
except Exception as e:
26+
raise RuntimeError(f"Failed to get authentication token: {str(e)}")

src/spark_history_mcp/common/yoshi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99

1010
from spark_history_mcp.common import vault
11+
from spark_history_mcp.common.vault import JWT
1112

1213

1314
class Yoshi:
@@ -19,7 +20,7 @@ def __init__(self, datacenter: str):
1920
host = f"https://yoshi.{self.AUDIENCE}.all-clusters.local-dc.fabric.dog:8443"
2021
self.configuration = Configuration(
2122
host=host,
22-
access_token=vault.get_token(datacenter, self.AUDIENCE),
23+
access_token=JWT(audience=self.AUDIENCE,datacenter=datacenter).get_token(),
2324
)
2425

2526
def get_job_definition(self, job_id: str) -> Job:

0 commit comments

Comments
 (0)