diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java index c51bd46ef44..5be83b3e702 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java @@ -72,12 +72,20 @@ public boolean preSampleRequest(final @Nonnull AppSecRequestContext ctx) { if (counter.tryAcquire()) { log.debug("API security sampling is required for this request (presampled)"); ctx.setKeepOpenForApiSecurityPostProcessing(true); + // Update immediately to prevent concurrent requests from seeing the same expired state + updateApiAccessIfExpired(hash); return true; } return false; } - /** Get the final sampling decision. This method is NOT thread-safe. */ + /** + * Confirms the final sampling decision. + * + *

This method is called after the span completes. The actual sampling decision and map update + * already happened in {@link #preSampleRequest(AppSecRequestContext)} to prevent race conditions. + * This method only serves as a final confirmation gate before schema extraction. + */ @Override public boolean sampleRequest(AppSecRequestContext ctx) { if (ctx == null) { @@ -88,7 +96,7 @@ public boolean sampleRequest(AppSecRequestContext ctx) { // This should never happen, it should have been short-circuited before. return false; } - return updateApiAccessIfExpired(hash); + return true; } @Override diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java index af97152609e..d7ca7f5da11 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java @@ -57,11 +57,15 @@ public void process(@Nonnull AgentSpan span, @Nonnull BooleanSupplier timeoutChe extractSchemas(ctx, ctx_.getTraceSegment()); } finally { ctx.setKeepOpenForApiSecurityPostProcessing(false); + // XXX: Close the additive first. This is not strictly needed, but it'll prevent getting it + // detected as a + // missed request-ended event. try { - // XXX: Close the additive first. This is not strictly needed, but it'll prevent getting it - // detected as a - // missed request-ended event. ctx.closeWafContext(); + } catch (Exception e) { + log.debug("Error closing WAF context", e); + } + try { ctx.close(); } catch (Exception e) { log.debug("Error closing AppSecRequestContext", e); diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java index 65ae99aa17c..cb4d16c831e 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java @@ -836,17 +836,19 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) { TraceSegment traceSeg = ctx_.getTraceSegment(); Map tags = spanInfo.getTags(); - if (maybeSampleForApiSecurity(ctx, spanInfo, tags)) { - if (!Config.get().isApmTracingEnabled()) { - traceSeg.setTagTop(Tags.ASM_KEEP, true); - traceSeg.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM); - } - } else { + boolean sampledForApiSec = maybeSampleForApiSecurity(ctx, spanInfo, tags); + + if (!sampledForApiSec) { ctx.closeWafContext(); } // AppSec report metric and events for web span only if (traceSeg != null) { + if (sampledForApiSec && !Config.get().isApmTracingEnabled()) { + traceSeg.setTagTop(Tags.ASM_KEEP, true); + traceSeg.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM); + } + traceSeg.setTagTop("_dd.appsec.enabled", 1); traceSeg.setTagTop("_dd.runtime_family", "jvm"); diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy index a4ef9984786..4d60e7a8527 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy @@ -138,35 +138,40 @@ class ApiSecuritySamplerTest extends DDSpecification { !sampled } - void 'sampleRequest honors expiration'() { + void 'preSampleRequest honors expiration'() { given: - def ctx = createContext('route1', 'GET', 200) - ctx.setApiSecurityEndpointHash(42L) - ctx.setKeepOpenForApiSecurityPostProcessing(true) + def ctx1 = createContext('route1', 'GET', 200) + def ctx2 = createContext('route1', 'GET', 200) + def ctx3 = createContext('route1', 'GET', 200) final timeSource = new ControllableTimeSource() timeSource.set(0) final long expirationTimeInMs = 10L final long expirationTimeInNs = expirationTimeInMs * 1_000_000 def sampler = new ApiSecuritySamplerImpl(10, expirationTimeInMs, timeSource) - when: - def sampled = sampler.sampleRequest(ctx) + when: 'first request samples' + def preSampled1 = sampler.preSampleRequest(ctx1) + def sampled1 = sampler.sampleRequest(ctx1) then: - sampled + preSampled1 + sampled1 - when: - sampled = sampler.sampleRequest(ctx) + when: 'second request to same endpoint before expiration' + def preSampled2 = sampler.preSampleRequest(ctx2) then: 'second request is not sampled' - !sampled + !preSampled2 when: 'expiration time has passed' + sampler.releaseOne() timeSource.advance(expirationTimeInNs) - sampled = sampler.sampleRequest(ctx) + def preSampled3 = sampler.preSampleRequest(ctx3) + def sampled3 = sampler.sampleRequest(ctx3) then: 'request is sampled again' - sampled + preSampled3 + sampled3 } void 'internal accessMap never goes beyond capacity'() { @@ -198,10 +203,13 @@ class ApiSecuritySamplerTest extends DDSpecification { expect: for (int i = 0; i < maxCapacity * 10; i++) { - final ctx = createContext('route1', 'GET', 200 + 1) - ctx.setApiSecurityEndpointHash(i as long) - ctx.setKeepOpenForApiSecurityPostProcessing(true) - assert sampler.sampleRequest(ctx) + final ctx = createContext('route1', 'GET', 200 + i) + def preSampled = sampler.preSampleRequest(ctx) + // First request always samples, then we advance time so each subsequent request expires + assert preSampled + def sampled = sampler.sampleRequest(ctx) + assert sampled + sampler.releaseOne() assert sampler.accessMap.size() <= 2 if (i % 2) { timeSource.advance(expirationTimeInMs * 1_000_000) diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy index 321f3876d94..d83b5fa29e6 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy @@ -248,4 +248,85 @@ class AppSecSpanPostProcessorTest extends DDSpecification { 1 * sampler.releaseOne() 0 * _ } + + void 'permit is released even if extractSchemas throws exception'() { + given: + def sampler = Mock(ApiSecuritySamplerImpl) + def producer = Mock(EventProducerService) + def span = Mock(AgentSpan) + def reqCtx = Mock(RequestContext) + def ctx = Mock(AppSecRequestContext) + def processor = new AppSecSpanPostProcessor(sampler, producer) + + when: + processor.process(span, { false }) + + then: + def ex = thrown(RuntimeException) + ex.message == "Unexpected error" + 1 * span.getRequestContext() >> reqCtx + 1 * reqCtx.getData(_) >> ctx + 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true + 1 * sampler.sampleRequest(_) >> true + 1 * reqCtx.getTraceSegment() >> { throw new RuntimeException("Unexpected error") } + 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) + 1 * ctx.closeWafContext() + 1 * ctx.close() + 1 * sampler.releaseOne() // Critical: permit is still released despite exception + 0 * _ + } + + void 'multiple requests do not exhaust semaphore permits'() { + given: + // Use real ApiSecuritySamplerImpl which has a semaphore with 4 permits + def realSampler = new ApiSecuritySamplerImpl() + def producer = Mock(EventProducerService) + def processor = new AppSecSpanPostProcessor(realSampler, producer) + + when: 'Process 5 consecutive requests that acquire permits' + 5.times { i -> + def span = Mock(AgentSpan) + def reqCtx = Mock(RequestContext) + def ctx = Mock(AppSecRequestContext) + + // Mock the interactions + span.getRequestContext() >> reqCtx + reqCtx.getData(_) >> ctx + ctx.isKeepOpenForApiSecurityPostProcessing() >> true + ctx.setKeepOpenForApiSecurityPostProcessing(false) + ctx.closeWafContext() + ctx.close() + + // Process should complete without issues, releasing permit each time + processor.process(span, { false }) + } + + then: 'All requests complete successfully without permit exhaustion' + noExceptionThrown() + } + + void 'permit is released when ctx cleanup operations fail'() { + given: + def sampler = Mock(ApiSecuritySamplerImpl) + def producer = Mock(EventProducerService) + def span = Mock(AgentSpan) + def reqCtx = Mock(RequestContext) + def ctx = Mock(AppSecRequestContext) + def processor = new AppSecSpanPostProcessor(sampler, producer) + + when: + processor.process(span, { false }) + + then: + noExceptionThrown() + 1 * span.getRequestContext() >> reqCtx + 1 * reqCtx.getData(_) >> ctx + 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true + 1 * sampler.sampleRequest(_) >> false + 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) + 1 * ctx.closeWafContext() >> { throw new RuntimeException("WAF context close failed") } + 1 * ctx.close() >> { throw new RuntimeException("Context close failed") } + 1 * sampler.releaseOne() // Critical: permit is still released despite cleanup failures + 0 * _ + } }