diff --git a/metadata-ingestion/src/datahub/testing/state_helpers.py b/metadata-ingestion/src/datahub/testing/state_helpers.py new file mode 100644 index 0000000000000..96eaed21b72e2 --- /dev/null +++ b/metadata-ingestion/src/datahub/testing/state_helpers.py @@ -0,0 +1,26 @@ +from typing import Optional, cast + +from datahub.ingestion.run.pipeline import Pipeline +from datahub.ingestion.source.state.checkpoint import Checkpoint +from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState +from datahub.ingestion.source.state.stale_entity_removal_handler import ( + StaleEntityRemovalHandler, +) +from datahub.ingestion.source.state.stateful_ingestion_base import ( + StatefulIngestionSourceBase, +) + + +def get_current_checkpoint_from_pipeline( + pipeline: Pipeline, +) -> Optional[Checkpoint[GenericCheckpointState]]: + """ + Helper method to retrieve the current checkpoint from a pipeline. + """ + stateful_source = cast(StatefulIngestionSourceBase, pipeline.source) + return stateful_source.state_provider.get_current_checkpoint( + StaleEntityRemovalHandler.compute_job_id( + getattr(stateful_source, "platform", "default") + ) + ) + diff --git a/metadata-ingestion/tests/test_helpers/state_helpers.py b/metadata-ingestion/tests/test_helpers/state_helpers.py index c469db6ce8cf8..53dce902d2155 100644 --- a/metadata-ingestion/tests/test_helpers/state_helpers.py +++ b/metadata-ingestion/tests/test_helpers/state_helpers.py @@ -12,13 +12,8 @@ from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.config import DatahubClientConfig from datahub.ingestion.run.pipeline import Pipeline -from datahub.ingestion.source.state.checkpoint import Checkpoint -from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState -from datahub.ingestion.source.state.stale_entity_removal_handler import ( - StaleEntityRemovalHandler, -) -from datahub.ingestion.source.state.stateful_ingestion_base import ( - StatefulIngestionSourceBase, +from datahub.testing.state_helpers import ( + get_current_checkpoint_from_pipeline as get_current_checkpoint_from_pipeline, ) @@ -107,16 +102,3 @@ def mock_datahub_graph_instance( mock_datahub_graph: Callable[[DatahubClientConfig], DataHubGraph], ) -> DataHubGraph: return mock_datahub_graph(DatahubClientConfig(server="http://fake.domain.local")) - - -def get_current_checkpoint_from_pipeline( - pipeline: Pipeline, -) -> Optional[Checkpoint[GenericCheckpointState]]: - # TODO: This only works for stale entity removal. We need to generalize this. - - stateful_source = cast(StatefulIngestionSourceBase, pipeline.source) - return stateful_source.state_provider.get_current_checkpoint( - StaleEntityRemovalHandler.compute_job_id( - getattr(stateful_source, "platform", "default") - ) - ) diff --git a/smoke-test/tests/test_stateful_ingestion.py b/smoke-test/tests/test_stateful_ingestion.py index 0460d1168a518..e1b448050b2dc 100644 --- a/smoke-test/tests/test_stateful_ingestion.py +++ b/smoke-test/tests/test_stateful_ingestion.py @@ -12,6 +12,7 @@ from datahub.ingestion.source.state.stale_entity_removal_handler import ( StaleEntityRemovalHandler, ) +from datahub.testing.state_helpers import get_current_checkpoint_from_pipeline from tests.utils import get_db_password, get_db_type, get_db_url, get_db_username @@ -50,22 +51,6 @@ def validate_all_providers_have_committed_successfully(pipeline: Pipeline) -> No assert stateful_committable.state_to_commit assert provider_count == 1 - def get_current_checkpoint_from_pipeline( - auth_session, - pipeline: Pipeline, - ) -> Optional[Checkpoint[GenericCheckpointState]]: - # TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here. - sql_source: Union[MySQLSource, PostgresSource] - if get_db_type() == "mysql": - sql_source = cast(MySQLSource, pipeline.source) - else: - sql_source = cast(PostgresSource, pipeline.source) - return sql_source.state_provider.get_current_checkpoint( - StaleEntityRemovalHandler.compute_job_id( - getattr(sql_source, "platform", "default") - ) - ) - source_config_dict: Dict[str, Any] = { "host_port": get_db_url(), "username": get_db_username(), @@ -113,14 +98,14 @@ def get_current_checkpoint_from_pipeline( # 3. Do the first run of the pipeline and get the default job's checkpoint. pipeline_run1 = run_and_get_pipeline(pipeline_config_dict) - checkpoint1 = get_current_checkpoint_from_pipeline(auth_session, pipeline_run1) + checkpoint1 = get_current_checkpoint_from_pipeline(pipeline_run1) assert checkpoint1 assert checkpoint1.state # 4. Drop table t1 created during step 2 + rerun the pipeline and get the checkpoint state. drop_table(sql_engine, table_names[0]) pipeline_run2 = run_and_get_pipeline(pipeline_config_dict) - checkpoint2 = get_current_checkpoint_from_pipeline(auth_session, pipeline_run2) + checkpoint2 = get_current_checkpoint_from_pipeline(pipeline_run2) assert checkpoint2 assert checkpoint2.state