Skip to content

Commit e0e19f5

Browse files
committed
Fix #9: Fix enviroment variable configuration, some fields were lost
1 parent 15ceaf8 commit e0e19f5

File tree

2 files changed

+51
-14
lines changed

2 files changed

+51
-14
lines changed

savant_cloudpin/cfg/_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ def as_value_dict(obj: Any) -> dict[str, Any]:
4343

4444

4545
def env_interpolation(
46-
default: str | int | float | bool, name: str, alt: str | None = None
46+
default: str | int | float | bool | None, name: str, alt: str | None = None
4747
) -> str:
48-
if isinstance(default, bool):
48+
if default is None:
49+
default = "null"
50+
elif isinstance(default, bool):
4951
default = str(default).lower()
5052
if alt:
5153
default = f"${{oc.env:{alt},{default}}}"
@@ -72,12 +74,10 @@ def env_override[T](
7274
for name, val, alt in items:
7375
env_name = f"{prefix}_{name.upper()}" if prefix else name.upper()
7476
match val:
75-
case str() | int() | float() | bool():
76-
if default is None:
77-
updates[name] = env_interpolation(val, env_name, alt)
78-
else:
79-
updates[name] = env_interpolation(default, env_name, alt)
80-
case list() | None:
77+
case str() | int() | float() | bool() | None:
78+
val = val if default is None else default
79+
updates[name] = env_interpolation(val, env_name, alt)
80+
case list():
8181
continue
8282
case _:
8383
updates[name] = env_override(val, default, env_name)

tests/savant_cloudpin/test_cfg.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import functools
12
import os
23
import textwrap
34
import unittest.mock
45
from collections.abc import Generator
6+
from typing import Any
57
from unittest.mock import Mock
68

79
import pytest
@@ -94,6 +96,40 @@ def invalid_urls(request: pytest.FixtureRequest) -> tuple[str, str]:
9496
return request.param
9597

9698

99+
@pytest.fixture(
100+
params=[
101+
"websockets.endpoint",
102+
"websockets.ssl.cert_file",
103+
"metrics.otlp.endpoint",
104+
"metrics.prometheus.endpoint",
105+
"health.endpoint",
106+
]
107+
)
108+
def config_env_vars(request: pytest.FixtureRequest) -> tuple[str, dict[str, str], Any]:
109+
match request.param:
110+
case "websockets.endpoint":
111+
expected = fake.uri(["wss", "ws"])
112+
env_vars = {"CLOUDPIN_WEBSOCKETS_ENDPOINT": expected}
113+
case "websockets.ssl.cert_file":
114+
expected = fake.file_path()
115+
env_vars = {
116+
"CLOUDPIN_WEBSOCKETS_SSL_CERT_FILE": expected,
117+
"CLOUDPIN_WEBSOCKETS_SSL_KEY_FILE": fake.file_path(),
118+
}
119+
case "metrics.otlp.endpoint":
120+
expected = fake.uri(["http"])
121+
env_vars = {"CLOUDPIN_METRICS_OTLP_ENDPOINT": expected}
122+
case "metrics.prometheus.endpoint":
123+
expected = fake.uri(["http"])
124+
env_vars = {"CLOUDPIN_METRICS_PROMETHEUS_ENDPOINT": expected}
125+
case "health.endpoint":
126+
expected = fake.uri(["http"])
127+
env_vars = {"CLOUDPIN_HEALTH_ENDPOINT": expected}
128+
case _:
129+
raise ValueError
130+
return request.param, env_vars, expected
131+
132+
97133
def test_load_config_when_valid_urls(
98134
valid_urls: tuple[str, str], some_cli_config: dict[str, str]
99135
) -> None:
@@ -121,18 +157,19 @@ def test_load_config_when_invalid_urls(
121157
load_config(cli_args)
122158

123159

124-
def test_load_config_with_environ_var(some_cli_config: dict[str, str]) -> None:
125-
endpoint = fake.uri(["wss", "ws"])
126-
environ_vars = {"CLOUDPIN_WEBSOCKETS_ENDPOINT": endpoint}
160+
def test_load_config_with_environ_var(
161+
some_cli_config: dict[str, str], config_env_vars: tuple[str, str, Any]
162+
) -> None:
163+
attr, env_vars, expected = config_env_vars
127164
cli_config = some_cli_config.copy()
128-
del cli_config["websockets.endpoint"]
165+
cli_config.pop(attr, None)
129166
cli_args = ["=".join(arg) for arg in cli_config.items()]
130167

131-
with unittest.mock.patch.dict(os.environ, environ_vars):
168+
with unittest.mock.patch.dict(os.environ, env_vars):
132169
result = load_config(cli_args)
133170

134171
assert isinstance(result, (ServerServiceConfig, ClientServiceConfig))
135-
assert result.websockets.endpoint == endpoint
172+
assert expected == functools.reduce(getattr, attr.split("."), result)
136173

137174

138175
def test_load_config_with_file(some_cli_config: dict[str, str]) -> None:

0 commit comments

Comments
 (0)