Skip to content
215 changes: 64 additions & 151 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VARCHAR,
TypeDecorator,
)
from singer_sdk.helpers.capabilities import TargetLoadMethods
from sshtunnel import SSHTunnelForwarder


Expand All @@ -41,6 +42,7 @@ class PostgresConnector(SQLConnector):
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = True # Whether MERGE UPSERT is supported.
allow_overwrite: bool = True # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.

def __init__(self, config: dict) -> None:
Expand Down Expand Up @@ -92,6 +94,24 @@ def interpret_content_encoding(self) -> bool:
"""
return self.config.get("interpret_content_encoding", False)

def get_table_from_metadata(
self,
full_table_name: str,
connection: sa.engine.Connection
) -> sa.Table:
"""Returns an existing table object from the database

Args:
full_table_name: the fully qualified table name.

Returns:
The table object.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
meta.reflect(connection, only=[table_name])
return meta.tables[full_table_name]

def prepare_table( # type: ignore[override]
self,
full_table_name: str,
Expand All @@ -100,7 +120,7 @@ def prepare_table( # type: ignore[override]
connection: sa.engine.Connection,
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
) -> sa.Table:
) -> None:
"""Adapt target table to provided schema if possible.

Args:
Expand All @@ -117,26 +137,39 @@ def prepare_table( # type: ignore[override]
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
table: sa.Table

if not self.table_exists(full_table_name=full_table_name):
table = self.create_empty_table(
table_name=table_name,
self.create_empty_table(
full_table_name=full_table_name,
meta=meta,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
connection=connection,
)
return

if self.config["load_method"] == TargetLoadMethods.OVERWRITE:
self.get_table(full_table_name=full_table_name).drop(self._engine)
self.create_empty_table(
full_table_name=full_table_name,
meta=meta,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
connection=connection,
)
return table
return

meta.reflect(connection, only=[table_name])
table = meta.tables[
full_table_name
] # So we don't mess up the casing of the Table reference

columns = self.get_table_columns(
schema_name=cast(str, schema_name),
table_name=table_name,
connection=connection,
full_table_name=full_table_name,
)

for property_name, property_def in schema["properties"].items():
Expand All @@ -151,8 +184,6 @@ def prepare_table( # type: ignore[override]
column_object=column_object,
)

return meta.tables[full_table_name]

def copy_table_structure(
self,
full_table_name: str,
Expand Down Expand Up @@ -331,7 +362,7 @@ def pick_best_sql_type(sql_type_array: list):

def create_empty_table( # type: ignore[override]
self,
table_name: str,
full_table_name: str,
meta: sa.MetaData,
schema: dict,
connection: sa.engine.Connection,
Expand All @@ -357,6 +388,9 @@ def create_empty_table( # type: ignore[override]
NotImplementedError: if temp tables are unsupported and as_temp_table=True.
RuntimeError: if a variant schema is passed with no properties defined.
"""

_, schema_name, table_name = self.parse_full_table_name(full_table_name)

