diff --git a/pkg/engine/jmespath/arithmetic.go b/pkg/engine/jmespath/arithmetic.go index 84d813c7af..2fbaac5e9f 100644 --- a/pkg/engine/jmespath/arithmetic.go +++ b/pkg/engine/jmespath/arithmetic.go @@ -30,25 +30,29 @@ type scalar struct { float64 } -func parseArithemticOperands(arguments []interface{}, operator string) (operand, operand, error) { - op := [2]operand{nil, nil} - for i := 0; i < 2; i++ { - if tmp, err := validateArg(divide, arguments, i, reflect.Float64); err == nil { - var sc scalar - sc.float64 = tmp.Float() - op[i] = sc - } else if tmp, err = validateArg(divide, arguments, i, reflect.String); err == nil { - if q, err := resource.ParseQuantity(tmp.String()); err == nil { - op[i] = quantity{Quantity: q} - } else if d, err := time.ParseDuration(tmp.String()); err == nil { - op[i] = duration{Duration: d} - } +func parseArithemticOperand(arguments []interface{}, index int, operator string) (operand, error) { + if tmp, err := validateArg(operator, arguments, index, reflect.Float64); err == nil { + return scalar{float64: tmp.Float()}, nil + } else if tmp, err = validateArg(operator, arguments, index, reflect.String); err == nil { + if q, err := resource.ParseQuantity(tmp.String()); err == nil { + return quantity{Quantity: q}, nil + } else if d, err := time.ParseDuration(tmp.String()); err == nil { + return duration{Duration: d}, nil } } - if op[0] == nil || op[1] == nil { - return nil, nil, formatError(genericError, operator, "invalid operands") + return nil, formatError(genericError, operator, "invalid operand") +} + +func parseArithemticOperands(arguments []interface{}, operator string) (operand, operand, error) { + left, err := parseArithemticOperand(arguments, 0, operator) + if err != nil { + return nil, nil, err } - return op[0], op[1], nil + right, err := parseArithemticOperand(arguments, 1, operator) + if err != nil { + return nil, nil, err + } + return left, right, nil } // Quantity +|- Quantity -> Quantity diff --git a/pkg/engine/jmespath/arithmetic_test.go b/pkg/engine/jmespath/arithmetic_test.go index 4b4c56fdd5..4da8c7ec6e 100644 --- a/pkg/engine/jmespath/arithmetic_test.go +++ b/pkg/engine/jmespath/arithmetic_test.go @@ -768,6 +768,75 @@ func Test_Modulo(t *testing.T) { } } +func Test_Round(t *testing.T) { + testCases := []struct { + name string + test string + expectedResult interface{} + err bool + retFloat bool + }{ + // Scalar + { + name: "Scalar roundoff Quantity -> error", + test: "round(`23`, '12Ki')", + err: true, + }, + { + name: "Scalar roundoff Duration -> error", + test: "round(`21`, '5s')", + err: true, + }, + { + name: "Scalar roundoff Scalar -> Scalar", + test: "round(`9.414675`, `2`)", + expectedResult: 9.41, + retFloat: true, + }, + { + name: "Scalar roundoff zero -> error", + test: "round(`14.123`, `6`)", + expectedResult: 14.123, + retFloat: true, + }, + // round with non int values + { + name: "Scalar roundoff Non int -> error", + test: "round(`14`, `1.5`)", + err: true, + }, + { + name: "Scalar roundoff negative int -> error", + test: "round(`14`, `-2`)", + err: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jp, err := newJMESPath(cfg, tc.test) + assert.NilError(t, err) + + result, err := jp.Search("") + if !tc.err { + assert.NilError(t, err) + } else { + assert.Assert(t, err != nil) + return + } + + if tc.retFloat { + equal, ok := result.(float64) + assert.Assert(t, ok) + assert.Equal(t, equal, tc.expectedResult.(float64)) + } else { + equal, ok := result.(string) + assert.Assert(t, ok) + assert.Equal(t, equal, tc.expectedResult.(string)) + } + }) + } +} + func TestScalar_Multiply(t *testing.T) { type fields struct { float64 float64 diff --git a/pkg/engine/jmespath/error.go b/pkg/engine/jmespath/error.go index be462a754d..73808a5d29 100644 --- a/pkg/engine/jmespath/error.go +++ b/pkg/engine/jmespath/error.go @@ -12,6 +12,7 @@ const ( zeroDivisionError = errorPrefix + "Zero divisor passed" nonIntModuloError = errorPrefix + "Non-integer argument(s) passed for modulo" typeMismatchError = errorPrefix + "Types mismatch" + nonIntRoundError = errorPrefix + "Non-integer argument(s) passed for round off" ) func formatError(format string, function string, values ...interface{}) error { diff --git a/pkg/engine/jmespath/functions.go b/pkg/engine/jmespath/functions.go index bb00524633..d736aa2888 100644 --- a/pkg/engine/jmespath/functions.go +++ b/pkg/engine/jmespath/functions.go @@ -11,6 +11,7 @@ import ( "encoding/pem" "errors" "fmt" + "math" "path/filepath" "reflect" "regexp" @@ -58,6 +59,7 @@ var ( multiply = "multiply" divide = "divide" modulo = "modulo" + round = "round" base64Decode = "base64_decode" base64Encode = "base64_encode" pathCanonicalize = "path_canonicalize" @@ -307,6 +309,17 @@ func GetFunctions(configuration config.Configuration) []FunctionEntry { }, ReturnType: []jpType{jpAny}, Note: "divisor must be non-zero, arguments must be integers", + }, { + FunctionEntry: gojmespath.FunctionEntry{ + Name: round, + Arguments: []argSpec{ + {Types: []jpType{jpNumber}}, + {Types: []jpType{jpNumber}}, + }, + Handler: jpRound, + }, + ReturnType: []jpType{jpNumber}, + Note: "does roundoff to upto the given decimal places", }, { FunctionEntry: gojmespath.FunctionEntry{ Name: base64Decode, @@ -865,6 +878,27 @@ func jpModulo(arguments []interface{}) (interface{}, error) { return op1.Modulo(op2) } +func jpRound(arguments []interface{}) (interface{}, error) { + op, err := validateArg(round, arguments, 0, reflect.Float64) + if err != nil { + return nil, err + } + length, err := validateArg(round, arguments, 1, reflect.Float64) + if err != nil { + return nil, err + } + intLength, err := intNumber(length.Float()) + if err != nil { + return nil, formatError(nonIntRoundError, round) + } + if intLength < 0 { + return nil, formatError(argOutOfBoundsError, round) + } + shift := math.Pow(10, float64(intLength)) + rounded := math.Round(op.Float()*shift) / shift + return rounded, nil +} + func jpBase64Decode(arguments []interface{}) (interface{}, error) { var err error str, err := validateArg("", arguments, 0, reflect.String)