Skip to content

Commit 7f5fa06

Browse files
chore: Apply non-functional refactoring and fix typos
1 parent 14a7377 commit 7f5fa06

File tree

5 files changed

+128
-142
lines changed

5 files changed

+128
-142
lines changed

pyproject.toml

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,22 @@ target-version = "py38"
8080

8181
[tool.ruff.lint]
8282
select = [
83-
"F", # Pyflakes
84-
"W", # pycodestyle warnings
85-
"E", # pycodestyle errors
86-
"I", # isort
87-
"N", # pep8-naming
88-
"D", # pydocsyle
89-
"ICN", # flake8-import-conventions
90-
"RUF", # ruff
83+
"F", # Pyflakes
84+
"W", # pycodestyle warnings
85+
"E", # pycodestyle errors
86+
"I", # isort
87+
"N", # pep8-naming
88+
"D", # pydocsyle
89+
"UP", # pyupgrade
90+
"ICN", # flake8-import-conventions
91+
"RET", # flake8-return
92+
"SIM", # flake8-simplify
93+
"TCH", # flake8-type-checking
94+
"ERA", # eradicate
95+
"PGH", # pygrep-hooks
96+
"PL", # Pylint
97+
"PERF", # Perflint
98+
"RUF", # ruff
9199
]
92100

93101
[tool.ruff.lint.flake8-import-conventions]

target_postgres/connector.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Handles Postgres interactions."""
22

3+
34
from __future__ import annotations
45

56
import atexit
67
import io
8+
import itertools
79
import signal
10+
import sys
811
import typing as t
912
from contextlib import contextmanager
1013
from functools import cached_property
@@ -92,7 +95,7 @@ def interpret_content_encoding(self) -> bool:
9295
"""
9396
return self.config.get("interpret_content_encoding", False)
9497

