Skip to content

Commit 4745854

Browse files
authored
Merge pull request #934 from openedx/pwnage101/temp-backwards-compatible-1-plan-provisioning
fix: temporarily allow 1-plan provisioning.
2 parents f5e4ddf + e84e1b4 commit 4745854

File tree

3 files changed

+131
-3
lines changed

3 files changed

+131
-3
lines changed

enterprise_access/apps/api/v1/tests/test_provisioning_views.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from unittest import mock
88

99
import ddt
10+
from django.conf import settings
1011
from django.contrib.auth import get_user_model
1112
from django.utils import timezone
1213
from edx_rbac.constants import ALL_ACCESS_CONTEXT
@@ -820,6 +821,80 @@ def test_new_subscription_plan_created(
820821
)
821822
assert mock_license_client.create_subscription_plan.call_count == 2
822823

824+
@mock.patch('enterprise_access.apps.provisioning.api.LicenseManagerApiClient')
825+
@mock.patch('enterprise_access.apps.provisioning.api.LmsApiClient')
826+
@mock.patch('enterprise_access.apps.api.v1.views.provisioning.logger')
827+
def test_legacy_single_plan_request_transformation(
828+
self, mock_logger, mock_lms_api_client, mock_license_manager_client
829+
):
830+
"""
831+
Test that legacy requests with single 'subscription_plan' key are transformed
832+
to the new two-plan format and successfully provision resources.
833+
"""
834+
# Setup mocks for successful provisioning.
835+
mock_lms_client = mock_lms_api_client.return_value
836+
mock_lms_client.get_enterprise_customer_data.return_value = None
837+
mock_lms_client.create_enterprise_customer.return_value = DEFAULT_CUSTOMER_RECORD
838+
mock_lms_client.get_enterprise_admin_users.return_value = []
839+
mock_lms_client.get_enterprise_pending_admin_users.return_value = []
840+
mock_lms_client.get_enterprise_catalogs.return_value = [DEFAULT_CATALOG_RECORD]
841+
842+
mock_license_client = mock_license_manager_client.return_value
843+
mock_license_client.get_customer_agreement.return_value = None
844+
mock_license_client.create_customer_agreement.return_value = {
845+
**DEFAULT_AGREEMENT_RECORD, "subscriptions": []
846+
}
847+
mock_license_client.create_subscription_plan.side_effect = [
848+
DEFAULT_TRIAL_SUBSCRIPTION_PLAN_RECORD,
849+
DEFAULT_FIRST_PAID_SUBSCRIPTION_PLAN_RECORD,
850+
]
851+
mock_license_client.create_subscription_plan_renewal.return_value = (
852+
EXPECTED_SUBSCRIPTION_PLAN_RENEWAL_RESPONSE
853+
)
854+
855+
# Create a legacy request payload with 'subscription_plan' instead of the new format.
856+
legacy_request_payload = {**DEFAULT_REQUEST_PAYLOAD}
857+
legacy_request_payload.pop('first_paid_subscription_plan')
858+
legacy_request_payload['subscription_plan'] = legacy_request_payload.pop('trial_subscription_plan')
859+
860+
# Make the provisioning request.
861+
response = self.client.post(PROVISIONING_CREATE_ENDPOINT, data=legacy_request_payload)
862+
863+
# Should succeed despite using legacy format.
864+
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
865+
866+
# Verify warning was logged about deprecated format.
867+
mock_logger.warning.assert_called_once()
868+
warning_message = mock_logger.warning.call_args[0][0]
869+
self.assertIn('Deprecated request format detected', warning_message)
870+
self.assertIn('subscription_plan', warning_message)
871+
872+
# Verify info log about transformation.
873+
mock_logger.info.assert_called()
874+
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
875+
self.assertTrue(
876+
any('Transformed legacy subscription_plan' in msg for msg in info_calls),
877+
"Expected transformation log message not found"
878+
)
879+
880+
# Verify response has both trial and paid subscription plans.
881+
response_data = response.json()
882+
self.assertIn('trial_subscription_plan', response_data)
883+
self.assertIn('first_paid_subscription_plan', response_data)
884+
885+
# Verify the workflow input_data contains correct trial and paid plan data
886+
workflow = ProvisionNewCustomerWorkflow.objects.first()
887+
trial_plan_input = workflow.input_data.get('create_trial_subscription_plan_input')
888+
first_paid_plan_input = workflow.input_data.get('create_first_paid_subscription_plan_input')
889+
assert trial_plan_input['title'] == legacy_request_payload['subscription_plan']['title']
890+
assert trial_plan_input['salesforce_opportunity_line_item'] == (
891+
legacy_request_payload['subscription_plan']['salesforce_opportunity_line_item']
892+
)
893+
assert trial_plan_input['product_id'] == legacy_request_payload['subscription_plan']['product_id']
894+
assert 'First Paid Plan' in first_paid_plan_input['title']
895+
assert first_paid_plan_input['product_id'] == settings.PROVISIONING_PAID_SUBSCRIPTION_PRODUCT_ID
896+
assert first_paid_plan_input['salesforce_opportunity_line_item'] is None
897+
823898

