diff --git a/internal/internal_event_handlers.go b/internal/internal_event_handlers.go index 2bbeabc97..2ca555e9b 100644 --- a/internal/internal_event_handlers.go +++ b/internal/internal_event_handlers.go @@ -484,7 +484,7 @@ func validateAndSerializeSearchAttributes(attributes map[string]interface{}) (*c func (wc *workflowEnvironmentImpl) UpsertMemo(memoMap map[string]interface{}) error { // This has to be used in WorkflowEnvironment implementations instead of in Workflow for testsuite mock purpose. - memo, err := validateAndSerializeMemo(memoMap, wc.dataConverter) + memo, err := validateAndSerializeMemo(memoMap, wc.dataConverter, wc) if err != nil { return err } @@ -520,11 +520,11 @@ func mergeMemo(current, upsert *commonpb.Memo) *commonpb.Memo { return current } -func validateAndSerializeMemo(memoMap map[string]interface{}, dc converter.DataConverter) (*commonpb.Memo, error) { +func validateAndSerializeMemo(memoMap map[string]interface{}, dc converter.DataConverter, accessor memoFlagAccessor) (*commonpb.Memo, error) { if len(memoMap) == 0 { return nil, errMemoNotSet } - return getWorkflowMemo(memoMap, dc) + return getWorkflowMemo(memoMap, dc, accessor) } func (wc *workflowEnvironmentImpl) RegisterCancelHandler(handler func()) { @@ -540,7 +540,7 @@ func (wc *workflowEnvironmentImpl) ExecuteChildWorkflow( if params.WorkflowID == "" { params.WorkflowID = wc.workflowInfo.currentRunID + "_" + wc.GenerateSequenceID() } - memo, err := getWorkflowMemo(params.Memo, wc.dataConverter) + memo, err := getWorkflowMemo(params.Memo, wc.dataConverter, wc) if err != nil { if wc.sdkFlags.tryUse(SDKFlagChildWorkflowErrorExecution, !wc.isReplay) { startedHandler(WorkflowExecution{}, &ChildWorkflowExecutionAlreadyStartedError{}) diff --git a/internal/internal_event_handlers_test.go b/internal/internal_event_handlers_test.go index d19856e4f..c2eefa7f7 100644 --- a/internal/internal_event_handlers_test.go +++ b/internal/internal_event_handlers_test.go @@ -213,13 +213,13 @@ func Test_MergeSearchAttributes(t *testing.T) { func Test_ValidateAndSerializeMemo(t *testing.T) { t.Parallel() - _, err := validateAndSerializeMemo(nil, nil) + _, err := validateAndSerializeMemo(nil, nil, nil) require.EqualError(t, err, "memo is empty") attr := map[string]interface{}{ "JustKey": make(chan int), } - _, err = validateAndSerializeMemo(attr, nil) + _, err = validateAndSerializeMemo(attr, nil, nil) require.EqualError( t, err, @@ -229,7 +229,7 @@ func Test_ValidateAndSerializeMemo(t *testing.T) { attr = map[string]interface{}{ "key": 1, } - memo, err := validateAndSerializeMemo(attr, nil) + memo, err := validateAndSerializeMemo(attr, nil, nil) require.NoError(t, err) require.Equal(t, 1, len(memo.Fields)) var resp int @@ -244,6 +244,8 @@ func Test_UpsertMemo(t *testing.T) { env := &workflowEnvironmentImpl{ commandsHelper: helper, workflowInfo: GetWorkflowInfo(ctx), + sdkFlags: newSDKFlags(nil), + dataConverter: converter.GetDefaultDataConverter(), } helper.setCurrentWorkflowTaskStartedEventID(4) err := env.UpsertMemo(nil) diff --git a/internal/internal_flags.go b/internal/internal_flags.go index 456a7fe81..d6d514ce6 100644 --- a/internal/internal_flags.go +++ b/internal/internal_flags.go @@ -29,13 +29,20 @@ const ( // SDKFlagBlockedSelectorSignalReceive will cause a signal to not be lost // when the Default path is blocked. SDKFlagBlockedSelectorSignalReceive = 5 - SDKFlagUnknown = math.MaxUint32 + // SDKFlagMemoUserDCEncode will use the user data converter when encoding a memo. If user data converter fails, + // we will fallback onto the default data converter. If the default DC fails, the user DC error will be returned. + SDKFlagMemoUserDCEncode = 6 + SDKFlagUnknown = math.MaxUint32 ) // unblockSelectorSignal exists to allow us to configure the default behavior of // SDKFlagBlockedSelectorSignalReceive. This is primarily useful with tests. var unblockSelectorSignal = os.Getenv("UNBLOCK_SIGNAL_SELECTOR") != "" +// memoUserDCEncode exists to allow us to configure the default behavior of +// SDKFlagMemoUserDCEncode. This is primarily useful with tests. +var memoUserDCEncode = os.Getenv("MEMO_USER_DC_ENCODE") != "" + func sdkFlagFromUint(value uint32) sdkFlag { switch value { case uint32(SDKFlagUnset): @@ -50,6 +57,8 @@ func sdkFlagFromUint(value uint32) sdkFlag { return SDKPriorityUpdateHandling case uint32(SDKFlagBlockedSelectorSignalReceive): return SDKFlagBlockedSelectorSignalReceive + case uint32(SDKFlagMemoUserDCEncode): + return SDKFlagMemoUserDCEncode default: return SDKFlagUnknown } @@ -130,3 +139,7 @@ func (sf *sdkFlags) gatherNewSDKFlags() []sdkFlag { func SetUnblockSelectorSignal(unblockSignal bool) { unblockSelectorSignal = unblockSignal } + +func SetMemoUserDCEncode(userEncode bool) { + memoUserDCEncode = userEncode +} diff --git a/internal/internal_schedule_client.go b/internal/internal_schedule_client.go index eb08ef907..42f5e53bc 100644 --- a/internal/internal_schedule_client.go +++ b/internal/internal_schedule_client.go @@ -70,7 +70,7 @@ func (w *workflowClientInterceptor) CreateSchedule(ctx context.Context, in *Sche return nil, err } - memo, err := getWorkflowMemo(in.Options.Memo, dataConverter) + memo, err := getWorkflowMemo(in.Options.Memo, dataConverter, nil) if err != nil { return nil, err } @@ -875,11 +875,16 @@ func encodeScheduleWorkflowMemo(dc converter.DataConverter, input map[string]int } memo := make(map[string]*commonpb.Payload) + if dc == nil { + dc = converter.GetDefaultDataConverter() + } + + useUserDC := shouldUseMemoUserDataConverter(nil) for k, v := range input { if enc, ok := v.(*commonpb.Payload); ok { memo[k] = enc } else { - memoBytes, err := converter.GetDefaultDataConverter().ToPayload(v) + memoBytes, err := encodeMemoValue(v, dc, useUserDC) if err != nil { return nil, fmt.Errorf("encode workflow memo error: %v", err.Error()) } diff --git a/internal/internal_schedule_client_test.go b/internal/internal_schedule_client_test.go index 9976c2728..0faa01c7e 100644 --- a/internal/internal_schedule_client_test.go +++ b/internal/internal_schedule_client_test.go @@ -2,6 +2,7 @@ package internal import ( "context" + iconverter "go.temporal.io/sdk/internal/converter" "testing" "github.com/golang/mock/gomock" @@ -219,3 +220,109 @@ func (s *scheduleClientTestSuite) TestIteratorError() { s.Nil(event) s.NotNil(err) } + +func (s *scheduleClientTestSuite) TestCreateScheduleWorkflowMemoDataConverter() { + testFn := func() { + dc := iconverter.NewTestDataConverter() + s.client = NewServiceClient(s.service, nil, ClientOptions{DataConverter: dc}) + + memo := map[string]interface{}{ + "testMemo": "memo value", + } + wf := func(ctx Context) string { panic("this is just a stub") } + + options := ScheduleOptions{ + ID: scheduleID, + Spec: ScheduleSpec{ + CronExpressions: []string{"*"}, + }, + Action: &ScheduleWorkflowAction{ + Workflow: wf, + ID: workflowID, + TaskQueue: taskqueue, + WorkflowExecutionTimeout: timeoutInSeconds, + WorkflowTaskTimeout: timeoutInSeconds, + Memo: memo, + }, + } + createResp := &workflowservice.CreateScheduleResponse{} + s.service.EXPECT().CreateSchedule(gomock.Any(), gomock.Any(), gomock.Any()).Return(createResp, nil). + Do(func(_ interface{}, req *workflowservice.CreateScheduleRequest, _ ...interface{}) { + startWorkflow := req.Schedule.Action.GetStartWorkflow() + encoding := string(startWorkflow.Memo.Fields["testMemo"].Metadata[converter.MetadataEncoding]) + if memoUserDCEncode { + s.Equal("binary/gob", encoding) + } else { + s.Equal("json/plain", encoding) + } + }) + + _, err := s.client.ScheduleClient().Create(context.Background(), options) + s.NoError(err) + } + s.T().Run("old behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(false) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) + s.T().Run("new behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(true) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) + +} + +func (s *scheduleClientTestSuite) TestCreateScheduleWorkflowMemoUserAndDefaultConverterFail() { + testFn := func() { + dc := failingMemoDataConverter{ + delegate: converter.GetDefaultDataConverter(), + } + s.client = NewServiceClient(s.service, nil, ClientOptions{DataConverter: dc}) + + memo := map[string]interface{}{ + "testMemo": make(chan int), + } + wf := func(ctx Context) string { panic("this is just a stub") } + + options := ScheduleOptions{ + ID: scheduleID, + Spec: ScheduleSpec{ + CronExpressions: []string{"*"}, + }, + Action: &ScheduleWorkflowAction{ + Workflow: wf, + ID: workflowID, + TaskQueue: taskqueue, + WorkflowExecutionTimeout: timeoutInSeconds, + WorkflowTaskTimeout: timeoutInSeconds, + Memo: memo, + }, + } + + s.service.EXPECT().CreateSchedule(gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + _, err := s.client.ScheduleClient().Create(context.Background(), options) + s.Error(err) + if memoUserDCEncode { + s.ErrorContains(err, "failingMemoDataConverter memo encoding failed") + } else { + s.ErrorContains(err, "unsupported type: chan int") + } + } + + s.T().Run("old behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(false) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) + s.T().Run("new behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(true) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) +} diff --git a/internal/internal_workflow_client.go b/internal/internal_workflow_client.go index 777c214f1..a2904f824 100644 --- a/internal/internal_workflow_client.go +++ b/internal/internal_workflow_client.go @@ -1629,15 +1629,58 @@ func (workflowRun *workflowRunImpl) follow( return workflowRun.GetWithOptions(ctx, valuePtr, options) } -func getWorkflowMemo(input map[string]interface{}, dc converter.DataConverter) (*commonpb.Memo, error) { +type memoFlagAccessor interface { + TryUse(flag sdkFlag) bool + GetFlag(flag sdkFlag) bool +} + +func encodeMemoValue(value interface{}, dc converter.DataConverter, useUserDC bool) (*commonpb.Payload, error) { + if useUserDC { + payload, dcErr := dc.ToPayload(value) + if dcErr == nil { + return payload, nil + } + + payload, err := converter.GetDefaultDataConverter().ToPayload(value) + + // If fallback default data converter fails, return original user data converter error + if err != nil { + return nil, dcErr + } + return payload, nil + } + payload, err := converter.GetDefaultDataConverter().ToPayload(value) + if err != nil { + return nil, err + } + return payload, nil +} + +func shouldUseMemoUserDataConverter(accessor memoFlagAccessor) bool { + if accessor == nil { + return memoUserDCEncode + } + + if memoUserDCEncode { + return accessor.TryUse(SDKFlagMemoUserDCEncode) + } + + return accessor.GetFlag(SDKFlagMemoUserDCEncode) +} + +func getWorkflowMemo(input map[string]interface{}, dc converter.DataConverter, accessor memoFlagAccessor) (*commonpb.Memo, error) { if input == nil { return nil, nil } - memo := make(map[string]*commonpb.Payload) + if dc == nil { + dc = converter.GetDefaultDataConverter() + } + + memo := make(map[string]*commonpb.Payload, len(input)) + useUserDC := shouldUseMemoUserDataConverter(accessor) for k, v := range input { - // TODO (shtin): use dc here??? - memoBytes, err := converter.GetDefaultDataConverter().ToPayload(v) + memoBytes, err := encodeMemoValue(v, dc, useUserDC) if err != nil { return nil, fmt.Errorf("encode workflow memo error: %v", err.Error()) } @@ -1699,7 +1742,7 @@ func (w *workflowClientInterceptor) createStartWorkflowRequest( return nil, err } - memo, err := getWorkflowMemo(in.Options.Memo, dataConverter) + memo, err := getWorkflowMemo(in.Options.Memo, dataConverter, nil) if err != nil { return nil, err } @@ -2080,7 +2123,7 @@ func (w *workflowClientInterceptor) SignalWithStartWorkflow( return nil, err } - memo, err := getWorkflowMemo(in.Options.Memo, dataConverter) + memo, err := getWorkflowMemo(in.Options.Memo, dataConverter, nil) if err != nil { return nil, err } diff --git a/internal/internal_workflow_client_test.go b/internal/internal_workflow_client_test.go index 9555d2b9e..c1c1bc900 100644 --- a/internal/internal_workflow_client_test.go +++ b/internal/internal_workflow_client_test.go @@ -1901,31 +1901,150 @@ func (s *workflowClientTestSuite) TestSignalWithStartWorkflowWithVersioningOverr func (s *workflowClientTestSuite) TestGetWorkflowMemo() { var input1 map[string]interface{} - result1, err := getWorkflowMemo(input1, s.dataConverter) + result1, err := getWorkflowMemo(input1, s.dataConverter, nil) s.NoError(err) s.Nil(result1) input1 = make(map[string]interface{}) - result2, err := getWorkflowMemo(input1, s.dataConverter) + result2, err := getWorkflowMemo(input1, s.dataConverter, nil) s.NoError(err) s.NotNil(result2) s.Equal(0, len(result2.Fields)) input1["t1"] = "v1" - result3, err := getWorkflowMemo(input1, s.dataConverter) + result3, err := getWorkflowMemo(input1, s.dataConverter, nil) s.NoError(err) s.NotNil(result3) s.Equal(1, len(result3.Fields)) var resultString string - // TODO (shtin): use s.DataConverter here??? _ = converter.GetDefaultDataConverter().FromPayload(result3.Fields["t1"], &resultString) s.Equal("v1", resultString) input1["non-serializable"] = make(chan int) - _, err = getWorkflowMemo(input1, s.dataConverter) + _, err = getWorkflowMemo(input1, s.dataConverter, nil) s.Error(err) } +func (s *workflowClientTestSuite) TestStartWorkflowWithMemoDataConverter() { + testFn := func() { + // User data converter uses binary/gob encoding + dc := iconverter.NewTestDataConverter() + s.client = NewServiceClient(s.service, nil, ClientOptions{DataConverter: dc}) + + memo := map[string]interface{}{ + "testMemo": "memo value", + } + options := StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WorkflowExecutionTimeout: timeoutInSeconds, + WorkflowTaskTimeout: timeoutInSeconds, + Memo: memo, + } + wf := func(ctx Context) string { return "" } + + startResp := &workflowservice.StartWorkflowExecutionResponse{} + s.service.EXPECT().StartWorkflowExecution(gomock.Any(), gomock.Any(), gomock.Any()).Return(startResp, nil). + Do(func(_ interface{}, req *workflowservice.StartWorkflowExecutionRequest, _ ...interface{}) { + encoding := string(req.Memo.Fields["testMemo"].Metadata[converter.MetadataEncoding]) + if memoUserDCEncode { + s.Equal("binary/gob", encoding) + } else { + s.Equal("json/plain", encoding) + } + + }) + + _, err := s.client.ExecuteWorkflow(context.Background(), options, wf) + s.NoError(err) + } + + s.T().Run("old behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(false) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) + s.T().Run("new behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(true) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) +} + +type failingMemoDataConverter struct { + delegate converter.DataConverter +} + +func (f failingMemoDataConverter) ToPayload(value interface{}) (*commonpb.Payload, error) { + return nil, fmt.Errorf("failingMemoDataConverter memo encoding failed") +} + +func (f failingMemoDataConverter) FromPayload(payload *commonpb.Payload, valuePtr interface{}) error { + return f.delegate.FromPayload(payload, valuePtr) +} + +func (f failingMemoDataConverter) ToPayloads(values ...interface{}) (*commonpb.Payloads, error) { + return f.delegate.ToPayloads(values...) +} + +func (f failingMemoDataConverter) FromPayloads(payloads *commonpb.Payloads, valuePtrs ...interface{}) error { + return f.delegate.FromPayloads(payloads, valuePtrs...) +} + +func (f failingMemoDataConverter) ToString(input *commonpb.Payload) string { + return f.delegate.ToString(input) +} + +func (f failingMemoDataConverter) ToStrings(input *commonpb.Payloads) []string { + return f.delegate.ToStrings(input) +} + +func (s *workflowClientTestSuite) TestStartWorkflowWithMemoUserAndDefaultConverterFail() { + testFn := func() { + dc := failingMemoDataConverter{ + delegate: converter.GetDefaultDataConverter(), + } + s.client = NewServiceClient(s.service, nil, ClientOptions{DataConverter: dc}) + + memo := map[string]interface{}{ + "testMemo": make(chan int), + } + options := StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WorkflowExecutionTimeout: timeoutInSeconds, + WorkflowTaskTimeout: timeoutInSeconds, + Memo: memo, + } + wf := func(ctx Context) string { return "" } + + s.service.EXPECT().StartWorkflowExecution(gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + _, err := s.client.ExecuteWorkflow(context.Background(), options, wf) + s.Error(err) + if memoUserDCEncode { + s.ErrorContains(err, "failingMemoDataConverter memo encoding failed") + } else { + s.ErrorContains(err, "unsupported type: chan int") + } + } + + s.T().Run("old behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(false) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) + s.T().Run("new behavior", func(t *testing.T) { + previousFlag := memoUserDCEncode + SetMemoUserDCEncode(true) + defer SetMemoUserDCEncode(previousFlag) + testFn() + }) +} + func (s *workflowClientTestSuite) TestSerializeSearchAttributes() { var input1 map[string]interface{} result1, err := serializeUntypedSearchAttributes(input1) diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index d1e0eb3f9..c26b6c7d2 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -2856,7 +2856,7 @@ func (env *testWorkflowEnvironmentImpl) UpsertTypedSearchAttributes(attributes S } func (env *testWorkflowEnvironmentImpl) UpsertMemo(memoMap map[string]interface{}) error { - memo, err := validateAndSerializeMemo(memoMap, env.dataConverter) + memo, err := validateAndSerializeMemo(memoMap, env.dataConverter, env) env.workflowInfo.Memo = mergeMemo(env.workflowInfo.Memo, memo) diff --git a/internal/internal_workflow_testsuite_test.go b/internal/internal_workflow_testsuite_test.go index 6bf136647..0fa7bcdad 100644 --- a/internal/internal_workflow_testsuite_test.go +++ b/internal/internal_workflow_testsuite_test.go @@ -1877,7 +1877,8 @@ func (s *WorkflowTestSuiteUnitTest) Test_MockUpsertMemo() { s.NotNil(wfInfo.Memo) valBytes := wfInfo.Memo.Fields["CustomIntField"] var result int - _ = converter.GetDefaultDataConverter().FromPayload(valBytes, &result) + err = converter.GetDefaultDataConverter().FromPayload(valBytes, &result) + s.NoError(err) s.Equal(1, result) return nil diff --git a/internal/workflow_testsuite.go b/internal/workflow_testsuite.go index 6cc380e0a..c0d670f62 100644 --- a/internal/workflow_testsuite.go +++ b/internal/workflow_testsuite.go @@ -1191,7 +1191,7 @@ func (e *TestWorkflowEnvironment) SetLastError(err error) { // SetMemoOnStart sets the memo when start workflow. func (e *TestWorkflowEnvironment) SetMemoOnStart(memo map[string]interface{}) error { - memoStruct, err := getWorkflowMemo(memo, e.impl.GetDataConverter()) + memoStruct, err := getWorkflowMemo(memo, e.impl.GetDataConverter(), e.impl) if err != nil { return err }