Skip to content

Commit f1fab97

Browse files
yiminctdeebswihart
authored andcommitted
Add validation for a few string fields (#5487)
Add string validation for a few string fields Since we disable utf8 string validation from proto level, we want to enforce minimal validation for some key fields. Unit tests No <!-- Have you made sure this change doesn't falsify anything currently stated in `docs/`? If significant new behavior is added, have you described that in `docs/`? --> <!-- Is this PR a hotfix candidate or does it require a notification to be sent to the broader community? (Yes/No) -->
1 parent 4581099 commit f1fab97

File tree

6 files changed

+170
-25
lines changed

6 files changed

+170
-25
lines changed

common/searchattribute/encode_value.go

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@
2525
package searchattribute
2626

2727
import (
28+
"errors"
2829
"fmt"
2930
"time"
31+
"unicode/utf8"
3032

3133
commonpb "go.temporal.io/api/common/v1"
3234
enumspb "go.temporal.io/api/enums/v1"
3335

3436
"go.temporal.io/server/common/payload"
3537
)
3638

39+
var ErrInvalidString = errors.New("SearchAttribute value is not a valid UTF-8 string")
40+
3741
// EncodeValue encodes search attribute value and IndexedValueType to Payload.
3842
func EncodeValue(val interface{}, t enumspb.IndexedValueType) (*commonpb.Payload, error) {
3943
valPayload, err := payload.Encode(val)
@@ -70,16 +74,37 @@ func DecodeValue(
7074
case enumspb.INDEXED_VALUE_TYPE_INT:
7175
return decodeValueTyped[int64](value, allowList)
7276
case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
73-
return decodeValueTyped[string](value, allowList)
77+
return validateStrings(decodeValueTyped[string](value, allowList))
7478
case enumspb.INDEXED_VALUE_TYPE_TEXT:
75-
return decodeValueTyped[string](value, allowList)
79+
return validateStrings(decodeValueTyped[string](value, allowList))
7680
case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
77-
return decodeValueTyped[[]string](value, false)
81+
return validateStrings(decodeValueTyped[[]string](value, false))
7882
default:
7983
return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
8084
}
8185
}
8286

87+
func validateStrings(anyValue any, err error) (any, error) {
88+
if err != nil {
89+
return anyValue, err
90+
}
91+
92+
// validate strings
93+
switch value := anyValue.(type) {
94+
case string:
95+
if !utf8.ValidString(value) {
96+
return nil, fmt.Errorf("%w: %s", ErrInvalidString, value)
97+
}
98+
case []string:
99+
for _, item := range value {
100+
if !utf8.ValidString(item) {
101+
return nil, fmt.Errorf("%w: %s", ErrInvalidString, item)
102+
}
103+
}
104+
}
105+
return anyValue, err
106+
}
107+
83108
// decodeValueTyped tries to decode to the given type.
84109
// If the input is a list and allowList is false, then it will return only the first element.
85110
// If the input is a list and allowList is true, then it will return the decoded list.

common/searchattribute/encode_value_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
package searchattribute
2626

2727
import (
28+
"errors"
2829
"testing"
2930
"time"
3031

@@ -372,3 +373,21 @@ func Test_EncodeValue(t *testing.T) {
372373
"Datetime Search Attribute is expected to be encoded in RFC 3339 format")
373374
s.Equal("Datetime", string(encodedPayload.Metadata["type"]))
374375
}
376+
377+
func Test_ValidateStrings(t *testing.T) {
378+
_, err := validateStrings("anything here", errors.New("test error"))
379+
assert.Error(t, err)
380+
assert.Contains(t, err.Error(), "test error")
381+
382+
_, err = validateStrings("\x87\x01", nil)
383+
assert.Error(t, err)
384+
assert.Contains(t, err.Error(), "is not a valid UTF-8 string")
385+
386+
value, err := validateStrings("anything here", nil)
387+
assert.Nil(t, err)
388+
assert.Equal(t, "anything here", value)
389+
390+
_, err = validateStrings([]string{"abc", "\x87\x01"}, nil)
391+
assert.Error(t, err)
392+
assert.Contains(t, err.Error(), "is not a valid UTF-8 string")
393+
}

common/util.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"strings"
3333
"sync"
3434
"time"
35+
"unicode/utf8"
3536

3637
"github.com/dgryski/go-farm"
3738
"github.com/gogo/protobuf/proto"
@@ -720,3 +721,10 @@ func OverrideWorkflowTaskTimeout(
720721
func CloneProto[T proto.Message](v T) T {
721722
return proto.Clone(v).(T)
722723
}
724+
725+
func ValidateUTF8String(fieldName string, strValue string) error {
726+
if !utf8.ValidString(strValue) {
727+
return serviceerror.NewInvalidArgument(fmt.Sprintf("%s %v is not a valid UTF-8 string", fieldName, strValue))
728+
}
729+
return nil
730+
}

service/frontend/workflow_handler.go

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ func (wh *WorkflowHandler) StartWorkflowExecution(ctx context.Context, request *
380380
return nil, errWorkflowTypeTooLong
381381
}
382382

383+
if err := common.ValidateUTF8String("WorkflowType", request.WorkflowType.GetName()); err != nil {
384+
return nil, err
385+
}
383386
if err := wh.validateTaskQueue(request.TaskQueue); err != nil {
384387
return nil, err
385388
}
@@ -388,12 +391,12 @@ func (wh *WorkflowHandler) StartWorkflowExecution(ctx context.Context, request *
388391
return nil, err
389392
}
390393

391-
if request.GetRequestId() == "" {
394+
if request.RequestId == "" {
392395
return nil, errRequestIDNotSet
393396
}
394397

395-
if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() {
396-
return nil, errRequestIDTooLong
398+
if err := validateRequestId(&request.RequestId, wh.config.MaxIDLengthLimit()); err != nil {
399+
return nil, err
397400
}
398401

399402
request, err := wh.unaliasStartWorkflowExecutionRequestSearchAttributes(request, namespaceName)
@@ -2002,12 +2005,16 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution(ctx context.Context,
20022005
return nil, errWorkflowTypeTooLong
20032006
}
20042007

2008+
if err := common.ValidateUTF8String("WorkflowType", request.WorkflowType.GetName()); err != nil {
2009+
return nil, err
2010+
}
2011+
20052012
if err := wh.validateTaskQueue(request.TaskQueue); err != nil {
20062013
return nil, err
20072014
}
20082015

2009-
if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() {
2010-
return nil, errRequestIDTooLong
2016+
if err := validateRequestId(&request.RequestId, wh.config.MaxIDLengthLimit()); err != nil {
2017+
return nil, err
20112018
}
20122019

20132020
if err := wh.validateSignalWithStartWorkflowTimeouts(request); err != nil {
@@ -4348,6 +4355,9 @@ func (wh *WorkflowHandler) validateTaskQueue(t *taskqueuepb.TaskQueue) error {
43484355
if len(t.GetName()) > wh.config.MaxIDLengthLimit() {
43494356
return errTaskQueueTooLong
43504357
}
4358+
if err := common.ValidateUTF8String("TaskQueue", t.GetName()); err != nil {
4359+
return err
4360+
}
43514361

43524362
enums.SetDefaultTaskQueueKind(&t.Kind)
43534363
return nil
@@ -4356,26 +4366,33 @@ func (wh *WorkflowHandler) validateTaskQueue(t *taskqueuepb.TaskQueue) error {
43564366
func (wh *WorkflowHandler) validateBuildIdOrderingUpdate(
43574367
req *workflowservice.UpdateWorkerBuildIdOrderingRequest,
43584368
) error {
4359-
errstr := "request to update worker build id ordering requires:"
4360-
hadErr := false
4369+
errDeets := []string{"request to update worker build id compatability requires: "}
4370+
4371+
checkIdLen := func(id string) {
4372+
if len(id) > wh.config.WorkerBuildIdSizeLimit() {
4373+
errDeets = append(errDeets, fmt.Sprintf(" Worker build IDs to be no larger than %v characters",
4374+
wh.config.WorkerBuildIdSizeLimit()))
4375+
}
4376+
4377+
if err := common.ValidateUTF8String("BuildId", id); err != nil {
4378+
errDeets = append(errDeets, err.Error())
4379+
}
4380+
}
4381+
43614382
if req.GetNamespace() == "" {
4362-
errstr += " `namespace` to be set"
4363-
hadErr = true
4383+
errDeets = append(errDeets, "`namespace` to be set")
43644384
}
43654385
if req.GetTaskQueue() == "" {
4366-
errstr += " `task_queue` to be set"
4367-
hadErr = true
4386+
errDeets = append(errDeets, "`task_queue` to be set")
43684387
}
43694388
if req.GetVersionId().GetWorkerBuildId() == "" {
4370-
errstr += " targeting a valid version identifier"
4371-
hadErr = true
4372-
}
4373-
if len(req.GetVersionId().GetWorkerBuildId()) > wh.config.WorkerBuildIdSizeLimit() {
4374-
errstr += fmt.Sprintf(" Worker build IDs to be no larger than %v characters", wh.config.WorkerBuildIdSizeLimit())
4375-
hadErr = true
4389+
errDeets = append(errDeets, "targeting a valid version identifier")
4390+
} else {
4391+
checkIdLen(req.GetVersionId().GetWorkerBuildId())
43764392
}
4377-
if hadErr {
4378-
return serviceerror.NewInvalidArgument(errstr)
4393+
4394+
if len(errDeets) > 1 {
4395+
return serviceerror.NewInvalidArgument(strings.Join(errDeets, ", "))
43794396
}
43804397
return nil
43814398
}
@@ -4657,6 +4674,24 @@ func (wh *WorkflowHandler) validateRetryPolicy(namespaceName namespace.Name, ret
46574674
return common.ValidateRetryPolicy(retryPolicy)
46584675
}
46594676

4677+
func validateRequestId(requestID *string, lenLimit int) error {
4678+
if requestID == nil {
4679+
// should never happen, but just in case.
4680+
return serviceerror.NewInvalidArgument("RequestId is nil")
4681+
}
4682+
if *requestID == "" {
4683+
// For easy direct API use, we default the request ID here but expect all
4684+
// SDKs and other auto-retrying clients to set it
4685+
*requestID = uuid.New()
4686+
}
4687+
4688+
if len(*requestID) > lenLimit {
4689+
return errRequestIDTooLong
4690+
}
4691+
4692+
return common.ValidateUTF8String("RequestId", *requestID)
4693+
}
4694+
46604695
func (wh *WorkflowHandler) validateStartWorkflowTimeouts(
46614696
request *workflowservice.StartWorkflowExecutionRequest,
46624697
) error {
@@ -4753,7 +4788,7 @@ func (wh *WorkflowHandler) makeFakeContinuedAsNewEvent(
47534788
func (wh *WorkflowHandler) validateNamespace(
47544789
namespace string,
47554790
) error {
4756-
if err := wh.validateUTF8String(namespace); err != nil {
4791+
if err := common.ValidateUTF8String("Namespace", namespace); err != nil {
47574792
return err
47584793
}
47594794
if len(namespace) > wh.config.MaxIDLengthLimit() {
@@ -4768,7 +4803,7 @@ func (wh *WorkflowHandler) validateWorkflowID(
47684803
if workflowID == "" {
47694804
return errWorkflowIDNotSet
47704805
}
4771-
if err := wh.validateUTF8String(workflowID); err != nil {
4806+
if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil {
47724807
return err
47734808
}
47744809
if len(workflowID) > wh.config.MaxIDLengthLimit() {

service/frontend/workflow_handler_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,6 +2655,18 @@ func TestContextNearDeadline(t *testing.T) {
26552655
assert.False(t, contextNearDeadline(ctx, time.Millisecond))
26562656
}
26572657

2658+
func TestValidateRequestId(t *testing.T) {
2659+
req := workflowservice.StartWorkflowExecutionRequest{RequestId: ""}
2660+
err := validateRequestId(&req.RequestId, 100)
2661+
assert.Nil(t, err)
2662+
assert.Len(t, req.RequestId, 36) // new UUID length
2663+
2664+
req.RequestId = "\x87\x01"
2665+
err = validateRequestId(&req.RequestId, 100)
2666+
assert.Error(t, err)
2667+
assert.Contains(t, err.Error(), "not a valid UTF-8 string")
2668+
}
2669+
26582670
func (s *workflowHandlerSuite) Test_DeleteWorkflowExecution() {
26592671
config := s.newConfig()
26602672
wh := s.getWorkflowHandler(config)

service/history/commandChecker.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ func (v *commandAttrValidator) validateTimerScheduleAttributes(
380380
if len(attributes.GetTimerId()) > v.maxIDLengthLimit {
381381
return failedCause, serviceerror.NewInvalidArgument("TimerId exceeds length limit.")
382382
}
383+
if err := common.ValidateUTF8String("TimerId", attributes.TimerId); err != nil {
384+
return failedCause, err
385+
}
383386
if timestamp.DurationValue(attributes.GetStartToFireTimeout()) <= 0 {
384387
return failedCause, serviceerror.NewInvalidArgument("A valid StartToFireTimeout is not set on command.")
385388
}
@@ -498,7 +501,22 @@ func (v *commandAttrValidator) validateCancelExternalWorkflowExecutionAttributes
498501
if len(attributes.GetWorkflowId()) > v.maxIDLengthLimit {
499502
return failedCause, serviceerror.NewInvalidArgument("WorkflowId exceeds length limit.")
500503
}
501-
runID := attributes.GetRunId()
504+
runID := attributes.RunId
505+
workflowID := attributes.WorkflowId
506+
ns := attributes.Namespace
507+
508+
if workflowID == "" {
509+
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowId is not set on RequestCancelExternalWorkflowExecutionCommand. Namespace=%s RunId=%s", ns, runID))
510+
}
511+
if len(ns) > v.maxIDLengthLimit {
512+
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Namespace on RequestCancelExternalWorkflowExecutionCommand exceeds length limit. WorkflowId=%s RunId=%s Namespace=%s Length=%d Limit=%d", workflowID, runID, ns, len(ns), v.maxIDLengthLimit))
513+
}
514+
if len(workflowID) > v.maxIDLengthLimit {
515+
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowId on RequestCancelExternalWorkflowExecutionCommand exceeds length limit. WorkflowId=%s Length=%d Limit=%d RunId=%s Namespace=%s", workflowID, len(workflowID), v.maxIDLengthLimit, runID, ns))
516+
}
517+
if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil {
518+
return failedCause, err
519+
}
502520
if runID != "" && uuid.Parse(runID) == nil {
503521
return failedCause, serviceerror.NewInvalidArgument("Invalid RunId set on command.")
504522
}
@@ -540,6 +558,22 @@ func (v *commandAttrValidator) validateSignalExternalWorkflowExecutionAttributes
540558
}
541559

542560
targetRunID := attributes.Execution.GetRunId()
561+
signalName := attributes.SignalName
562+
workflowID := attributes.Execution.WorkflowId
563+
ns := attributes.Namespace
564+
565+
if workflowID == "" {
566+
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowId is not set on SignalExternalWorkflowExecutionCommand. Namespace=%s RunId=%s SignalName=%s", ns, targetRunID, signalName))
567+
}
568+
if len(ns) > v.maxIDLengthLimit {
569+
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Namespace on SignalExternalWorkflowExecutionCommand exceeds length limit. WorkflowId=%s Namespace=%s Length=%d Limit=%d RunId=%s SignalName=%s", workflowID, ns, len(ns), v.maxIDLengthLimit, targetRunID, signalName))
570+
}
571+
if len(workflowID) > v.maxIDLengthLimit {
572+
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowId on SignalExternalWorkflowExecutionCommand exceeds length limit. WorkflowId=%s Length=%d Limit=%d Namespace=%s RunId=%s SignalName=%s", workflowID, len(workflowID), v.maxIDLengthLimit, ns, targetRunID, signalName))
573+
}
574+
if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil {
575+
return failedCause, err
576+
}
543577
if targetRunID != "" && uuid.Parse(targetRunID) == nil {
544578
return failedCause, serviceerror.NewInvalidArgument("Invalid RunId set on command.")
545579
}
@@ -791,6 +825,14 @@ func (v *commandAttrValidator) validateStartChildExecutionAttributes(
791825
return failedCause, serviceerror.NewInvalidArgument("WorkflowType exceeds length limit.")
792826
}
793827

828+
if err := common.ValidateUTF8String("WorkflowId", attributes.WorkflowId); err != nil {
829+
return failedCause, err
830+
}
831+
832+
if err := common.ValidateUTF8String("WorkflowType", attributes.WorkflowType.Name); err != nil {
833+
return failedCause, err
834+
}
835+
794836
if timestamp.DurationValue(attributes.GetWorkflowExecutionTimeout()) < 0 {
795837
return failedCause, serviceerror.NewInvalidArgument("Invalid WorkflowExecutionTimeout.")
796838
}
@@ -868,6 +910,10 @@ func (v *commandAttrValidator) validateTaskQueue(
868910
return taskQueue, serviceerror.NewInvalidArgument(fmt.Sprintf("task queue name exceeds length limit of %v", v.maxIDLengthLimit))
869911
}
870912

913+
if err := common.ValidateUTF8String("TaskQueue", name); err != nil {
914+
return taskQueue, err
915+
}
916+
871917
if strings.HasPrefix(name, reservedTaskQueuePrefix) {
872918
return taskQueue, serviceerror.NewInvalidArgument(fmt.Sprintf("task queue name cannot start with reserved prefix %v", reservedTaskQueuePrefix))
873919
}

0 commit comments

Comments
 (0)