Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions metadata-ingestion/src/datahub/testing/state_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional, cast

from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)


def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
"""
Helper method to retrieve the current checkpoint from a pipeline.
"""
stateful_source = cast(StatefulIngestionSourceBase, pipeline.source)
return stateful_source.state_provider.get_current_checkpoint(
StaleEntityRemovalHandler.compute_job_id(
getattr(stateful_source, "platform", "default")
)
)

22 changes: 2 additions & 20 deletions metadata-ingestion/tests/test_helpers/state_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.graph.config import DatahubClientConfig
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
from datahub.testing.state_helpers import (
get_current_checkpoint_from_pipeline as get_current_checkpoint_from_pipeline,
)


Expand Down Expand Up @@ -107,16 +102,3 @@ def mock_datahub_graph_instance(
mock_datahub_graph: Callable[[DatahubClientConfig], DataHubGraph],
) -> DataHubGraph:
return mock_datahub_graph(DatahubClientConfig(server="http://fake.domain.local"))


def get_current_checkpoint_from_pipeline(
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: This only works for stale entity removal. We need to generalize this.

stateful_source = cast(StatefulIngestionSourceBase, pipeline.source)
return stateful_source.state_provider.get_current_checkpoint(
StaleEntityRemovalHandler.compute_job_id(
getattr(stateful_source, "platform", "default")
)
)
21 changes: 3 additions & 18 deletions smoke-test/tests/test_stateful_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.testing.state_helpers import get_current_checkpoint_from_pipeline
from tests.utils import get_db_password, get_db_type, get_db_url, get_db_username


Expand Down Expand Up @@ -50,22 +51,6 @@ def validate_all_providers_have_committed_successfully(pipeline: Pipeline) -> No
assert stateful_committable.state_to_commit
assert provider_count == 1

def get_current_checkpoint_from_pipeline(
auth_session,
pipeline: Pipeline,
) -> Optional[Checkpoint[GenericCheckpointState]]:
# TODO: Refactor to use the helper method in the metadata-ingestion tests, instead of copying it here.
sql_source: Union[MySQLSource, PostgresSource]
if get_db_type() == "mysql":
sql_source = cast(MySQLSource, pipeline.source)
else:
sql_source = cast(PostgresSource, pipeline.source)
return sql_source.state_provider.get_current_checkpoint(
StaleEntityRemovalHandler.compute_job_id(
getattr(sql_source, "platform", "default")
)
)

source_config_dict: Dict[str, Any] = {
"host_port": get_db_url(),
"username": get_db_username(),
Expand Down Expand Up @@ -113,14 +98,14 @@ def get_current_checkpoint_from_pipeline(

# 3. Do the first run of the pipeline and get the default job's checkpoint.
pipeline_run1 = run_and_get_pipeline(pipeline_config_dict)
checkpoint1 = get_current_checkpoint_from_pipeline(auth_session, pipeline_run1)
checkpoint1 = get_current_checkpoint_from_pipeline(pipeline_run1)
assert checkpoint1
assert checkpoint1.state

# 4. Drop table t1 created during step 2 + rerun the pipeline and get the checkpoint state.
drop_table(sql_engine, table_names[0])
pipeline_run2 = run_and_get_pipeline(pipeline_config_dict)
checkpoint2 = get_current_checkpoint_from_pipeline(auth_session, pipeline_run2)
checkpoint2 = get_current_checkpoint_from_pipeline(pipeline_run2)
assert checkpoint2
assert checkpoint2.state

Expand Down
Loading