Skip to content

Commit 54908d2

Browse files
committed
feat: handle payment error on trial
1 parent df621a2 commit 54908d2

File tree

4 files changed

+275
-41
lines changed

4 files changed

+275
-41
lines changed

enterprise_access/apps/api_client/license_manager_client.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,56 @@ class LicenseManagerApiClient(BaseOAuthClient):
2727
customer_agreement_provisioning_endpoint = api_base_url + 'provisioning-admins/customer-agreement/'
2828
subscription_provisioning_endpoint = api_base_url + 'provisioning-admins/subscriptions/'
2929

30+
def list_subscriptions(self, enterprise_customer_uuid, current=None):
31+
"""
32+
List subscription plans for an enterprise.
33+
34+
Returns a paginated DRF list response: { count, next, previous, results: [...] }
35+
If current is True, returns only the current plan (results length 0 or 1).
36+
"""
37+
try:
38+
params = {
39+
'enterprise_customer_uuid': enterprise_customer_uuid,
40+
}
41+
if current is not None:
42+
params['current'] = 'true' if current else 'false'
43+
44+
response = self.client.get(
45+
self.subscriptions_endpoint,
46+
params=params,
47+
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
48+
)
49+
response.raise_for_status()
50+
return response.json()
51+
except requests.exceptions.HTTPError as exc:
52+
logger.exception(
53+
'Failed to list subscriptions for enterprise %s, response: %s, exc: %s',
54+
enterprise_customer_uuid, safe_error_response_content(exc), exc,
55+
)
56+
raise
57+
58+
def update_subscription_plan(self, subscription_uuid, **payload):
59+
"""
60+
Partially update a subscription plan by UUID.
61+
62+
Example payload: { 'is_active': False, 'change_reason': 'delayed_payment' }
63+
"""
64+
endpoint = f"{self.subscription_provisioning_endpoint}{subscription_uuid}/"
65+
try:
66+
response = self.client.patch(
67+
endpoint,
68+
json=payload,
69+
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
70+
)
71+
response.raise_for_status()
72+
return response.json()
73+
except requests.exceptions.HTTPError as exc:
74+
logger.exception(
75+
'Failed to update subscription %s, payload=%s, response: %s, exc: %s',
76+
subscription_uuid, payload, safe_error_response_content(exc), exc,
77+
)
78+
raise
79+
3080
def get_subscription_overview(self, subscription_uuid):
3181
"""
3282
Call license-manager API for data about a SubscriptionPlan.

enterprise_access/apps/api_client/tests/test_license_manager_client.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,49 @@ def test_create_customer_agreement(self, mock_oauth_client):
8888
json=expected_payload,
8989
)
9090

91+
@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
92+
def test_list_subscriptions_current_flag(self, mock_oauth_client):
93+
mock_get = mock_oauth_client.return_value.get
94+
mock_get.return_value.json.return_value = {'results': []}
95+
96+
lm_client = LicenseManagerApiClient()
97+
enterprise_uuid = 'ec-uuid-123'
98+
99+
# current=True should set ?current=true
100+
result = lm_client.list_subscriptions(enterprise_uuid, current=True)
101+
self.assertEqual(result, {'results': []})
102+
103+
# Verify URL and params
104+
expected_url = (
105+
'http://license-manager.example.com'
106+
'/api/v1/subscriptions/'
107+
)
108+
mock_get.assert_called_with(
109+
expected_url,
110+
params={'enterprise_customer_uuid': enterprise_uuid, 'current': 'true'},
111+
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
112+
)
113+
114+
@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
115+
def test_update_subscription_plan_patch(self, mock_oauth_client):
116+
mock_patch = mock_oauth_client.return_value.patch
117+
mock_patch.return_value.json.return_value = {'uuid': 'plan-uuid', 'is_active': False}
118+
119+
lm_client = LicenseManagerApiClient()
120+
payload = {'is_active': False, 'change_reason': 'delayed_payment'}
121+
result = lm_client.update_subscription_plan('plan-uuid', **payload)
122+
123+
self.assertEqual(result, mock_patch.return_value.json.return_value)
124+
expected_url = (
125+
'http://license-manager.example.com'
126+
'/api/v1/provisioning-admins/subscriptions/plan-uuid/'
127+
)
128+
mock_patch.assert_called_once_with(
129+
expected_url,
130+
json=payload,
131+
timeout=settings.LICENSE_MANAGER_CLIENT_TIMEOUT,
132+
)
133+
91134
@mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True)
92135
def test_create_subscription_plan(self, mock_oauth_client):
93136
mock_post = mock_oauth_client.return_value.post

enterprise_access/apps/customer_billing/stripe_event_handlers.py

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,117 @@ def link_event_data_to_checkout_intent(event, checkout_intent):
115115
event_data.save()
116116

117117

118+
def handle_pending_update(subscription_id: str, checkout_intent_id: int, pending_update):
119+
"""
120+
Log pending update information for visibility.
121+
Assumes a pending_update is present.
122+
"""
123+
# TODO: take necessary action on the actual SubscriptionPlan and update the CheckoutIntent.
124+
logger.warning(
125+
"Subscription %s has pending update: %s. checkout_intent_id: %s",
126+
subscription_id,
127+
pending_update,
128+
checkout_intent_id,
129+
)
130+
131+
132+
def handle_trial_cancellation(checkout_intent: CheckoutIntent, checkout_intent_id: int, subscription_id: str, trial_end):
133+
"""
134+
Send cancellation email for a trial subscription that has just transitioned to canceled.
135+
Assumes caller validated status transition and presence of trial_end.
136+
"""
137+
logger.info(
138+
f"Subscription {subscription_id} transitioned to 'canceled'. "
139+
f"Queuing trial cancellation email for checkout_intent_id={checkout_intent_id}"
140+
)
141+
142+
send_trial_cancellation_email_task.delay(
143+
checkout_intent_id=checkout_intent.id,
144+
trial_end_timestamp=trial_end,
145+
)
146+
147+
148+
def future_plans_of_current(current_plan_uuid: str, plans: list[dict]) -> list[dict]:
149+
"""
150+
Return plans that are future renewals of the current plan, based on prior_renewals linkage.
151+
"""
152+
def is_future_of_current(plan_dict):
153+
if str(plan_dict.get('uuid')) == current_plan_uuid:
154+
return False
155+
for renewal in plan_dict.get('prior_renewals', []) or []:
156+
if str(renewal.get('prior_subscription_plan_id')) == current_plan_uuid:
157+
return True
158+
return False
159+
160+
return [p for p in plans if is_future_of_current(p)]
161+
162+
163+
def cancel_all_future_plans(enterprise_uuid: str, reason: str = 'delayed_payment', subscription_id_for_logs: str | None = None) -> list[str]:
164+
"""
165+
Deactivate (cancel) all future plans for the current plan of the given enterprise.
166+
167+
Returns list of deactivated plan UUIDs. Logs warnings/info for observability.
168+
"""
169+
from enterprise_access.apps.api_client.license_manager_client import LicenseManagerApiClient
170+
171+
client = LicenseManagerApiClient()
172+
deactivated = []
173+
try:
174+
current_list = client.list_subscriptions(enterprise_uuid, current=True)
175+
current_results = (current_list or {}).get('results', [])
176+
current_plan = current_results[0] if current_results else None
177+
if not current_plan:
178+
logger.warning(
179+
"No current subscription plan found for enterprise %s when canceling future plans (subscription %s)",
180+
enterprise_uuid, subscription_id_for_logs,
181+
)
182+
return deactivated
183+
184+
current_plan_uuid = str(current_plan.get('uuid'))
185+
186+
# Fetch all active plans for enterprise
187+
all_list = client.list_subscriptions(enterprise_uuid)
188+
all_plans = (all_list or {}).get('results', [])
189+
190+
future_plans = future_plans_of_current(current_plan_uuid, all_plans)
191+
if not future_plans:
192+
logger.info(
193+
"No future plans to deactivate for enterprise %s (current plan %s) (subscription %s)",
194+
enterprise_uuid, current_plan_uuid, subscription_id_for_logs,
195+
)
196+
return deactivated
197+
198+
# Deactivate all future plans
199+
for future in future_plans:
200+
future_uuid = future.get('uuid')
201+
try:
202+
client.update_subscription_plan(
203+
future_uuid,
204+
is_active=False,
205+
change_reason=reason,
206+
)
207+
deactivated.append(str(future_uuid))
208+
logger.info(
209+
"Deactivated future plan %s for enterprise %s (reason=%s) (subscription %s)",
210+
future_uuid, enterprise_uuid, reason, subscription_id_for_logs,
211+
)
212+
except Exception as exc: # pylint: disable=broad-except
213+
logger.exception(
214+
"Failed to deactivate future plan %s for enterprise %s (reason=%s): %s",
215+
future_uuid, enterprise_uuid, reason, exc,
216+
)
217+
except Exception as exc: # pylint: disable=broad-except
218+
logger.exception(
219+
"Unexpected error canceling future plans for enterprise %s (subscription %s): %s",
220+
enterprise_uuid, subscription_id_for_logs, exc,
221+
)
222+
223+
return deactivated
224+
225+
226+
227+
228+
118229
class StripeEventHandler:
119230
"""
120231
Container for Stripe event handler logic.
@@ -219,54 +330,40 @@ def subscription_updated(event: stripe.Event) -> None:
219330
Send cancellation notification email when a trial subscription is canceled.
220331
"""
221332
subscription = event.data.object
222-
pending_update = getattr(subscription, "pending_update", None)
223-
224-
checkout_intent_id = get_checkout_intent_id_from_subscription(
225-
subscription
226-
)
227-
checkout_intent = get_checkout_intent_or_raise(
228-
checkout_intent_id, event.id
229-
)
333+
checkout_intent_id = get_checkout_intent_id_from_subscription(subscription)
334+
checkout_intent = get_checkout_intent_or_raise(checkout_intent_id, event.id)
230335
link_event_data_to_checkout_intent(event, checkout_intent)
231336

337+
# Pending update
338+
pending_update = getattr(subscription, "pending_update", None)
232339
if pending_update:
233-
# TODO: take necessary action on the actual SubscriptionPlan
234-
# and update the CheckoutIntent.
235-
logger.warning(
236-
"Subscription %s has pending update: %s. checkout_intent_id: %s",
237-
subscription.id,
238-
pending_update,
239-
get_checkout_intent_id_from_subscription(subscription),
240-
)
340+
handle_pending_update(subscription.id, checkout_intent_id, pending_update)
241341

242-
# Handle trial subscription cancellation
243-
# Check if status changed to canceled to avoid duplicate emails
342+
# Trial cancellation transition
244343
current_status = subscription.get("status")
245-
if current_status == "canceled":
246-
prior_status = getattr(checkout_intent.previous_summary(event), 'subscription_status', None)
247-
248-
# Only send email if status changed from non-canceled to canceled
249-
if prior_status != 'canceled':
250-
trial_end = subscription.get("trial_end")
251-
if trial_end:
252-
logger.info(
253-
f"Subscription {subscription.id} status changed from '{prior_status}' to 'canceled'. "
254-
f"Queuing trial cancellation email for checkout_intent_id={checkout_intent_id}"
255-
)
256-
257-
send_trial_cancellation_email_task.delay(
258-
checkout_intent_id=checkout_intent.id,
259-
trial_end_timestamp=trial_end,
260-
)
261-
else:
262-
logger.info(
263-
f"Subscription {subscription.id} canceled but has no trial_end, skipping cancellation email"
264-
)
344+
prior_status = getattr(checkout_intent.previous_summary(event), 'subscription_status', None)
345+
if current_status == "canceled" and prior_status != "canceled":
346+
trial_end = subscription.get("trial_end")
347+
if trial_end:
348+
handle_trial_cancellation(checkout_intent, checkout_intent_id, subscription.id, trial_end)
349+
350+
# Past due transition
351+
if current_status == "past_due" and prior_status != "past_due":
352+
enterprise_uuid = checkout_intent.enterprise_uuid
353+
if enterprise_uuid:
354+
cancel_all_future_plans(
355+
enterprise_uuid=enterprise_uuid,
356+
reason='delayed_payment',
357+
subscription_id_for_logs=subscription.id,
358+
)
265359
else:
266-
logger.info(
267-
f"Subscription {subscription.id} already canceled (status unchanged), skipping cancellation email"
360+
logger.error(
361+
"Cannot deactivate future plans for subscription %s: missing enterprise_uuid on CheckoutIntent %s",
362+
subscription.id, checkout_intent.id,
268363
)
269364

365+
366+
270367
@on_stripe_event("customer.subscription.deleted")
271368
@staticmethod
272369
def subscription_deleted(event: stripe.Event) -> None:

enterprise_access/apps/customer_billing/tests/test_stripe_event_handlers.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from enterprise_access.apps.core.tests.factories import UserFactory
1717
from enterprise_access.apps.customer_billing.constants import CheckoutIntentState
1818
from enterprise_access.apps.customer_billing.models import CheckoutIntent, StripeEventData
19-
from enterprise_access.apps.customer_billing.stripe_event_handlers import StripeEventHandler
19+
from enterprise_access.apps.customer_billing.stripe_event_handlers import (
20+
StripeEventHandler,
21+
future_plans_of_current,
22+
cancel_all_future_plans,
23+
)
2024

2125

2226
class AttrDict(dict):
@@ -277,3 +281,43 @@ def test_subscription_updated_skips_email_when_no_trial_end(self):
277281
) as mock_task:
278282
StripeEventHandler.dispatch(mock_event)
279283
mock_task.delay.assert_not_called()
284+
285+
def test_future_plans_of_current_selects_children(self):
286+
"""future_plans_of_current returns plans whose prior_renewals link to current plan."""
287+
current_uuid = "1111-aaaa"
288+
plans = [
289+
{"uuid": current_uuid, "prior_renewals": []},
290+
{"uuid": "2222-bbbb", "prior_renewals": [{"prior_subscription_plan_id": current_uuid}]},
291+
{"uuid": "3333-cccc", "prior_renewals": [{"prior_subscription_plan_id": current_uuid}]},
292+
{"uuid": "4444-dddd", "prior_renewals": []},
293+
]
294+
295+
result = future_plans_of_current(current_uuid, plans)
296+
self.assertEqual({p["uuid"] for p in result}, {"2222-bbbb", "3333-cccc"})
297+
298+
@mock.patch("enterprise_access.apps.customer_billing.stripe_event_handlers.LicenseManagerApiClient")
299+
def test_cancel_all_future_plans_deactivates_all(self, MockClient):
300+
"""cancel_all_future_plans patches all future plans and returns their uuids."""
301+
enterprise_uuid = "ent-123"
302+
current_uuid = "1111-aaaa"
303+
future1 = {"uuid": "2222-bbbb", "prior_renewals": [{"prior_subscription_plan_id": current_uuid}]}
304+
future2 = {"uuid": "3333-cccc", "prior_renewals": [{"prior_subscription_plan_id": current_uuid}]}
305+
306+
mock_client = MockClient.return_value
307+
# list_subscriptions(current=True)
308+
mock_client.list_subscriptions.side_effect = [
309+
{"results": [{"uuid": current_uuid}]}, # current=True response
310+
{"results": [
311+
{"uuid": current_uuid, "prior_renewals": []},
312+
future1,
313+
future2,
314+
]}, # all plans response
315+
]
316+
317+
deactivated = cancel_all_future_plans(enterprise_uuid, reason="delayed_payment", subscription_id_for_logs="sub-1")
318+
319+
# Should have patched both future plans
320+
self.assertEqual(set(deactivated), {"2222-bbbb", "3333-cccc"})
321+
self.assertEqual(mock_client.update_subscription_plan.call_count, 2)
322+
mock_client.update_subscription_plan.assert_any_call("2222-bbbb", is_active=False, change_reason="delayed_payment")
323+
mock_client.update_subscription_plan.assert_any_call("3333-cccc", is_active=False, change_reason="delayed_payment")

0 commit comments

Comments
 (0)