Skip to content

Commit 13ec9d1

Browse files
authored
Add driver parameter to get_uri() to facilitate psycopg3 support (#10)
Also updates the tests to check that SQLAlchemy works with both psycopg2 and psycopg3.
1 parent a633959 commit 13ec9d1

File tree

4 files changed

+53
-42
lines changed

4 files changed

+53
-42
lines changed

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66

77
[project]
88
name = "pixeltable-pgserver"
9-
version = "0.2.6"
9+
version = "0.2.7"
1010
description = "Embedded Postgres Server for Pixeltable"
1111
readme = "README.md"
1212
requires-python = ">=3.9"
@@ -26,11 +26,15 @@ dev = [
2626
]
2727
test = [
2828
"pytest",
29-
"psycopg2-binary",
29+
"psycopg2-binary==2.9.9",
30+
"psycopg[binary]==3.1.18",
3031
"sqlalchemy>=2",
3132
"sqlalchemy-utils"
3233
]
3334

35+
[tool.isort]
36+
line_length = 120
37+
3438
[tool.setuptools.packages.find]
3539
where = ["src"] # list of folders that contain the packages (["."] by default)
3640
include = ["pixeltable_pgserver*"] # package names should match these glob patterns (["*"] by default)

src/pixeltable_pgserver/postgres_server.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
1-
from pathlib import Path
2-
from typing import Optional, Dict, Union
3-
import shutil
41
import atexit
5-
import subprocess
6-
import os
72
import logging
3+
import os
84
import platform
9-
import psutil
5+
import shutil
6+
import subprocess
107
import time
8+
from pathlib import Path
9+
from typing import Optional, Union
10+
11+
import psutil
1112

1213
from ._commands import POSTGRES_BIN_PATH, initdb, pg_ctl
13-
from .utils import find_suitable_port, find_suitable_socket_dir, DiskList, PostmasterInfo, process_is_running
14+
from .utils import DiskList, PostmasterInfo, find_suitable_port, find_suitable_socket_dir
1415

1516
if platform.system() != 'Windows':
16-
from .utils import ensure_user_exists, ensure_prefix_permissions, ensure_folder_permissions
17+
from .utils import ensure_folder_permissions, ensure_prefix_permissions, ensure_user_exists
1718

1819
_logger = logging.getLogger('pixeltable_pgserver')
1920

2021
class PostgresServer:
2122
""" Provides a common interface for interacting with a server.
2223
"""
23-
import platformdirs
2424
import fasteners
25+
import platformdirs
2526

26-
_instances : Dict[Path, 'PostgresServer'] = {}
27+
_instances : dict[Path, 'PostgresServer'] = {}
2728

2829
# NB home does not always support locking, eg NFS or LUSTRE (eg some clusters)
2930
# so, use user_runtime_path instead, which seems to be in a local filesystem
@@ -75,10 +76,10 @@ def get_pid(self) -> Optional[int]:
7576
"""
7677
return self.get_postmaster_info().pid
7778

78-
def get_uri(self, database : Optional[str] = None) -> str:
79+
def get_uri(self, database: Optional[str] = None, driver: Optional[str] = None) -> str:
7980
""" Returns a connection string for the postgresql server.
8081
"""
81-
return self.get_postmaster_info().get_uri(database=database)
82+
return self.get_postmaster_info().get_uri(database=database, driver=driver)
8283

8384
def ensure_pgdata_inited(self) -> None:
8485
""" Initializes the pgdata directory if it is not already initialized.

src/pixeltable_pgserver/utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from pathlib import Path
2-
import typing
3-
from typing import Optional, List, Dict
4-
import subprocess
1+
import datetime
2+
import hashlib
53
import json
64
import logging
7-
import hashlib
8-
import socket
95
import platform
6+
import socket
107
import stat
8+
import subprocess
9+
from pathlib import Path
10+
from typing import Dict, List, Optional
11+
1112
import psutil
12-
import datetime
13-
import shutil
1413

1514
_logger = logging.getLogger('pixeltable_pgserver')
1615

@@ -87,16 +86,20 @@ def read_from_pgdata(cls, pgdata : Path) -> Optional['PostmasterInfo']:
8786
lines = postmaster_file.read_text().splitlines()
8887
return cls(lines)
8988

90-
def get_uri(self, user : str = 'postgres', database : Optional[str] = None) -> str:
89+
def get_uri(self, user: str = 'postgres', database: Optional[str] = None, driver: Optional[str] = None) -> str:
9190
""" Returns a connection uri string for the postgresql server using the information in postmaster.pid"""
9291
if database is None:
9392
database = user
93+
if driver is None:
94+
driver_suffix = ''
95+
else:
96+
driver_suffix = f'+{driver}'
9497

9598
if self.socket_dir is not None:
96-
return f"postgresql://{user}:@/{database}?host={self.socket_dir}"
99+
return f"postgresql{driver_suffix}://{user}:@/{database}?host={self.socket_dir}"
97100
elif self.port is not None:
98101
assert self.hostname is not None
99-
return f"postgresql://{user}:@{self.hostname}:{self.port}/{database}"
102+
return f"postgresql{driver_suffix}://{user}:@{self.hostname}:{self.port}/{database}"
100103
else:
101104
raise RuntimeError("postmaster.pid does not contain port or socket information")
102105

tests/test_pgserver.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1-
import pytest
2-
import pixeltable_pgserver
3-
import subprocess
4-
import tempfile
5-
from typing import Optional, Union
1+
import logging
62
import multiprocessing as mp
3+
import os
4+
import platform
75
import shutil
8-
from pathlib import Path
9-
import pixeltable_pgserver.utils
106
import socket
11-
from pixeltable_pgserver.utils import find_suitable_port, process_is_running
7+
import subprocess
8+
import tempfile
9+
from pathlib import Path
10+
from typing import Optional, Union
11+
1212
import psutil
13-
import platform
13+
import pytest
1414
import sqlalchemy as sa
15-
import datetime
16-
from sqlalchemy_utils import database_exists, create_database
17-
import logging
18-
import os
15+
from sqlalchemy_utils import create_database, database_exists
1916

20-
def _check_sqlalchemy_works(srv : pixeltable_pgserver.PostgresServer):
17+
import pixeltable_pgserver
18+
import pixeltable_pgserver.utils
19+
from pixeltable_pgserver.utils import find_suitable_port, process_is_running
20+
21+
22+
def _check_sqlalchemy_works(srv: pixeltable_pgserver.PostgresServer, driver: Optional[str] = None):
2123
database_name = 'testdb'
22-
uri = srv.get_uri(database_name)
24+
uri = srv.get_uri(database_name, driver)
2325

2426
if not database_exists(uri):
2527
create_database(uri)
@@ -39,7 +41,7 @@ def _check_sqlalchemy_works(srv : pixeltable_pgserver.PostgresServer):
3941
assert result
4042
assert result[0] == 1
4143

42-
def _check_time_zones(srv : pixeltable_pgserver.PostgresServer):
44+
def _check_time_zones(srv: pixeltable_pgserver.PostgresServer):
4345
# Check that time zone information was properly compiled
4446
database_name = 'testdb'
4547
uri = srv.get_uri(database_name)
@@ -78,7 +80,8 @@ def _check_server(pg : pixeltable_pgserver.PostgresServer) -> int:
7880
# parse second row (first two are headers)
7981
ret_path = Path(ret.splitlines()[2].strip())
8082
assert pg.pgdata == ret_path
81-
_check_sqlalchemy_works(pg)
83+
_check_sqlalchemy_works(pg, None) # Test with psycopg2 (default)
84+
_check_sqlalchemy_works(pg, 'psycopg') # Test with psycopg3
8285
_check_time_zones(pg)
8386
return postmaster_info.pid
8487

0 commit comments

Comments
 (0)