columns: list[sa.Column] = []
primary_keys = primary_keys or []
try:
Expand Down Expand Up @@ -410,66 +444,31 @@ def prepare_column(
_, schema_name, table_name = self.parse_full_table_name(full_table_name)

column_exists = column_object is not None or self.column_exists(
full_table_name, column_name, connection=connection
full_table_name, column_name,
)

if not column_exists:
self._create_empty_column(
# We should migrate every function to use sa.Table
# instead of having to know what the function wants
table_name=table_name,
full_table_name=full_table_name,
column_name=column_name,
sql_type=sql_type,
schema_name=cast(str, schema_name),
connection=connection,
)
return

self._adapt_column_type(
schema_name=cast(str, schema_name),
table_name=table_name,
full_table_name=full_table_name,
column_name=column_name,
sql_type=sql_type,
connection=connection,
column_object=column_object,
)

def _create_empty_column( # type: ignore[override]
self,
schema_name: str,
table_name: str,
column_name: str,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection,
) -> None:
"""Create a new column.

Args:
schema_name: The schema name.
table_name: The table name.
column_name: The name of the new column.
sql_type: SQLAlchemy type engine to be used in creating the new column.
connection: The database connection.

Raises:
NotImplementedError: if adding columns is not supported.
"""
if not self.allow_column_add:
msg = "Adding columns is not supported."
raise NotImplementedError(msg)

column_add_ddl = self.get_column_add_ddl(
schema_name=schema_name,
table_name=table_name,
column_name=column_name,
column_type=sql_type,
)
connection.execute(column_add_ddl)

def get_column_add_ddl( # type: ignore[override]
self,
table_name: str,
schema_name: str,
column_name: str,
column_type: sa.types.TypeEngine,
) -> sa.DDL:
Expand All @@ -484,6 +483,8 @@ def get_column_add_ddl( # type: ignore[override]
Returns:
A sqlalchemy DDL instance.
"""
_, schema_name, table_name = self.parse_full_table_name(table_name)

column = sa.Column(column_name, column_type)

return sa.DDL(
Expand All @@ -501,12 +502,11 @@ def get_column_add_ddl( # type: ignore[override]

def _adapt_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
full_table_name: str,
column_name: str,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection,
column_object: sa.Column | None,
connection: sa.engine.Connection | None = None,
column_object: sa.Column | None = None,
) -> None:
"""Adapt table column type to support the new JSON schema type.

Expand All @@ -521,15 +521,21 @@ def _adapt_column_type( # type: ignore[override]
Raises:
NotImplementedError: if altering columns is not supported.
"""
if connection is None:
super()._adapt_column_type(
full_table_name=full_table_name,
column_name=column_name,
sql_type=sql_type,
)
return

current_type: sa.types.TypeEngine
if column_object is not None:
current_type = t.cast(sa.types.TypeEngine, column_object.type)
else:
current_type = self._get_column_type(
schema_name=schema_name,
table_name=table_name,
full_table_name=full_table_name,
column_name=column_name,
connection=connection,
)

# remove collation if present and save it
Expand All @@ -556,22 +562,20 @@ def _adapt_column_type( # type: ignore[override]
if not self.allow_column_alter:
msg = (
"Altering columns is not supported. Could not convert column "
f"'{schema_name}.{table_name}.{column_name}' from '{current_type}' to "
f"'{full_table_name}.{column_name}' from '{current_type}' to "
f"'{compatible_sql_type}'."
)
raise NotImplementedError(msg)

alter_column_ddl = self.get_column_alter_ddl(
schema_name=schema_name,
table_name=table_name,
table_name=full_table_name,
column_name=column_name,
column_type=compatible_sql_type,
)
connection.execute(alter_column_ddl)

def get_column_alter_ddl( # type: ignore[override]
self,
schema_name: str,
table_name: str,
column_name: str,
column_type: sa.types.TypeEngine,
Expand All @@ -589,6 +593,7 @@ def get_column_alter_ddl( # type: ignore[override]
Returns:
A sqlalchemy DDL instance.
"""
_, schema_name, _ = self.parse_full_table_name(table_name)
column = sa.Column(column_name, column_type)
return sa.DDL(
(
Expand Down Expand Up @@ -736,98 +741,6 @@ def catch_signal(self, signum, frame) -> None:
"""
exit(1) # Calling this to be sure atexit is called, so clean_up gets called

def _get_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
column_name: str,
connection: sa.engine.Connection,
) -> sa.types.TypeEngine:
"""Get the SQL type of the declared column.

Args:
schema_name: The schema name.
table_name: The table name.
column_name: The name of the column.
connection: The database connection.

Returns:
The type of the column.

Raises:
KeyError: If the provided column name does not exist.
"""
try:
column = self.get_table_columns(
schema_name=schema_name,
table_name=table_name,
connection=connection,
)[column_name]
except KeyError as ex:
msg = (
f"Column `{column_name}` does not exist in table"
"`{schema_name}.{table_name}`."
)
raise KeyError(msg) from ex

return t.cast(sa.types.TypeEngine, column.type)

def get_table_columns( # type: ignore[override]
self,
schema_name: str,
table_name: str,
connection: sa.engine.Connection,
column_names: list[str] | None = None,
) -> dict[str, sa.Column]:
"""Return a list of table columns.

Overrode to support schema_name

Args:
schema_name: schema name.
table_name: table name to get columns for.
connection: database connection.
column_names: A list of column names to filter to.

Returns:
An ordered list of column objects.
"""
inspector = sa.inspect(connection)
columns = inspector.get_columns(table_name, schema_name)

return {
col_meta["name"]: sa.Column(
col_meta["name"],
col_meta["type"],
nullable=col_meta.get("nullable", False),
)
for col_meta in columns
if not column_names
or col_meta["name"].casefold() in {col.casefold() for col in column_names}
}

def column_exists( # type: ignore[override]
self,
full_table_name: str,
column_name: str,
connection: sa.engine.Connection,
) -> bool:
"""Determine if the target column already exists.

Args:
full_table_name: the target table name.
column_name: the target column name.
connection: the database connection.

Returns:
True if table exists, False if not.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
assert schema_name is not None
assert table_name is not None
return column_name in self.get_table_columns(
schema_name=schema_name, table_name=table_name, connection=connection
)


class NOTYPE(TypeDecorator):
Expand Down
Loading