Skip to content

Commit a2a086c

Browse files
authored
Make TLS certificates creation reusable (#31)
* feat: introduce a centralized TLS certificate manager * refactor: move web_ui cert management to TLSCertificateManager * refactor: expose url and local_url from WebUI * test: add basic tests for WebUI
1 parent 73774cc commit a2a086c

File tree

7 files changed

+842
-133
lines changed

7 files changed

+842
-133
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dev = [
3131
"setuptools",
3232
"build",
3333
"pytest",
34+
"websocket-client",
3435
"ruff",
3536
"docstring_parser>=0.16",
3637
"arduino_app_bricks[all]",

src/arduino/app_bricks/web_ui/certs.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

src/arduino/app_bricks/web_ui/web_ui.py

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
#
33
# SPDX-License-Identifier: MPL-2.0
44

5-
from collections.abc import Callable
6-
import asyncio
75
import os
6+
import asyncio
87
import threading
8+
from contextlib import asynccontextmanager
9+
from typing import Any
10+
from collections.abc import Callable
11+
912
import uvicorn
1013
from fastapi import FastAPI
1114
from fastapi.responses import FileResponse
1215
from fastapi_socketio import SocketManager
16+
1317
from arduino.app_utils import brick, Logger
1418

1519
logger = Logger("WebUI")
@@ -32,7 +36,8 @@ def __init__(
3236
api_path_prefix: str = "",
3337
assets_dir_path: str = "/app/assets",
3438
certs_dir_path: str = "/app/certs",
35-
use_ssl: bool = False,
39+
use_tls: bool = False,
40+
use_ssl: bool | None = None, # Deprecated alias for use_tls
3641
):
3742
"""Initialize the web server.
3843
@@ -42,35 +47,72 @@ def __init__(
4247
ui_path_prefix (str, optional): URL prefix for UI routes. Defaults to "" (root).
4348
api_path_prefix (str, optional): URL prefix for API routes. Defaults to "" (root).
4449
assets_dir_path (str, optional): Path to static assets directory. Defaults to "/app/assets".
45-
certs_dir_path (str, optional): Path to SSL certificates directory. Defaults to "/app/certs".
46-
use_ssl (bool, optional): Enable SSL/HTTPS. Defaults to False.
50+
certs_dir_path (str, optional): Path to TLS certificates directory. Defaults to "/app/certs".
51+
use_tls (bool, optional): Enable TLS/HTTPS. Defaults to False.
52+
use_ssl (bool, optional): Deprecated. Use use_tls instead. Defaults to None.
4753
"""
48-
self.app = FastAPI(title=__name__, openapi_url=None, on_startup=[self._on_startup])
54+
# Handle deprecated use_ssl parameter
55+
if use_ssl is not None:
56+
logger.warning("'use_ssl' parameter is deprecated. Use 'use_tls' instead.")
57+
use_tls = use_ssl
58+
59+
@asynccontextmanager
60+
async def lifespan(app):
61+
await self._on_startup()
62+
yield
63+
64+
self.app = FastAPI(title=__name__, openapi_url=None, lifespan=lifespan)
4965
self.sio = SocketManager(app=self.app, mount_location="/socket.io", socketio_path="", max_http_buffer_size=10 * 1024 * 1024)
5066

5167
self._addr = addr
52-
self._port = port
68+
69+
def pick_free_port():
70+
import socket
71+
72+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
73+
s.bind(("127.0.0.1", 0))
74+
return s.getsockname()[1]
75+
76+
self._port = port if port != 0 else pick_free_port()
5377
self._ui_path_prefix = ui_path_prefix
5478
self._api_path_prefix = api_path_prefix
5579
self._assets_dir_path = os.path.abspath(assets_dir_path)
5680
self._certs_dir_path = os.path.abspath(certs_dir_path)
57-
self._use_ssl = use_ssl
58-
self._protocol = "https" if self._use_ssl else "http"
59-
self._server: uvicorn.Server = None
81+
self._use_tls = use_tls
82+
self._protocol = "https" if self._use_tls else "http"
83+
self._server: uvicorn.Server | None = None
6084
self._server_loop: asyncio.AbstractEventLoop | None = None
61-
self._on_connect_cb: Callable[[str], None] = None
62-
self._on_disconnect_cb: Callable[[str], None] = None
85+
self._on_connect_cb: Callable[[str], None] | None = None
86+
self._on_disconnect_cb: Callable[[str], None] | None = None
6387
self._on_message_cbs = {}
6488
self._on_message_cbs_lock = threading.Lock()
6589

90+
@property
91+
def local_url(self) -> str:
92+
"""Get the locally addressable URL of the web server.
93+
94+
Returns:
95+
str: The server's URL (including protocol, address, and port).
96+
"""
97+
return f"{self._protocol}://localhost:{self._port}"
98+
99+
@property
100+
def url(self) -> str:
101+
"""Get the externally addressable URL of the web server.
102+
103+
Returns:
104+
str: The server's URL (including protocol, address, and port).
105+
"""
106+
return f"{self._protocol}://{os.getenv('HOST_IP') or self._addr}:{self._port}"
107+
66108
def start(self):
67109
"""Start the web server asynchronously.
68110
69-
This sets up static file routing and WebSocket event handlers, configures SSL if enabled, and launches the server using Uvicorn.
111+
This sets up static file routing and WebSocket event handlers, configures TLS if enabled, and launches the server using Uvicorn.
70112
71113
Raises:
72114
RuntimeError: If 'index.html' is missing in the static assets directory.
73-
RuntimeError: If SSL is enabled but certificates are missing or fail to generate.
115+
RuntimeError: If TLS is enabled but certificates fail to generate.
74116
RuntimeWarning: If the server is already running.
75117
"""
76118
# Setup static routes and SocketIO events
@@ -82,18 +124,15 @@ def start(self):
82124
self._init_socketio()
83125

84126
config = uvicorn.Config(self.app, host=self._addr, port=self._port, log_level="warning")
85-
if self._use_ssl:
86-
from . import certs
127+
if self._use_tls:
128+
from arduino.app_utils.tls_cert_manager import TLSCertificateManager
87129

88-
if not certs.cert_exists(self._certs_dir_path):
89-
try:
90-
certs.generate_self_signed_cert(self._certs_dir_path)
91-
except Exception as e:
92-
logger.exception(f"Failed to generate SSL certificate: {e}")
93-
raise RuntimeError("Failed to generate SSL certificate. Please check the certs directory.") from e
94-
95-
config.ssl_keyfile = certs.get_pkey(self._certs_dir_path)
96-
config.ssl_certfile = certs.get_cert(self._certs_dir_path)
130+
try:
131+
cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=self._certs_dir_path, common_name=self._addr)
132+
config.ssl_certfile = cert_path
133+
config.ssl_keyfile = key_path
134+
except Exception as e:
135+
raise RuntimeError("Failed to configure TLS certificate. Please check the certs directory.") from e
97136

98137
self._server = uvicorn.Server(config)
99138

@@ -108,33 +147,31 @@ def stop(self):
108147

109148
def execute(self):
110149
logger.debug(f"Serving static web files from {self._assets_dir_path}")
111-
if self._use_ssl:
112-
logger.debug(f"Serving certificates from {self._certs_dir_path}")
150+
if self._use_tls:
151+
logger.debug(f"Using TLS certificates from {self._certs_dir_path}")
113152

114153
logger.debug("Starting server...")
115154

116155
startup_log = "The application interface is available here:\n"
117-
startup_log += f" - Local URL: {self._protocol}://localhost:{self._port}"
118-
host_ip = os.getenv("HOST_IP")
119-
if host_ip:
120-
network_url = f"{self._protocol}://{host_ip}:{self._port}"
121-
startup_log += f"\n - Network URL: {network_url}"
156+
startup_log += f" - Local URL: {self.local_url}"
157+
if os.getenv("HOST_IP"):
158+
startup_log += f"\n - Network URL: {self.url}"
122159
logger.info(startup_log)
123160

124161
try:
125162
self._server.run()
126163
except Exception as e:
127164
logger.exception(f"Error running server: {e}")
128165

129-
def expose_api(self, method: str, path: str, function: callable):
166+
def expose_api(self, method: str, path: str, function: Callable):
130167
"""Register a route with the specified HTTP method and path.
131168
132169
The path will be prefixed with the api_path_prefix configured during initialization.
133170
134171
Args:
135172
method (str): HTTP method to use (e.g., "GET", "POST").
136173
path (str): URL path for the API endpoint (without the prefix).
137-
function (callable): Function to execute when the route is accessed.
174+
function (Callable): Function to execute when the route is accessed.
138175
"""
139176
self.app.add_api_route(self._api_path_prefix + path, function, methods=[method])
140177

@@ -160,7 +197,7 @@ def on_disconnect(self, callback: Callable[[str], None]):
160197
"""
161198
self._on_disconnect_cb = callback
162199

163-
def on_message(self, message_type: str, callback: Callable[[str, any], any]):
200+
def on_message(self, message_type: str, callback: Callable[[str, Any], Any]):
164201
"""Register a callback function for a specific WebSocket message type received by clients.
165202
166203
The client should send messages named as message_type for this callback to be triggered.
@@ -170,7 +207,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]):
170207
171208
Args:
172209
message_type (str): The message type name to listen for.
173-
callback (Callable[[str, any], any]): Function to handle the message. Receives two arguments:
210+
callback (Callable[[str, Any], Any]): Function to handle the message. Receives two arguments:
174211
the session ID (sid) and the incoming message data.
175212
176213
"""
@@ -180,7 +217,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]):
180217
self._on_message_cbs[message_type] = callback
181218
logger.debug(f"Registered listener for message '{message_type}'")
182219

183-
def send_message(self, message_type: str, message: dict | str, room: str = None):
220+
def send_message(self, message_type: str, message: dict | str, room: str | None = None):
184221
"""Send a message to connected WebSocket clients.
185222
186223
Args:
@@ -200,7 +237,8 @@ def send_message(self, message_type: str, message: dict | str, room: str = None)
200237
logger.exception(f"Failed to send WebSocket message '{message_type}': {e}")
201238

202239
async def _on_startup(self):
203-
"""This function is called by uvicorn when the server starts up, it is necessary to capture the running
240+
"""
241+
This function is called by uvicorn when the server starts up, it is necessary to capture the running
204242
asyncio event loop and reuse it later for emitting socket.io events as it requires an asyncio context.
205243
"""
206244
self._server_loop = asyncio.get_running_loop()

0 commit comments

Comments
 (0)