Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 126 additions & 30 deletions ddtrace/_trace/_inferred_proxy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from dataclasses import dataclass
import logging
from typing import Callable
from typing import Dict
from typing import Union
from typing import Optional

from ddtrace import config
from ddtrace._trace.span import Span
from ddtrace.ext import SpanKind
from ddtrace.ext import SpanTypes
from ddtrace.ext import http
from ddtrace.internal.constants import COMPONENT
Expand All @@ -13,18 +16,64 @@

log = logging.getLogger(__name__)


@dataclass
class ProxyHeaderContext:
system_name: str
request_time: str
method: Optional[str]
path: Optional[str]
resource_path: Optional[str]
domain_name: Optional[str]
stage: Optional[str]
account_id: Optional[str]
api_id: Optional[str]
region: Optional[str]
user: Optional[str]
useragent: Optional[str]


@dataclass
class ProxyInfo:
span_name: str
component: str
resource_arn_builder: Optional[Callable[[ProxyHeaderContext], Optional[str]]] = None


def _api_gateway_rest_api_arn(proxy_context: ProxyHeaderContext) -> Optional[str]:
if proxy_context.region and proxy_context.api_id:
return f"arn:aws:apigateway:{proxy_context.region}::/restapis/{proxy_context.api_id}"
return None


def _api_gateway_http_api_arn(proxy_context: ProxyHeaderContext) -> Optional[str]:
if proxy_context.region and proxy_context.api_id:
return f"arn:aws:apigateway:{proxy_context.region}::/apis/{proxy_context.api_id}"
return None


supported_proxies: Dict[str, ProxyInfo] = {
"aws-apigateway": ProxyInfo("aws.apigateway", "aws-apigateway", _api_gateway_rest_api_arn),
"aws-httpapi": ProxyInfo("aws.httpapi", "aws-httpapi", _api_gateway_http_api_arn),
}

SUPPORTED_PROXY_SPAN_NAMES = {info.span_name for info in supported_proxies.values()}

# Checking lower case and upper case versions per WSGI spec following ddtrace/propagation/http.py's
# logic to extract http headers
POSSIBLE_PROXY_HEADER_SYSTEM = _possible_header("x-dd-proxy")
POSSIBLE_PROXY_HEADER_START_TIME_MS = _possible_header("x-dd-proxy-request-time-ms")
POSSIBLE_PROXY_HEADER_PATH = _possible_header("x-dd-proxy-path")
POSSIBLE_PROXY_HEADER_RESOURCE_PATH = _possible_header("x-dd-proxy-resource-path")
POSSIBLE_PROXY_HEADER_HTTPMETHOD = _possible_header("x-dd-proxy-httpmethod")
POSSIBLE_PROXY_HEADER_DOMAIN = _possible_header("x-dd-proxy-domain-name")
POSSIBLE_PROXY_HEADER_STAGE = _possible_header("x-dd-proxy-stage")
POSSIBLE_PROXY_HEADER_ACCOUNT_ID = _possible_header("x-dd-proxy-account-id")
POSSIBLE_PROXY_HEADER_API_ID = _possible_header("x-dd-proxy-api-id")
POSSIBLE_PROXY_HEADER_REGION = _possible_header("x-dd-proxy-region")
POSSIBLE_PROXY_HEADER_USER = _possible_header("x-dd-proxy-user")

supported_proxies: Dict[str, Dict[str, str]] = {
"aws-apigateway": {"span_name": "aws.apigateway", "component": "aws-apigateway"}
}
HEADER_USERAGENT = _possible_header("user-agent")


def create_inferred_proxy_span_if_headers_exist(ctx, headers, child_of, tracer) -> None:
Expand All @@ -38,19 +87,23 @@ def create_inferred_proxy_span_if_headers_exist(ctx, headers, child_of, tracer)
if not proxy_context:
return None

proxy_span_info = supported_proxies[proxy_context["proxy_system_name"]]
proxy_info = supported_proxies[proxy_context.system_name]

