11"""Handles Postgres interactions."""
22
3+
34from __future__ import annotations
45
56import atexit
67import io
8+ import itertools
79import signal
10+ import sys
811import typing as t
912from contextlib import contextmanager
1013from functools import cached_property
@@ -92,7 +95,7 @@ def interpret_content_encoding(self) -> bool:
9295 """
9396 return self .config .get ("interpret_content_encoding" , False )
9497
95- def prepare_table ( # type: ignore[override]
98+ def prepare_table ( # type: ignore[override] # noqa: PLR0913
9699 self ,
97100 full_table_name : str ,
98101 schema : dict ,
@@ -118,7 +121,7 @@ def prepare_table( # type: ignore[override]
118121 meta = sa .MetaData (schema = schema_name )
119122 table : sa .Table
120123 if not self .table_exists (full_table_name = full_table_name ):
121- table = self .create_empty_table (
124+ return self .create_empty_table (
122125 table_name = table_name ,
123126 meta = meta ,
124127 schema = schema ,
@@ -127,7 +130,6 @@ def prepare_table( # type: ignore[override]
127130 as_temp_table = as_temp_table ,
128131 connection = connection ,
129132 )
130- return table
131133 meta .reflect (connection , only = [table_name ])
132134 table = meta .tables [
133135 full_table_name
@@ -174,19 +176,19 @@ def copy_table_structure(
174176 _ , schema_name , table_name = self .parse_full_table_name (full_table_name )
175177 meta = sa .MetaData (schema = schema_name )
176178 new_table : sa .Table
177- columns = []
178179 if self .table_exists (full_table_name = full_table_name ):
179180 raise RuntimeError ("Table already exists" )
180- for column in from_table .columns :
181- columns .append (column ._copy ())
181+
182+ columns = [column ._copy () for column in from_table .columns ]
183+
182184 if as_temp_table :
183185 new_table = sa .Table (table_name , meta , * columns , prefixes = ["TEMPORARY" ])
184186 new_table .create (bind = connection )
185187 return new_table
186- else :
187- new_table = sa .Table (table_name , meta , * columns )
188- new_table .create (bind = connection )
189- return new_table
188+
189+ new_table = sa .Table (table_name , meta , * columns )
190+ new_table .create (bind = connection )
191+ return new_table
190192
191193 @contextmanager
192194 def _connect (self ) -> t .Iterator [sa .engine .Connection ]:
@@ -197,18 +199,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
197199 """Drop table data."""
198200 table .drop (bind = connection )
199201
200- def clone_table (
202+ def clone_table ( # noqa: PLR0913
201203 self , new_table_name , table , metadata , connection , temp_table
202204 ) -> sa .Table :
203205 """Clone a table."""
204- new_columns = []
205- for column in table .columns :
206- new_columns .append (
207- sa .Column (
208- column .name ,
209- column .type ,
210- )
206+ new_columns = [
207+ sa .Column (
208+ column .name ,
209+ column .type ,
211210 )
211+ for column in table .columns
212+ ]
212213 if temp_table is True :
213214 new_table = sa .Table (
214215 new_table_name , metadata , * new_columns , prefixes = ["TEMPORARY" ]
@@ -293,9 +294,8 @@ def pick_individual_type(self, jsonschema_type: dict):
293294 ):
294295 return HexByteString ()
295296 individual_type = th .to_sql_type (jsonschema_type )
296- if isinstance (individual_type , VARCHAR ):
297- return TEXT ()
298- return individual_type
297+
298+ return TEXT () if isinstance (individual_type , VARCHAR ) else individual_type
299299
300300 @staticmethod
301301 def pick_best_sql_type (sql_type_array : list ):
@@ -323,13 +323,12 @@ def pick_best_sql_type(sql_type_array: list):
323323 NOTYPE ,
324324 ]
325325
326- for sql_type in precedence_order :
327- for obj in sql_type_array :
328- if isinstance (obj , sql_type ):
329- return obj
326+ for sql_type , obj in itertools .product (precedence_order , sql_type_array ):
327+ if isinstance (obj , sql_type ):
328+ return obj
330329 return TEXT ()
331330
332- def create_empty_table ( # type: ignore[override]
331+ def create_empty_table ( # type: ignore[override] # noqa: PLR0913
333332 self ,
334333 table_name : str ,
335334 meta : sa .MetaData ,
@@ -343,7 +342,7 @@ def create_empty_table( # type: ignore[override]
343342
344343 Args:
345344 table_name: the target table name.
346- meta: the SQLAchemy metadata object.
345+ meta: the SQLAlchemy metadata object.
347346 schema: the JSON schema for the new table.
348347 connection: the database connection.
349348 primary_keys: list of key properties.
@@ -386,7 +385,7 @@ def create_empty_table( # type: ignore[override]
386385 new_table .create (bind = connection )
387386 return new_table
388387
389- def prepare_column (
388+ def prepare_column ( # noqa: PLR0913
390389 self ,
391390 full_table_name : str ,
392391 column_name : str ,
@@ -434,7 +433,7 @@ def prepare_column(
434433 column_object = column_object ,
435434 )
436435
437- def _create_empty_column ( # type: ignore[override]
436+ def _create_empty_column ( # type: ignore[override] # noqa: PLR0913
438437 self ,
439438 schema_name : str ,
440439 table_name : str ,
@@ -499,7 +498,7 @@ def get_column_add_ddl( # type: ignore[override]
499498 },
500499 )
501500
502- def _adapt_column_type ( # type: ignore[override]
501+ def _adapt_column_type ( # type: ignore[override] # noqa: PLR0913
503502 self ,
504503 schema_name : str ,
505504 table_name : str ,
@@ -542,7 +541,7 @@ def _adapt_column_type( # type: ignore[override]
542541 return
543542
544543 # Not the same type, generic type or compatible types
545- # calling merge_sql_types for assistnace
544+ # calling merge_sql_types for assistance
546545 compatible_sql_type = self .merge_sql_types ([current_type , sql_type ])
547546
548547 if str (compatible_sql_type ) == str (current_type ):
@@ -612,17 +611,16 @@ def get_sqlalchemy_url(self, config: dict) -> str:
612611 if config .get ("sqlalchemy_url" ):
613612 return cast (str , config ["sqlalchemy_url" ])
614613
615- else :
616- sqlalchemy_url = URL .create (
617- drivername = config ["dialect+driver" ],
618- username = config ["user" ],
619- password = config ["password" ],
620- host = config ["host" ],
621- port = config ["port" ],
622- database = config ["database" ],
623- query = self .get_sqlalchemy_query (config ),
624- )
625- return cast (str , sqlalchemy_url )
614+ sqlalchemy_url = URL .create (
615+ drivername = config ["dialect+driver" ],
616+ username = config ["user" ],
617+ password = config ["password" ],
618+ host = config ["host" ],
619+ port = config ["port" ],
620+ database = config ["database" ],
621+ query = self .get_sqlalchemy_query (config ),
622+ )
623+ return cast (str , sqlalchemy_url )
626624
627625 def get_sqlalchemy_query (self , config : dict ) -> dict :
628626 """Get query values to be used for sqlalchemy URL creation.
@@ -638,7 +636,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
638636 # ssl_enable is for verifying the server's identity to the client.
639637 if config ["ssl_enable" ]:
640638 ssl_mode = config ["ssl_mode" ]
641- query . update ({ "sslmode" : ssl_mode })
639+ query [ "sslmode" ] = ssl_mode
642640 query ["sslrootcert" ] = self .filepath_or_certificate (
643641 value = config ["ssl_certificate_authority" ],
644642 alternative_name = config ["ssl_storage_directory" ] + "/root.crt" ,
@@ -684,12 +682,11 @@ def filepath_or_certificate(
684682 """
685683 if path .isfile (value ):
686684 return value
687- else :
688- with open (alternative_name , "wb" ) as alternative_file :
689- alternative_file .write (value .encode ("utf-8" ))
690- if restrict_permissions :
691- chmod (alternative_name , 0o600 )
692- return alternative_name
685+ with open (alternative_name , "wb" ) as alternative_file :
686+ alternative_file .write (value .encode ("utf-8" ))
687+ if restrict_permissions :
688+ chmod (alternative_name , 0o600 )
689+ return alternative_name
693690
694691 def guess_key_type (self , key_data : str ) -> paramiko .PKey :
695692 """Guess the type of the private key.
@@ -714,7 +711,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
714711 ):
715712 try :
716713 key = key_class .from_private_key (io .StringIO (key_data )) # type: ignore[attr-defined]
717- except paramiko .SSHException :
714+ except paramiko .SSHException : # noqa: PERF203
718715 continue
719716 else :
720717 return key
@@ -734,7 +731,7 @@ def catch_signal(self, signum, frame) -> None:
734731 signum: The signal number
735732 frame: The current stack frame
736733 """
737- exit (1 ) # Calling this to be sure atexit is called, so clean_up gets called
734+ sys . exit (1 ) # Calling this to be sure atexit is called, so clean_up gets called
738735
739736 def _get_column_type ( # type: ignore[override]
740737 self ,
0 commit comments