Skip to content

Commit 402d023

Browse files
chore: Enforce importing sqlalchemy as sa (#280)
cc @amotl
1 parent 78a1063 commit 402d023

File tree

6 files changed

+117
-112
lines changed

6 files changed

+117
-112
lines changed

pyproject.toml

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,25 @@ vcs = "git"
7575
style = "semver"
7676

7777
[tool.ruff]
78+
target-version = "py38"
79+
80+
[tool.ruff.lint]
7881
select = [
79-
"F", # Pyflakes
80-
"W", # pycodestyle warnings
81-
"E", # pycodestyle errors
82-
"I", # isort
83-
"N", # pep8-naming
84-
"D", # pydocsyle
82+
"F", # Pyflakes
83+
"W", # pycodestyle warnings
84+
"E", # pycodestyle errors
85+
"I", # isort
86+
"N", # pep8-naming
87+
"D", # pydocsyle
88+
"ICN", # flake8-import-conventions
89+
"RUF", # ruff
8590
]
86-
target-version = "py38"
8791

88-
[tool.ruff.pydocstyle]
92+
[tool.ruff.lint.flake8-import-conventions]
93+
banned-from = ["sqlalchemy"]
94+
95+
[tool.ruff.lint.flake8-import-conventions.extend-aliases]
96+
sqlalchemy = "sa"
97+
98+
[tool.ruff.lint.pydocstyle]
8999
convention = "google"

target_postgres/connector.py

Lines changed: 54 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import paramiko
1414
import simplejson
15-
import sqlalchemy
15+
import sqlalchemy as sa
1616
from singer_sdk import SQLConnector
1717
from singer_sdk import typing as th
1818
from sqlalchemy.dialects.postgresql import ARRAY, BIGINT, JSONB
@@ -84,10 +84,10 @@ def prepare_table( # type: ignore[override]
8484
full_table_name: str,
8585
schema: dict,
8686
primary_keys: list[str],
87-
connection: sqlalchemy.engine.Connection,
87+
connection: sa.engine.Connection,
8888
partition_keys: list[str] | None = None,
8989
as_temp_table: bool = False,
90-
) -> sqlalchemy.Table:
90+
) -> sa.Table:
9191
"""Adapt target table to provided schema if possible.
9292
9393
Args:
@@ -102,8 +102,8 @@ def prepare_table( # type: ignore[override]
102102
The table object.
103103
"""
104104
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
105-
meta = sqlalchemy.MetaData(schema=schema_name)
106-
table: sqlalchemy.Table
105+
meta = sa.MetaData(schema=schema_name)
106+
table: sa.Table
107107
if not self.table_exists(full_table_name=full_table_name):
108108
table = self.create_empty_table(
109109
table_name=table_name,
@@ -143,10 +143,10 @@ def prepare_table( # type: ignore[override]
143143
def copy_table_structure(
144144
self,
145145
full_table_name: str,
146-
from_table: sqlalchemy.Table,
147-
connection: sqlalchemy.engine.Connection,
146+
from_table: sa.Table,
147+
connection: sa.engine.Connection,
148148
as_temp_table: bool = False,
149-
) -> sqlalchemy.Table:
149+
) -> sa.Table:
150150
"""Copy table structure.
151151
152152
Args:
@@ -159,58 +159,54 @@ def copy_table_structure(
159159
The new table object.
160160
"""
161161
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
162-
meta = sqlalchemy.MetaData(schema=schema_name)
163-
new_table: sqlalchemy.Table
162+
meta = sa.MetaData(schema=schema_name)
163+
new_table: sa.Table
164164
columns = []
165165
if self.table_exists(full_table_name=full_table_name):
166166
raise RuntimeError("Table already exists")
167167
for column in from_table.columns:
168168
columns.append(column._copy())
169169
if as_temp_table:
170-
new_table = sqlalchemy.Table(
171-
table_name, meta, *columns, prefixes=["TEMPORARY"]
172-
)
170+
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
173171
new_table.create(bind=connection)
174172
return new_table
175173
else:
176-
new_table = sqlalchemy.Table(table_name, meta, *columns)
174+
new_table = sa.Table(table_name, meta, *columns)
177175
new_table.create(bind=connection)
178176
return new_table
179177

180178
@contextmanager
181-
def _connect(self) -> t.Iterator[sqlalchemy.engine.Connection]:
179+
def _connect(self) -> t.Iterator[sa.engine.Connection]:
182180
with self._engine.connect().execution_options() as conn:
183181
yield conn
184182

185-
def drop_table(
186-
self, table: sqlalchemy.Table, connection: sqlalchemy.engine.Connection
187-
):
183+
def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
188184
"""Drop table data."""
189185
table.drop(bind=connection)
190186

191187
def clone_table(
192188
self, new_table_name, table, metadata, connection, temp_table
193-
) -> sqlalchemy.Table:
189+
) -> sa.Table:
194190
"""Clone a table."""
195191
new_columns = []
196192
for column in table.columns:
197193
new_columns.append(
198-
sqlalchemy.Column(
194+
sa.Column(
199195
column.name,
200196
column.type,
201197
)
202198
)
203199
if temp_table is True:
204-
new_table = sqlalchemy.Table(
200+
new_table = sa.Table(
205201
new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"]
206202
)
207203
else:
208-
new_table = sqlalchemy.Table(new_table_name, metadata, *new_columns)
204+
new_table = sa.Table(new_table_name, metadata, *new_columns)
209205
new_table.create(bind=connection)
210206
return new_table
211207

212208
@staticmethod
213-
def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
209+
def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
214210
"""Return a JSON Schema representation of the provided type.
215211
216212
By default will call `typing.to_sql_type()`.
@@ -317,13 +313,13 @@ def pick_best_sql_type(sql_type_array: list):
317313
def create_empty_table( # type: ignore[override]
318314
self,
319315
table_name: str,
320-
meta: sqlalchemy.MetaData,
316+
meta: sa.MetaData,
321317
schema: dict,
322-
connection: sqlalchemy.engine.Connection,
318+
connection: sa.engine.Connection,
323319
primary_keys: list[str] | None = None,
324320
partition_keys: list[str] | None = None,
325321
as_temp_table: bool = False,
326-
) -> sqlalchemy.Table:
322+
) -> sa.Table:
327323
"""Create an empty target table.
328324
329325
Args:
@@ -342,7 +338,7 @@ def create_empty_table( # type: ignore[override]
342338
NotImplementedError: if temp tables are unsupported and as_temp_table=True.
343339
RuntimeError: if a variant schema is passed with no properties defined.
344340
"""
345-
columns: list[sqlalchemy.Column] = []
341+
columns: list[sa.Column] = []
346342
primary_keys = primary_keys or []
347343
try:
348344
properties: dict = schema["properties"]
@@ -355,31 +351,29 @@ def create_empty_table( # type: ignore[override]
355351
for property_name, property_jsonschema in properties.items():
356352
is_primary_key = property_name in primary_keys
357353
columns.append(
358-
sqlalchemy.Column(
354+
sa.Column(
359355
property_name,
360356
self.to_sql_type(property_jsonschema),
361357
primary_key=is_primary_key,
362358
autoincrement=False, # See: https://github.com/MeltanoLabs/target-postgres/issues/193 # noqa: E501
363359
)
364360
)
365361
if as_temp_table:
366-
new_table = sqlalchemy.Table(
367-
table_name, meta, *columns, prefixes=["TEMPORARY"]
368-
)
362+
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
369363
new_table.create(bind=connection)
370364
return new_table
371365

372-
new_table = sqlalchemy.Table(table_name, meta, *columns)
366+
new_table = sa.Table(table_name, meta, *columns)
373367
new_table.create(bind=connection)
374368
return new_table
375369

376370
def prepare_column(
377371
self,
378372
full_table_name: str,
379373
column_name: str,
380-
sql_type: sqlalchemy.types.TypeEngine,
381-
connection: sqlalchemy.engine.Connection | None = None,
382-
column_object: sqlalchemy.Column | None = None,
374+
sql_type: sa.types.TypeEngine,
375+
connection: sa.engine.Connection | None = None,
376+
column_object: sa.Column | None = None,
383377
) -> None:
384378
"""Adapt target table to provided schema if possible.
385379
@@ -402,7 +396,7 @@ def prepare_column(
402396

403397
if not column_exists:
404398
self._create_empty_column(
405-
# We should migrate every function to use sqlalchemy.Table
399+
# We should migrate every function to use sa.Table
406400
# instead of having to know what the function wants
407401
table_name=table_name,
408402
column_name=column_name,
@@ -426,8 +420,8 @@ def _create_empty_column( # type: ignore[override]
426420
schema_name: str,
427421
table_name: str,
428422
column_name: str,
429-
sql_type: sqlalchemy.types.TypeEngine,
430-
connection: sqlalchemy.engine.Connection,
423+
sql_type: sa.types.TypeEngine,
424+
connection: sa.engine.Connection,
431425
) -> None:
432426
"""Create a new column.
433427
@@ -458,8 +452,8 @@ def get_column_add_ddl( # type: ignore[override]
458452
table_name: str,
459453
schema_name: str,
460454
column_name: str,
461-
column_type: sqlalchemy.types.TypeEngine,
462-
) -> sqlalchemy.DDL:
455+
column_type: sa.types.TypeEngine,
456+
) -> sa.DDL:
463457
"""Get the create column DDL statement.
464458
465459
Args:
@@ -471,9 +465,9 @@ def get_column_add_ddl( # type: ignore[override]
471465
Returns:
472466
A sqlalchemy DDL instance.
473467
"""
474-
column = sqlalchemy.Column(column_name, column_type)
468+
column = sa.Column(column_name, column_type)
475469

476-
return sqlalchemy.DDL(
470+
return sa.DDL(
477471
(
478472
'ALTER TABLE "%(schema_name)s"."%(table_name)s"'
479473
"ADD COLUMN %(column_name)s %(column_type)s"
@@ -491,9 +485,9 @@ def _adapt_column_type( # type: ignore[override]
491485
schema_name: str,
492486
table_name: str,
493487
column_name: str,
494-
sql_type: sqlalchemy.types.TypeEngine,
495-
connection: sqlalchemy.engine.Connection,
496-
column_object: sqlalchemy.Column | None,
488+
sql_type: sa.types.TypeEngine,
489+
connection: sa.engine.Connection,
490+
column_object: sa.Column | None,
497491
) -> None:
498492
"""Adapt table column type to support the new JSON schema type.
499493
@@ -508,9 +502,9 @@ def _adapt_column_type( # type: ignore[override]
508502
Raises:
509503
NotImplementedError: if altering columns is not supported.
510504
"""
511-
current_type: sqlalchemy.types.TypeEngine
505+
current_type: sa.types.TypeEngine
512506
if column_object is not None:
513-
current_type = t.cast(sqlalchemy.types.TypeEngine, column_object.type)
507+
current_type = t.cast(sa.types.TypeEngine, column_object.type)
514508
else:
515509
current_type = self._get_column_type(
516510
schema_name=schema_name,
@@ -561,8 +555,8 @@ def get_column_alter_ddl( # type: ignore[override]
561555
schema_name: str,
562556
table_name: str,
563557
column_name: str,
564-
column_type: sqlalchemy.types.TypeEngine,
565-
) -> sqlalchemy.DDL:
558+
column_type: sa.types.TypeEngine,
559+
) -> sa.DDL:
566560
"""Get the alter column DDL statement.
567561
568562
Override this if your database uses a different syntax for altering columns.
@@ -576,8 +570,8 @@ def get_column_alter_ddl( # type: ignore[override]
576570
Returns:
577571
A sqlalchemy DDL instance.
578572
"""
579-
column = sqlalchemy.Column(column_name, column_type)
580-
return sqlalchemy.DDL(
573+
column = sa.Column(column_name, column_type)
574+
return sa.DDL(
581575
(
582576
'ALTER TABLE "%(schema_name)s"."%(table_name)s"'
583577
"ALTER COLUMN %(column_name)s %(column_type)s"
@@ -700,7 +694,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
700694
paramiko.Ed25519Key,
701695
):
702696
try:
703-
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined] # noqa: E501
697+
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined]
704698
except paramiko.SSHException:
705699
continue
706700
else:
@@ -728,8 +722,8 @@ def _get_column_type( # type: ignore[override]
728722
schema_name: str,
729723
table_name: str,
730724
column_name: str,
731-
connection: sqlalchemy.engine.Connection,
732-
) -> sqlalchemy.types.TypeEngine:
725+
connection: sa.engine.Connection,
726+
) -> sa.types.TypeEngine:
733727
"""Get the SQL type of the declared column.
734728
735729
Args:
@@ -757,15 +751,15 @@ def _get_column_type( # type: ignore[override]
757751
)
758752
raise KeyError(msg) from ex
759753

760-
return t.cast(sqlalchemy.types.TypeEngine, column.type)
754+
return t.cast(sa.types.TypeEngine, column.type)
761755

762756
def get_table_columns( # type: ignore[override]
763757
self,
764758
schema_name: str,
765759
table_name: str,
766-
connection: sqlalchemy.engine.Connection,
760+
connection: sa.engine.Connection,
767761
column_names: list[str] | None = None,
768-
) -> dict[str, sqlalchemy.Column]:
762+
) -> dict[str, sa.Column]:
769763
"""Return a list of table columns.
770764
771765
Overrode to support schema_name
@@ -779,11 +773,11 @@ def get_table_columns( # type: ignore[override]
779773
Returns:
780774
An ordered list of column objects.
781775
"""
782-
inspector = sqlalchemy.inspect(connection)
776+
inspector = sa.inspect(connection)
783777
columns = inspector.get_columns(table_name, schema_name)
784778

785779
return {
786-
col_meta["name"]: sqlalchemy.Column(
780+
col_meta["name"]: sa.Column(
787781
col_meta["name"],
788782
col_meta["type"],
789783
nullable=col_meta.get("nullable", False),
@@ -797,7 +791,7 @@ def column_exists( # type: ignore[override]
797791
self,
798792
full_table_name: str,
799793
column_name: str,
800-
connection: sqlalchemy.engine.Connection,
794+
connection: sa.engine.Connection,
801795
) -> bool:
802796
"""Determine if the target column already exists.
803797

0 commit comments

Comments
 (0)