method = proxy_context.method
route_or_path = proxy_context.resource_path or proxy_context.path
resource = f"{method or ''} {route_or_path or ''}"

span = tracer.start_span(
proxy_span_info["span_name"],
service=proxy_context.get("domain_name", config._get_service()),
resource=proxy_context["method"] + " " + proxy_context["path"],
proxy_info.span_name,
service=proxy_context.domain_name or config._get_service(),
resource=resource,
span_type=SpanTypes.WEB,
activate=True,
child_of=child_of,
)
span.start_ns = int(proxy_context["request_time"]) * 1000000
span.start_ns = int(proxy_context.request_time) * 1000000

set_inferred_proxy_span_tags(span, proxy_context)
set_inferred_proxy_span_tags(span, proxy_context, proxy_info)

# we need a callback to finish the api gateway span, this callback will be added to the child spans finish callbacks
def finish_callback(_):
Expand All @@ -62,24 +115,61 @@ def finish_callback(_):
ctx.set_item("headers", headers)


def set_inferred_proxy_span_tags(span, proxy_context) -> Span:
span._set_tag_str(COMPONENT, supported_proxies[proxy_context["proxy_system_name"]]["component"])
def set_inferred_proxy_span_tags(span: Span, proxy_context: ProxyHeaderContext, proxy_info: ProxyInfo) -> Span:
span._set_tag_str(COMPONENT, proxy_info.component)
span._set_tag_str("span.kind", SpanKind.SERVER)

span._set_tag_str(http.METHOD, proxy_context["method"])
span._set_tag_str(http.URL, f"{proxy_context['domain_name']}{proxy_context['path']}")
span._set_tag_str("stage", proxy_context["stage"])
span._set_tag_str(http.URL, f"https://{proxy_context.domain_name or ''}{proxy_context.path or ''}")

if proxy_context.method:
span._set_tag_str(http.METHOD, proxy_context.method)

if proxy_context.resource_path:
span._set_tag_str(http.ROUTE, proxy_context.resource_path)

if proxy_context.useragent:
span._set_tag_str(http.USER_AGENT, proxy_context.useragent)

if proxy_context.stage:
span._set_tag_str("stage", proxy_context.stage)

if proxy_context.account_id:
span._set_tag_str("account_id", proxy_context.account_id)

if proxy_context.api_id:
span._set_tag_str("apiid", proxy_context.api_id)

if proxy_context.region:
span._set_tag_str("region", proxy_context.region)

if proxy_context.user:
span._set_tag_str("aws_user", proxy_context.user)

if proxy_info.resource_arn_builder:
resource_arn = proxy_info.resource_arn_builder(proxy_context)
if resource_arn:
span._set_tag_str("dd_resource_key", resource_arn)

span.set_metric("_dd.inferred_span", 1)
return span


def extract_inferred_proxy_context(headers) -> Union[None, Dict[str, str]]:
proxy_header_system = str(_extract_header_value(POSSIBLE_PROXY_HEADER_SYSTEM, headers))
proxy_header_start_time_ms = str(_extract_header_value(POSSIBLE_PROXY_HEADER_START_TIME_MS, headers))
proxy_header_path = str(_extract_header_value(POSSIBLE_PROXY_HEADER_PATH, headers))
proxy_header_httpmethod = str(_extract_header_value(POSSIBLE_PROXY_HEADER_HTTPMETHOD, headers))
proxy_header_domain = str(_extract_header_value(POSSIBLE_PROXY_HEADER_DOMAIN, headers))
proxy_header_stage = str(_extract_header_value(POSSIBLE_PROXY_HEADER_STAGE, headers))
def extract_inferred_proxy_context(headers) -> Optional[ProxyHeaderContext]:
proxy_header_system = _extract_header_value(POSSIBLE_PROXY_HEADER_SYSTEM, headers)
proxy_header_start_time_ms = _extract_header_value(POSSIBLE_PROXY_HEADER_START_TIME_MS, headers)
proxy_header_path = _extract_header_value(POSSIBLE_PROXY_HEADER_PATH, headers)
proxy_header_resource_path = _extract_header_value(POSSIBLE_PROXY_HEADER_RESOURCE_PATH, headers)