824899
@ddt.ddt
825900
class TestCheckoutIntentSynchronization(APITest):

enterprise_access/apps/api/v1/views/provisioning.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Rest API views for the browse and request app.
33
"""
4+
import copy
45
import logging
56

67
from django.conf import settings
@@ -52,8 +53,60 @@ class ProvisioningCreateView(PermissionRequiredMixin, generics.CreateAPIView):
5253
permission_classes = (permissions.IsAuthenticated,)
5354
permission_required = constants.PROVISIONING_CREATE_PERMISSION
5455

56+
def _transform_legacy_request_data(self, request_data: dict) -> dict:
57+
"""
58+
Transform legacy request format to support backward compatibility.
59+
60+
Converts old `subscription_plan` key to new `trial_subscription_plan` and
61+
`first_paid_subscription_plan` keys if the old format is detected.
62+
63+
Args:
64+
request_data (dict): The incoming request data
65+
66+
Returns:
67+
dict: Transformed request data with new keys
68+
"""
69+
# If new format is already present, return as-is.
70+
if 'trial_subscription_plan' in request_data and 'first_paid_subscription_plan' in request_data:
71+
return request_data
72+
73+
# If old format is present, transform it.
74+
if 'subscription_plan' in request_data:
75+
logger.warning(
76+
'Deprecated request format detected: `subscription_plan` key should be replaced with '
77+
'`trial_subscription_plan` and `first_paid_subscription_plan`'
78+
)
79+
80+
subscription_plan = request_data.pop('subscription_plan')
81+
82+
# Use the subscription_plan data as trial_subscription_plan
83+
request_data['trial_subscription_plan'] = subscription_plan
84+
85+
# Synthesize first_paid_subscription_plan with required fields only
86+
request_data['first_paid_subscription_plan'] = {
87+
'title': f"{subscription_plan.get('title', 'Subscription')} - First Paid Plan",
88+
'product_id': settings.PROVISIONING_PAID_SUBSCRIPTION_PRODUCT_ID,
89+
'salesforce_opportunity_line_item': None,
90+
}
91+
92+
logger.info(
93+
'Transformed legacy subscription_plan to trial_subscription_plan and '
94+
'synthesized first_paid_subscription_plan'
95+
)
96+
logger.info(
97+
'Transformed request payload: %s',
98+
str(request_data),
99+
)
100+
return request_data
101+
102+
# If neither format detected, passthrough to serializer to inevitably fail validation and return HTTP 400.
103+
return request_data
104+
55105
def create(self, request, *args, **kwargs):
56-
request_serializer = serializers.ProvisioningRequestSerializer(data=request.data)
106+
# TEMP: Transform legacy request data before serialization. See docstring.
107+
transformed_data = self._transform_legacy_request_data(copy.deepcopy(dict(request.data)))
108+
109+
request_serializer = serializers.ProvisioningRequestSerializer(data=transformed_data)
57110
request_serializer.is_valid(raise_exception=True)
58111

59112
customer_request_data = request_serializer.validated_data['enterprise_customer']

enterprise_access/apps/provisioning/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ class GetCreateFirstPaidSubscriptionPlanStepInput(BaseInputOutput):
557557

558558
title: str = field(validator=is_str)
559559
product_id: int = field(validator=is_int)
560-
start_date: Optional[datetime] = field(validator=validators.optional(is_datetime))
561-
expiration_date: Optional[datetime] = field(validator=is_datetime)
560+
start_date: Optional[datetime] = field(default=None, validator=validators.optional(is_datetime))
561+
expiration_date: Optional[datetime] = field(default=None, validator=validators.optional(is_datetime))
562562
salesforce_opportunity_line_item: Optional[str] = field(default=None, validator=validators.optional(is_str))
563563
enterprise_catalog_uuid: Optional[UUID] = field(default=None, validator=validators.optional(is_uuid))
564564

0 commit comments

Comments
 (0)