95-
def prepare_table( # type: ignore[override]
98+
def prepare_table( # type: ignore[override] # noqa: PLR0913
9699
self,
97100
full_table_name: str,
98101
schema: dict,
@@ -118,7 +121,7 @@ def prepare_table( # type: ignore[override]
118121
meta = sa.MetaData(schema=schema_name)
119122
table: sa.Table
120123
if not self.table_exists(full_table_name=full_table_name):
121-
table = self.create_empty_table(
124+
return self.create_empty_table(
122125
table_name=table_name,
123126
meta=meta,
124127
schema=schema,
@@ -127,7 +130,6 @@ def prepare_table( # type: ignore[override]
127130
as_temp_table=as_temp_table,
128131
connection=connection,
129132
)
130-
return table
131133
meta.reflect(connection, only=[table_name])
132134
table = meta.tables[
133135
full_table_name
@@ -174,19 +176,19 @@ def copy_table_structure(
174176
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
175177
meta = sa.MetaData(schema=schema_name)
176178
new_table: sa.Table
177-
columns = []
178179
if self.table_exists(full_table_name=full_table_name):
179180
raise RuntimeError("Table already exists")
180-
for column in from_table.columns:
181-
columns.append(column._copy())
181+
182+
columns = [column._copy() for column in from_table.columns]
183+
182184
if as_temp_table:
183185
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
184186
new_table.create(bind=connection)
185187
return new_table
186-
else:
187-
new_table = sa.Table(table_name, meta, *columns)
188-
new_table.create(bind=connection)
189-
return new_table
188+
189+
new_table = sa.Table(table_name, meta, *columns)
190+
new_table.create(bind=connection)
191+
return new_table
190192

191193
@contextmanager
192194
def _connect(self) -> t.Iterator[sa.engine.Connection]:
@@ -197,18 +199,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
197199
"""Drop table data."""
198200
table.drop(bind=connection)
199201

200-
def clone_table(
202+
def clone_table( # noqa: PLR0913
201203
self, new_table_name, table, metadata, connection, temp_table
202204
) -> sa.Table:
203205
"""Clone a table."""
204-
new_columns = []
205-
for column in table.columns:
206-
new_columns.append(
207-
sa.Column(
208-
column.name,
209-
column.type,
210-
)
206+
new_columns = [
207+
sa.Column(
208+
column.name,
209+
column.type,
211210
)
211+
for column in table.columns
212+
]
212213
if temp_table is True:
213214
new_table = sa.Table(
214215
new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"]
@@ -293,9 +294,8 @@ def pick_individual_type(self, jsonschema_type: dict):
293294
):
294295
return HexByteString()
295296
individual_type = th.to_sql_type(jsonschema_type)
296-
if isinstance(individual_type, VARCHAR):
297-
return TEXT()
298-
return individual_type
297+
298+
return TEXT() if isinstance(individual_type, VARCHAR) else individual_type
299299

300300
@staticmethod
301301
def pick_best_sql_type(sql_type_array: list):
@@ -323,13 +323,12 @@ def pick_best_sql_type(sql_type_array: list):
323323
NOTYPE,
324324
]
325325

326-
for sql_type in precedence_order:
327-
for obj in sql_type_array:
328-
if isinstance(obj, sql_type):
329-
return obj
326+
for sql_type, obj in itertools.product(precedence_order, sql_type_array):
327+
if isinstance(obj, sql_type):
328+
return obj
330329
return TEXT()
331330

332-
def create_empty_table( # type: ignore[override]
331+
def create_empty_table( # type: ignore[override] # noqa: PLR0913
333332
self,
334333
table_name: str,
335334
meta: sa.MetaData,
@@ -343,7 +342,7 @@ def create_empty_table( # type: ignore[override]
343342
344343
Args:
345344
table_name: the target table name.
346-
meta: the SQLAchemy metadata object.
345+
meta: the SQLAlchemy metadata object.
347346
schema: the JSON schema for the new table.
348347
connection: the database connection.
349348
primary_keys: list of key properties.
@@ -386,7 +385,7 @@ def create_empty_table( # type: ignore[override]
386385
new_table.create(bind=connection)
387386
return new_table
388387

389-
def prepare_column(
388+
def prepare_column( # noqa: PLR0913
390389
self,
391390
full_table_name: str,
392391
column_name: str,
@@ -434,7 +433,7 @@ def prepare_column(
434433
column_object=column_object,
435434
)
436435

437-
def _create_empty_column( # type: ignore[override]
436+
def _create_empty_column( # type: ignore[override] # noqa: PLR0913
438437
self,
439438
schema_name: str,
440439
table_name: str,
@@ -499,7 +498,7 @@ def get_column_add_ddl( # type: ignore[override]
499498
},
500499
)
501500

502-
def _adapt_column_type( # type: ignore[override]
501+
def _adapt_column_type( # type: ignore[override] # noqa: PLR0913
503502
self,
504503
schema_name: str,
505504
table_name: str,
@@ -542,7 +541,7 @@ def _adapt_column_type( # type: ignore[override]
542541
return
543542

544543
# Not the same type, generic type or compatible types
545-
# calling merge_sql_types for assistnace
544+
# calling merge_sql_types for assistance
546545
compatible_sql_type = self.merge_sql_types([current_type, sql_type])
547546

548547
if str(compatible_sql_type) == str(current_type):
@@ -612,17 +611,16 @@ def get_sqlalchemy_url(self, config: dict) -> str:
612611
if config.get("sqlalchemy_url"):
613612
return cast(str, config["sqlalchemy_url"])
614613

615-
else:
616-
sqlalchemy_url = URL.create(
617-
drivername=config["dialect+driver"],
618-
username=config["user"],
619-
password=config["password"],
620-
host=config["host"],
621-
port=config["port"],
622-
database=config["database"],
623-
query=self.get_sqlalchemy_query(config),
624-
)
625-
return cast(str, sqlalchemy_url)
614+
sqlalchemy_url = URL.create(
615+
drivername=config["dialect+driver"],
616+
username=config["user"],
617+
password=config["password"],
618+
host=config["host"],
619+
port=config["port"],
620+
database=config["database"],
621+
query=self.get_sqlalchemy_query(config),
622+
)
623+
return cast(str, sqlalchemy_url)
626624

627625
def get_sqlalchemy_query(self, config: dict) -> dict:
628626
"""Get query values to be used for sqlalchemy URL creation.
@@ -638,7 +636,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
638636
# ssl_enable is for verifying the server's identity to the client.
639637
if config["ssl_enable"]:
640638
ssl_mode = config["ssl_mode"]
641-
query.update({"sslmode": ssl_mode})
639+
query["sslmode"] = ssl_mode
642640
query["sslrootcert"] = self.filepath_or_certificate(
643641
value=config["ssl_certificate_authority"],
644642
alternative_name=config["ssl_storage_directory"] + "/root.crt",
@@ -684,12 +682,11 @@ def filepath_or_certificate(
684682
"""
685683
if path.isfile(value):
686684
return value
687-
else:
688-
with open(alternative_name, "wb") as alternative_file:
689-
alternative_file.write(value.encode("utf-8"))
690-
if restrict_permissions:
691-
chmod(alternative_name, 0o600)
692-
return alternative_name
685+
with open(alternative_name, "wb") as alternative_file:
686+
alternative_file.write(value.encode("utf-8"))
687+
if restrict_permissions:
688+
chmod(alternative_name, 0o600)
689+
return alternative_name
693690

694691
def guess_key_type(self, key_data: str) -> paramiko.PKey:
695692
"""Guess the type of the private key.
@@ -714,7 +711,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
714711
):
715712
try:
716713
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined]
717-
except paramiko.SSHException:
714+
except paramiko.SSHException: # noqa: PERF203
718715
continue
719716
else:
720717
return key
@@ -734,7 +731,7 @@ def catch_signal(self, signum, frame) -> None:
734731
signum: The signal number
735732
frame: The current stack frame
736733
"""
737-
exit(1) # Calling this to be sure atexit is called, so clean_up gets called
734+
sys.exit(1) # Calling this to be sure atexit is called, so clean_up gets called
738735

739736
def _get_column_type( # type: ignore[override]
740737
self,

target_postgres/sinks.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def setup(self) -> None:
4747
This method is called on Sink creation, and creates the required Schema and
4848
Table entities in the target database.
4949
"""
50-
if self.key_properties is None or self.key_properties == []:
51-
self.append_only = True
52-
else:
53-
self.append_only = False
50+
self.append_only = self.key_properties is None or self.key_properties == []
5451
if self.schema_name:
5552
self.connector.prepare_schema(self.schema_name)
5653
with self.connector._connect() as connection, connection.begin():
@@ -109,14 +106,14 @@ def process_batch(self, context: dict) -> None:
109106

110107
def generate_temp_table_name(self):
111108
"""Uuid temp table name."""
112-
# sa.exc.IdentifierError: Identifier
109+
# sa.exc.IdentifierError: Identifier # noqa: ERA001
113110
# 'temp_test_optional_attributes_388470e9_fbd0_47b7_a52f_d32a2ee3f5f6'
114111
# exceeds maximum length of 63 characters
115112
# Is hit if we have a long table name, there is no limit on Temporary tables
116113
# in postgres, used a guid just in case we are using the same session
117114
return f"{str(uuid.uuid4()).replace('-', '_')}"
118115

119-
def bulk_insert_records( # type: ignore[override]
116+
def bulk_insert_records( # type: ignore[override] # noqa: PLR0913
120117
self,
121118
table: sa.Table,
122119
schema: dict,
@@ -156,24 +153,24 @@ def bulk_insert_records( # type: ignore[override]
156153
if self.append_only is False:
157154
insert_records: Dict[str, Dict] = {} # pk : record
158155
for record in records:
159-
insert_record = {}
160-
for column in columns:
161-
insert_record[column.name] = record.get(column.name)
156+
insert_record = {
157+
column.name: record.get(column.name) for column in columns
158+
}
162159
# No need to check for a KeyError here because the SDK already
163-
# guaruntees that all key properties exist in the record.
160+
# guarantees that all key properties exist in the record.
164161
primary_key_value = "".join([str(record[key]) for key in primary_keys])
165162
insert_records[primary_key_value] = insert_record
166163
data_to_insert = list(insert_records.values())
167164
else:
168165
for record in records:
169-
insert_record = {}
170-
for column in columns:
171-
insert_record[column.name] = record.get(column.name)
166+
insert_record = {
167+
column.name: record.get(column.name) for column in columns
168+
}
172169
data_to_insert.append(insert_record)
173170
connection.execute(insert, data_to_insert)
174171
return True
175172

176-
def upsert(
173+
def upsert( # noqa: PLR0913
177174
self,
178175
from_table: sa.Table,
179176
to_table: sa.Table,
@@ -232,7 +229,7 @@ def upsert(
232229
# Update
233230
where_condition = join_condition
234231
update_columns = {}
235-
for column_name in self.schema["properties"].keys():
232+
for column_name in self.schema["properties"]:
236233
from_table_column: sa.Column = from_table.columns[column_name]
237234
to_table_column: sa.Column = to_table.columns[column_name]
238235
update_columns[to_table_column] = from_table_column
@@ -249,14 +246,13 @@ def column_representation(
249246
schema: dict,
250247
) -> List[sa.Column]:
251248
"""Return a sqlalchemy table representation for the current schema."""
252-
columns: list[sa.Column] = []
253-
for property_name, property_jsonschema in schema["properties"].items():
254-
columns.append(
255-
sa.Column(
256-
property_name,
257-
self.connector.to_sql_type(property_jsonschema),
258-
)
249+
columns: list[sa.Column] = [
250+
sa.Column(
251+
property_name,
252+
self.connector.to_sql_type(property_jsonschema),
259253
)
254+
for property_name, property_jsonschema in schema["properties"].items()
255+
]
260256
return columns
261257

262258
def generate_insert_statement(
@@ -286,12 +282,12 @@ def schema_name(self) -> Optional[str]:
286282
"""Return the schema name or `None` if using names with no schema part.
287283
288284
Note that after the next SDK release (after 0.14.0) we can remove this
289-
as it's already upstreamed.
285+
as it's already up-streamed.
290286
291287
Returns:
292288
The target schema name.
293289
"""
294-
# Look for a default_target_scheme in the configuraion fle
290+
# Look for a default_target_scheme in the configuration file
295291
default_target_schema: str = self.config.get("default_target_schema", None)
296292
parts = self.stream_name.split("-")
297293

@@ -302,14 +298,7 @@ def schema_name(self) -> Optional[str]:
302298
if default_target_schema:
303299
return default_target_schema
304300

305-
if len(parts) in {2, 3}:
306-
# Stream name is a two-part or three-part identifier.
307-
# Use the second-to-last part as the schema name.
308-
stream_schema = self.conform_name(parts[-2], "schema")
309-
return stream_schema
310-
311-
# Schema name not detected.
312-
return None
301+
return self.conform_name(parts[-2], "schema") if len(parts) in {2, 3} else None
313302

314303
def activate_version(self, new_version: int) -> None:
315304
"""Bump the active version of the target table.

0 commit comments

Comments
 (0)