proxy_header_httpmethod = _extract_header_value(POSSIBLE_PROXY_HEADER_HTTPMETHOD, headers)
proxy_header_domain = _extract_header_value(POSSIBLE_PROXY_HEADER_DOMAIN, headers)
proxy_header_stage = _extract_header_value(POSSIBLE_PROXY_HEADER_STAGE, headers)

proxy_header_account_id = _extract_header_value(POSSIBLE_PROXY_HEADER_ACCOUNT_ID, headers)
proxy_header_api_id = _extract_header_value(POSSIBLE_PROXY_HEADER_API_ID, headers)
proxy_header_region = _extract_header_value(POSSIBLE_PROXY_HEADER_REGION, headers)
proxy_header_user = _extract_header_value(POSSIBLE_PROXY_HEADER_USER, headers)

header_user_agent = _extract_header_value(HEADER_USERAGENT, headers)

# Exit if start time header is not present
if proxy_header_start_time_ms is None:
Expand All @@ -92,14 +182,20 @@ def extract_inferred_proxy_context(headers) -> Union[None, Dict[str, str]]:
)
return None

return {
"request_time": proxy_header_start_time_ms,
"method": proxy_header_httpmethod,
"path": proxy_header_path,
"stage": proxy_header_stage,
"domain_name": proxy_header_domain,
"proxy_system_name": proxy_header_system,
}
return ProxyHeaderContext(
proxy_header_system,
proxy_header_start_time_ms,
proxy_header_httpmethod,
proxy_header_path,
proxy_header_resource_path,
proxy_header_domain,
proxy_header_stage,
proxy_header_account_id,
proxy_header_api_id,
proxy_header_region,
proxy_header_user,
header_user_agent,
)


