11"""Handles Postgres interactions."""
22
3+
34from __future__ import annotations
45
56import atexit
67import io
8+ import itertools
79import signal
10+ import sys
811import typing as t
912from contextlib import contextmanager
1013from os import chmod , path
@@ -79,7 +82,7 @@ def __init__(self, config: dict) -> None:
7982 sqlalchemy_url = url .render_as_string (hide_password = False ),
8083 )
8184
82- def prepare_table ( # type: ignore[override]
85+ def prepare_table ( # type: ignore[override] # noqa: PLR0913
8386 self ,
8487 full_table_name : str ,
8588 schema : dict ,
@@ -105,7 +108,7 @@ def prepare_table( # type: ignore[override]
105108 meta = sa .MetaData (schema = schema_name )
106109 table : sa .Table
107110 if not self .table_exists (full_table_name = full_table_name ):
108- table = self .create_empty_table (
111+ return self .create_empty_table (
109112 table_name = table_name ,
110113 meta = meta ,
111114 schema = schema ,
@@ -114,7 +117,6 @@ def prepare_table( # type: ignore[override]
114117 as_temp_table = as_temp_table ,
115118 connection = connection ,
116119 )
117- return table
118120 meta .reflect (connection , only = [table_name ])
119121 table = meta .tables [
120122 full_table_name
@@ -161,19 +163,19 @@ def copy_table_structure(
161163 _ , schema_name , table_name = self .parse_full_table_name (full_table_name )
162164 meta = sa .MetaData (schema = schema_name )
163165 new_table : sa .Table
164- columns = []
165166 if self .table_exists (full_table_name = full_table_name ):
166167 raise RuntimeError ("Table already exists" )
167- for column in from_table .columns :
168- columns .append (column ._copy ())
168+
169+ columns = [column ._copy () for column in from_table .columns ]
170+
169171 if as_temp_table :
170172 new_table = sa .Table (table_name , meta , * columns , prefixes = ["TEMPORARY" ])
171173 new_table .create (bind = connection )
172174 return new_table
173- else :
174- new_table = sa .Table (table_name , meta , * columns )
175- new_table .create (bind = connection )
176- return new_table
175+
176+ new_table = sa .Table (table_name , meta , * columns )
177+ new_table .create (bind = connection )
178+ return new_table
177179
178180 @contextmanager
179181 def _connect (self ) -> t .Iterator [sa .engine .Connection ]:
@@ -184,18 +186,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
184186 """Drop table data."""
185187 table .drop (bind = connection )
186188
187- def clone_table (
189+ def clone_table ( # noqa: PLR0913
188190 self , new_table_name , table , metadata , connection , temp_table
189191 ) -> sa .Table :
190192 """Clone a table."""
191- new_columns = []
192- for column in table .columns :
193- new_columns .append (
194- sa .Column (
195- column .name ,
196- column .type ,
197- )
193+ new_columns = [
194+ sa .Column (
195+ column .name ,
196+ column .type ,
198197 )
198+ for column in table .columns
199+ ]
199200 if temp_table is True :
200201 new_table = sa .Table (
201202 new_table_name , metadata , * new_columns , prefixes = ["TEMPORARY" ]
@@ -275,9 +276,8 @@ def pick_individual_type(jsonschema_type: dict):
275276 if jsonschema_type .get ("format" ) == "date-time" :
276277 return TIMESTAMP ()
277278 individual_type = th .to_sql_type (jsonschema_type )
278- if isinstance (individual_type , VARCHAR ):
279- return TEXT ()
280- return individual_type
279+
280+ return TEXT () if isinstance (individual_type , VARCHAR ) else individual_type
281281
282282 @staticmethod
283283 def pick_best_sql_type (sql_type_array : list ):
@@ -304,13 +304,12 @@ def pick_best_sql_type(sql_type_array: list):
304304 NOTYPE ,
305305 ]
306306
307- for sql_type in precedence_order :
308- for obj in sql_type_array :
309- if isinstance (obj , sql_type ):
310- return obj
307+ for sql_type , obj in itertools .product (precedence_order , sql_type_array ):
308+ if isinstance (obj , sql_type ):
309+ return obj
311310 return TEXT ()
312311
313- def create_empty_table ( # type: ignore[override]
312+ def create_empty_table ( # type: ignore[override] # noqa: PLR0913
314313 self ,
315314 table_name : str ,
316315 meta : sa .MetaData ,
@@ -324,7 +323,7 @@ def create_empty_table( # type: ignore[override]
324323
325324 Args:
326325 table_name: the target table name.
327- meta: the SQLAchemy metadata object.
326+ meta: the SQLAlchemy metadata object.
328327 schema: the JSON schema for the new table.
329328 connection: the database connection.
330329 primary_keys: list of key properties.
@@ -367,7 +366,7 @@ def create_empty_table( # type: ignore[override]
367366 new_table .create (bind = connection )
368367 return new_table
369368
370- def prepare_column (
369+ def prepare_column ( # noqa: PLR0913
371370 self ,
372371 full_table_name : str ,
373372 column_name : str ,
@@ -415,7 +414,7 @@ def prepare_column(
415414 column_object = column_object ,
416415 )
417416
418- def _create_empty_column ( # type: ignore[override]
417+ def _create_empty_column ( # type: ignore[override] # noqa: PLR0913
419418 self ,
420419 schema_name : str ,
421420 table_name : str ,
@@ -480,7 +479,7 @@ def get_column_add_ddl( # type: ignore[override]
480479 },
481480 )
482481
483- def _adapt_column_type ( # type: ignore[override]
482+ def _adapt_column_type ( # type: ignore[override] # noqa: PLR0913
484483 self ,
485484 schema_name : str ,
486485 table_name : str ,
@@ -523,7 +522,7 @@ def _adapt_column_type( # type: ignore[override]
523522 return
524523
525524 # Not the same type, generic type or compatible types
526- # calling merge_sql_types for assistnace
525+ # calling merge_sql_types for assistance
527526 compatible_sql_type = self .merge_sql_types ([current_type , sql_type ])
528527
529528 if str (compatible_sql_type ) == str (current_type ):
@@ -593,17 +592,16 @@ def get_sqlalchemy_url(self, config: dict) -> str:
593592 if config .get ("sqlalchemy_url" ):
594593 return cast (str , config ["sqlalchemy_url" ])
595594
596- else :
597- sqlalchemy_url = URL .create (
598- drivername = config ["dialect+driver" ],
599- username = config ["user" ],
600- password = config ["password" ],
601- host = config ["host" ],
602- port = config ["port" ],
603- database = config ["database" ],
604- query = self .get_sqlalchemy_query (config ),
605- )
606- return cast (str , sqlalchemy_url )
595+ sqlalchemy_url = URL .create (
596+ drivername = config ["dialect+driver" ],
597+ username = config ["user" ],
598+ password = config ["password" ],
599+ host = config ["host" ],
600+ port = config ["port" ],
601+ database = config ["database" ],
602+ query = self .get_sqlalchemy_query (config ),
603+ )
604+ return cast (str , sqlalchemy_url )
607605
608606 def get_sqlalchemy_query (self , config : dict ) -> dict :
609607 """Get query values to be used for sqlalchemy URL creation.
@@ -619,7 +617,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
619617 # ssl_enable is for verifying the server's identity to the client.
620618 if config ["ssl_enable" ]:
621619 ssl_mode = config ["ssl_mode" ]
622- query . update ({ "sslmode" : ssl_mode })
620+ query [ "sslmode" ] = ssl_mode
623621 query ["sslrootcert" ] = self .filepath_or_certificate (
624622 value = config ["ssl_certificate_authority" ],
625623 alternative_name = config ["ssl_storage_directory" ] + "/root.crt" ,
@@ -665,12 +663,11 @@ def filepath_or_certificate(
665663 """
666664 if path .isfile (value ):
667665 return value
668- else :
669- with open (alternative_name , "wb" ) as alternative_file :
670- alternative_file .write (value .encode ("utf-8" ))
671- if restrict_permissions :
672- chmod (alternative_name , 0o600 )
673- return alternative_name
666+ with open (alternative_name , "wb" ) as alternative_file :
667+ alternative_file .write (value .encode ("utf-8" ))
668+ if restrict_permissions :
669+ chmod (alternative_name , 0o600 )
670+ return alternative_name
674671
675672 def guess_key_type (self , key_data : str ) -> paramiko .PKey :
676673 """Guess the type of the private key.
@@ -695,7 +692,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
695692 ):
696693 try :
697694 key = key_class .from_private_key (io .StringIO (key_data )) # type: ignore[attr-defined]
698- except paramiko .SSHException :
695+ except paramiko .SSHException : # noqa: PERF203
699696 continue
700697 else :
701698 return key
@@ -715,7 +712,7 @@ def catch_signal(self, signum, frame) -> None:
715712 signum: The signal number
716713 frame: The current stack frame
717714 """
718- exit (1 ) # Calling this to be sure atexit is called, so clean_up gets called
715+ sys . exit (1 ) # Calling this to be sure atexit is called, so clean_up gets called
719716
720717 def _get_column_type ( # type: ignore[override]
721718 self ,
0 commit comments