Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions tap_github/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def url_base(self) -> str:
replication_key: str | None = None
tolerated_http_errors: ClassVar[list[int]] = []

# Save the context from the requests so it can be available to the parse_response method
context: dict | None = None

@property
def http_headers(self) -> dict[str, str]:
"""Return the http headers needed."""
Expand Down Expand Up @@ -142,6 +145,9 @@ def get_url_params(
context: dict | None,
next_page_token: Any | None, # noqa: ANN401
) -> dict[str, Any]:
# save the context from the requests so it can be available to the parse_response method
self.context = context
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK this is not necessary. The stream class already has a context attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I don't think it is available 😞 on the core RESTStream class. That is why all other method signatures include it to be passed in. For some reason it was excluded from this method.


"""Return a dictionary of values to be used in URL parameterization."""
params: dict = {"per_page": self.MAX_PER_PAGE}
if next_page_token:
Expand Down Expand Up @@ -250,7 +256,7 @@ def validate_response(self, response: requests.Response) -> None:

def parse_response(self, response: requests.Response) -> Iterable[dict]:
"""Parse the response and return an iterator of result rows."""
# TODO - Split into handle_reponse and parse_response.
# TODO - Split into handle_response and parse_response.
if response.status_code in (
[*self.tolerated_http_errors, EMPTY_REPO_ERROR_STATUS]
):
Expand All @@ -259,16 +265,30 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
# Update token rate limit info and loop through tokens if needed.
self.authenticator.update_rate_limit(response.headers)

# Get all items from the response
resp_json = response.json()

if isinstance(resp_json, list):
results = resp_json
elif resp_json.get("items") is not None:
results = resp_json.get("items")
else:
results = [resp_json]

yield from results
if not results:
return

# Filter items based on replication key's date if needed
since = self.get_starting_timestamp(self.context)
filtered_results = []
if self.replication_key and self.use_fake_since_parameter and since:
for item in results:
item_date = parse(item[self.replication_key])
if item_date >= since:
filtered_results.append(item)
else:
filtered_results = results

yield from filtered_results

def post_process(self, row: dict, context: dict[str, str] | None = None) -> dict:
"""Add `repo_id` by default to all streams."""
Expand Down
Loading