def normalize_headers(headers) -> Dict[str, str]:
Expand Down
3 changes: 2 additions & 1 deletion ddtrace/_trace/trace_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ddtrace
from ddtrace import config
from ddtrace._trace._inferred_proxy import SUPPORTED_PROXY_SPAN_NAMES
from ddtrace._trace._inferred_proxy import create_inferred_proxy_span_if_headers_exist
from ddtrace._trace._span_link import SpanLinkKind as _SpanLinkKind
from ddtrace._trace._span_pointer import _SpanPointerDescription
Expand Down Expand Up @@ -244,7 +245,7 @@ def _on_web_framework_finish_request(


def _set_inferred_proxy_tags(span, status_code):
if span._parent and span._parent.name == "aws.apigateway":
if span._parent and span._parent.name in SUPPORTED_PROXY_SPAN_NAMES:
inferred_span = span._parent
status_code = status_code if status_code else span.get_tag("http.status_code")
if status_code:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
tracing: Ensures inferred proxy spans for AWS API Gateway HTTP APIs are created when the
``x-dd-proxy`` header reports ``aws-httpapi``.
AAP: Update inferred proxy span tags to ensure that inferred services are discovered by the App and API Protection API Catalog.
2 changes: 1 addition & 1 deletion tests/contrib/aiohttp/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ async def test_inferred_spans_api_gateway(app_tracer, aiohttp_client, test_app,
api_gateway_resource="GET /",
method="GET",
status_code=str(test_app["status_code"]),
url="local/",
url="https://local/",
start=1736973768,
is_distributed=test_headers["type"] == "distributed",
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/asgi/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ async def test_inferred_spans_api_gateway_default(scope, tracer, test_spans, app
api_gateway_resource="GET /",
method="GET",
status_code=app_type["status_code"],
url="local/",
url="https://local/",
start=1736973768,
is_distributed=headers == distributed_headers,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/bottle/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def handled_error_endpoint():
api_gateway_resource="GET /",
method="GET",
status_code=str(test_endpoint["status"]),
url="local/",
url="https://local/",
start=1736973768,
is_distributed=False,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/bottle/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def default_endpoint():
api_gateway_resource="GET /",
method="GET",
status_code=200,
url="local/",
url="https://local/",
start=1736973768,
is_distributed=True,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/cherrypy/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def test_inferred_spans_api_gateway_default(self):
api_gateway_resource="GET /",
method="GET",
status_code=test_endpoint["status"],
url="local/",
url="https://local/",
start=1736973768,
is_distributed=test_headers == distributed_headers,
distributed_trace_id=1,
Expand Down
6 changes: 3 additions & 3 deletions tests/contrib/django/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,7 +1952,7 @@ def test_inferred_spans_api_gateway_default(client, test_spans):
api_gateway_resource="GET /",
method="GET",
status_code="200",
url="local/",
url="https://local/",
start=1736973768.0,
)

Expand All @@ -1975,7 +1975,7 @@ def test_inferred_spans_api_gateway_default(client, test_spans):
api_gateway_resource="GET /",
method="GET",
status_code="500",
url="local/",
url="https://local/",
start=1736973768.0,
)

Expand Down Expand Up @@ -2035,7 +2035,7 @@ def test_inferred_spans_api_gateway_distributed_tracing(client, test_spans):
api_gateway_resource="GET /",
method="GET",
status_code="200",
url="local/",
url="https://local/",
start=1736973768.0,
is_distributed=True,
distributed_trace_id=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_inferred_spans_api_gateway_default(client, test_spans, test_endpoint, i
api_gateway_resource="GET /",
method="GET",
status_code=test_endpoint["status_code"],
url="local/",
url="https://local/",
start=1736973768,
is_distributed=headers == distributed_headers,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/falcon/test_distributed_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_inferred_spans_api_gateway_distributed_tracing_enabled(self):
api_gateway_resource="GET /",
method="GET",
status_code="200",
url="local/",
url="https://local/",
start=1736973768.0,
is_distributed=True,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/falcon/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_inferred_spans_api_gateway_default(self):
api_gateway_resource="GET /",
method="GET",
status_code=test_endpoint["status"],
url="local/",
url="https://local/",
start=1736973768.0,
is_distributed=False,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def test_inferred_spans_api_gateway(client, tracer, test_spans, test, inferred_p
api_gateway_resource="GET /",
method="GET",
status_code=test["status_code"],
url="local/",
url="https://local/",
start=1736973768,
is_distributed=test_headers["type"] == "distributed",
distributed_trace_id=1,
Expand Down
8 changes: 4 additions & 4 deletions tests/contrib/flask/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def error_status_code():
api_gateway_resource="GET /",
method="GET",
status_code="200",
url="local/",
url="https://local/",
start=1736973768,
)

Expand All @@ -330,7 +330,7 @@ def error_status_code():
api_gateway_resource="GET /",
method="GET",
status_code="500",
url="local/",
url="https://local/",
start=1736973768,
)

Expand All @@ -351,7 +351,7 @@ def error_status_code():
api_gateway_resource="GET /",
method="GET",
status_code="599",
url="local/",
url="https://local/",
start=1736973768,
)

Expand Down Expand Up @@ -405,7 +405,7 @@ def index():
api_gateway_resource="GET /",
method="GET",
status_code="200",
url="local/",
url="https://local/",
start=1736973768,
is_distributed=True,
distributed_trace_id=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/molten/test_molten.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_inferred_spans_api_gateway_default(self):
api_gateway_resource="GET /",
method="GET",
status_code=test_endpoint["status"],
url="local/",
url="https://local/",
start=1736973768,
is_distributed=test_headers == distributed_headers,
distributed_trace_id=1,
Expand Down
Loading
Loading