diff --git a/pkg/engine/jmespath/arithmetic_test.go b/pkg/engine/jmespath/arithmetic_test.go index b9b4f93ff5..03d76a4c3a 100644 --- a/pkg/engine/jmespath/arithmetic_test.go +++ b/pkg/engine/jmespath/arithmetic_test.go @@ -106,6 +106,152 @@ func Test_Add(t *testing.T) { } } +func Test_Sum(t *testing.T) { + testCases := []struct { + name string + test string + expectedResult interface{} + err bool + retFloat bool + }{ + // Scalar + { + name: "sum([]) -> error", + test: "sum([])", + err: true, + }, + { + name: "sum(Scalar[]) -> Scalar", + test: "sum([`12`])", + expectedResult: 12.0, + retFloat: true, + }, + { + name: "sum(Scalar[]) -> Scalar", + test: "sum([`12`, `13`, `1`, `4`])", + expectedResult: 30.0, + retFloat: true, + }, + { + name: "sum(Scalar[]) -> Scalar", + test: "sum([`12`, `13`])", + expectedResult: 25.0, + retFloat: true, + }, + { + name: "sum(Scalar[Scalar, Duration, ..]) -> error", + test: "sum(['12', '13s'])", + err: true, + }, + { + name: "sum(Scalar[Scalar, Quantity, ..]) -> error", + test: "sum([`12`, '13Ki'])", + err: true, + }, + { + name: "sum(Scalar[Scalar, Quatity, ..]) -> error", + test: "sum([`12`, '13'])", + err: true, + }, + // Quantity + { + name: "sum([]) -> error", + test: "sum([])", + err: true, + }, + { + name: "sum(Quantity[]) -> Quantity", + test: "sum(['12Ki'])", + expectedResult: `12Ki`, + }, + { + name: "sum(Quantity[]) -> Quantity", + test: "sum(['12Ki', '13Ki', '1Ki', '4Ki'])", + expectedResult: `30Ki`, + }, + { + name: "sum(Quantity[]) -> Quantity", + test: "sum(['12Ki', '13Ki'])", + expectedResult: `25Ki`, + }, + { + name: "sum(Quantity[]) -> Quantity", + test: "sum(['12Ki', '13'])", + expectedResult: `12301`, + }, + { + name: "sum(Quantity[Quantity, Duration, ..]) -> error", + test: "sum(['12Ki', '13s'])", + err: true, + }, + { + name: "sum(Quantity[Quantity, Scalar, ..]) -> error", + test: "sum(['12Ki', `13`])", + err: true, + }, + // Duration + { + name: "sum([]) -> error", + test: "sum([])", + err: true, + }, + { + name: "sum(Duration[]) -> Duration", + test: "sum(['12s'])", + expectedResult: `12s`, + }, + { + name: "sum(Duration[]) -> Duration", + test: "sum(['12s', '13s', '1s', '4s'])", + expectedResult: `30s`, + }, + { + name: "sum(Duration[]) -> Duration", + test: "sum(['12s', '13s'])", + expectedResult: `25s`, + }, + { + name: "sum(Duration[Duration, Scalar, ..]) -> error", + test: "sum(['12s', `13`])", + err: true, + }, + { + name: "sum(Duration[Duration, Quantity, ..]) -> error", + test: "sum(['12s', '13Ki'])", + err: true, + }, + { + name: "sum(Duration[Duration, Quantity, ..]) -> error", + test: "sum(['12s', '13'])", + err: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jp, err := New(tc.test) + assert.NilError(t, err) + + result, err := jp.Search("") + if !tc.err { + assert.NilError(t, err) + } else { + assert.Assert(t, err != nil) + return + } + + if tc.retFloat { + equal, ok := result.(float64) + assert.Assert(t, ok) + assert.Equal(t, equal, tc.expectedResult.(float64)) + } else { + equal, ok := result.(string) + assert.Assert(t, ok) + assert.Equal(t, equal, tc.expectedResult.(string)) + } + }) + } +} + func Test_Subtract(t *testing.T) { testCases := []struct { name string diff --git a/pkg/engine/jmespath/functions.go b/pkg/engine/jmespath/functions.go index c6f01133d8..5d821af78f 100644 --- a/pkg/engine/jmespath/functions.go +++ b/pkg/engine/jmespath/functions.go @@ -51,6 +51,7 @@ var ( labelMatch = "label_match" toBoolean = "to_boolean" add = "add" + sum = "sum" subtract = "subtract" multiply = "multiply" divide = "divide" @@ -248,6 +249,16 @@ func GetFunctions() []FunctionEntry { }, ReturnType: []jpType{jpAny}, Note: "does arithmetic addition of two specified values of numbers, quantities, and durations", + }, { + FunctionEntry: gojmespath.FunctionEntry{ + Name: sum, + Arguments: []argSpec{ + {Types: []jpType{jpArray}}, + }, + Handler: jpSum, + }, + ReturnType: []jpType{jpAny}, + Note: "does arithmetic addition of specified array of values of numbers, quantities, and durations", }, { FunctionEntry: gojmespath.FunctionEntry{ Name: subtract, @@ -771,6 +782,25 @@ func jpAdd(arguments []interface{}) (interface{}, error) { return op1.Add(op2) } +func jpSum(arguments []interface{}) (interface{}, error) { + items, ok := arguments[0].([]interface{}) + if !ok { + return nil, formatError(typeMismatchError, sum) + } + if len(items) == 0 { + return nil, formatError(genericError, sum, "at least one element in the array is required") + } + var err error + sum := items[0] + for _, item := range items[1:] { + sum, err = jpAdd([]interface{}{sum, item}) + if err != nil { + return nil, err + } + } + return sum, nil +} + func jpSubtract(arguments []interface{}) (interface{}, error) { op1, op2, err := ParseArithemticOperands(arguments, subtract) if err != nil {