From aafc4fe97e99ccdc8517d536bab7a7bd7ad98eaa Mon Sep 17 00:00:00 2001
From: Khaled Emara <khaled.emara@nirmata.com>
Date: Fri, 5 Jul 2024 18:06:48 +0300
Subject: [PATCH] fix(json-ctx): overwrite element each iteration (#10615)

Signed-off-by: Khaled Emara <khaled.emara@nirmata.com>
---
 pkg/engine/context/context.go    | 36 ++++++++++++++++----------------
 pkg/engine/context/utils.go      | 12 +++++------
 pkg/engine/context/utils_test.go | 34 ++++++++++++++++++++++++++----
 3 files changed, 54 insertions(+), 28 deletions(-)

diff --git a/pkg/engine/context/context.go b/pkg/engine/context/context.go
index 21f9bfd856..2643eb6ba5 100644
--- a/pkg/engine/context/context.go
+++ b/pkg/engine/context/context.go
@@ -109,7 +109,7 @@ type Interface interface {
 	Reset()
 
 	// AddJSON  merges the json map with context
-	addJSON(dataMap map[string]interface{}) error
+	addJSON(dataMap map[string]interface{}, overwriteMaps bool) error
 }
 
 // Context stores the data resources as JSON
@@ -138,8 +138,8 @@ func NewContextFromRaw(jp jmespath.Interface, raw map[string]interface{}) Interf
 }
 
 // addJSON merges json data
-func (ctx *context) addJSON(dataMap map[string]interface{}) error {
-	mergeMaps(dataMap, ctx.jsonRaw)
+func (ctx *context) addJSON(dataMap map[string]interface{}, overwriteMaps bool) error {
+	mergeMaps(dataMap, ctx.jsonRaw, overwriteMaps)
 	return nil
 }
 
@@ -166,7 +166,7 @@ func (ctx *context) AddRequest(request admissionv1.AdmissionRequest) error {
 		return err
 	}
 
-	if err := addToContext(ctx, mapObj, "request"); err != nil {
+	if err := addToContext(ctx, mapObj, false, "request"); err != nil {
 		return err
 	}
 
@@ -180,7 +180,7 @@ func (ctx *context) AddVariable(key string, value interface{}) error {
 	if fields, err := reader.Read(); err != nil {
 		return err
 	} else {
-		return addToContext(ctx, value, fields...)
+		return addToContext(ctx, value, false, fields...)
 	}
 }
 
@@ -190,7 +190,7 @@ func (ctx *context) AddContextEntry(name string, dataRaw []byte) error {
 		logger.Error(err, "failed to unmarshal the resource")
 		return err
 	}
-	return addToContext(ctx, data, name)
+	return addToContext(ctx, data, false, name)
 }
 
 func (ctx *context) ReplaceContextEntry(name string, dataRaw []byte) error {
@@ -200,34 +200,34 @@ func (ctx *context) ReplaceContextEntry(name string, dataRaw []byte) error {
 		return err
 	}
 	// Adding a nil entry to clean out any existing data in the context with the entry name
-	if err := addToContext(ctx, nil, name); err != nil {
+	if err := addToContext(ctx, nil, false, name); err != nil {
 		logger.Error(err, "unable to replace context entry", "context entry name", name)
 		return err
 	}
-	return addToContext(ctx, data, name)
+	return addToContext(ctx, data, false, name)
 }
 
 // AddResource data at path: request.object
 func (ctx *context) AddResource(data map[string]interface{}) error {
 	clearLeafValue(ctx.jsonRaw, "request", "object")
-	return addToContext(ctx, data, "request", "object")
+	return addToContext(ctx, data, false, "request", "object")
 }
 
 // AddOldResource data at path: request.oldObject
 func (ctx *context) AddOldResource(data map[string]interface{}) error {
 	clearLeafValue(ctx.jsonRaw, "request", "oldObject")
-	return addToContext(ctx, data, "request", "oldObject")
+	return addToContext(ctx, data, false, "request", "oldObject")
 }
 
 // AddTargetResource adds data at path: target
 func (ctx *context) SetTargetResource(data map[string]interface{}) error {
 	clearLeafValue(ctx.jsonRaw, "target")
-	return addToContext(ctx, data, "target")
+	return addToContext(ctx, data, false, "target")
 }
 
 // AddOperation data at path: request.operation
 func (ctx *context) AddOperation(data string) error {
-	if err := addToContext(ctx, data, "request", "operation"); err != nil {
+	if err := addToContext(ctx, data, false, "request", "operation"); err != nil {
 		return err
 	}
 
@@ -238,7 +238,7 @@ func (ctx *context) AddOperation(data string) error {
 // AddUserInfo adds userInfo at path request.userInfo
 func (ctx *context) AddUserInfo(userRequestInfo kyvernov2.RequestInfo) error {
 	if data, err := toUnstructured(&userRequestInfo); err == nil {
-		return addToContext(ctx, data, "request")
+		return addToContext(ctx, data, false, "request")
 	} else {
 		return err
 	}
@@ -265,7 +265,7 @@ func (ctx *context) AddServiceAccount(userName string) error {
 		"serviceAccountName":      saName,
 		"serviceAccountNamespace": saNamespace,
 	}
-	if err := ctx.addJSON(data); err != nil {
+	if err := ctx.addJSON(data, false); err != nil {
 		return err
 	}
 
@@ -275,7 +275,7 @@ func (ctx *context) AddServiceAccount(userName string) error {
 
 // AddNamespace merges resource json under request.namespace
 func (ctx *context) AddNamespace(namespace string) error {
-	return addToContext(ctx, namespace, "request", "namespace")
+	return addToContext(ctx, namespace, false, "request", "namespace")
 }
 
 func (ctx *context) AddElement(data interface{}, index, nesting int) error {
@@ -287,7 +287,7 @@ func (ctx *context) AddElement(data interface{}, index, nesting int) error {
 		"elementIndex":     int64(index),
 		nestedElementIndex: int64(index),
 	}
-	return addToContext(ctx, data)
+	return addToContext(ctx, data, true)
 }
 
 func (ctx *context) AddImageInfo(info apiutils.ImageInfo, cfg config.Configuration) error {
@@ -300,7 +300,7 @@ func (ctx *context) AddImageInfo(info apiutils.ImageInfo, cfg config.Configurati
 		"tag":              info.Tag,
 		"digest":           info.Digest,
 	}
-	return addToContext(ctx, data, "image")
+	return addToContext(ctx, data, false, "image")
 }
 
 func (ctx *context) AddImageInfos(resource *unstructured.Unstructured, cfg config.Configuration) error {
@@ -323,7 +323,7 @@ func (ctx *context) addImageInfos(images map[string]map[string]apiutils.ImageInf
 	}
 
 	logging.V(4).Info("updated image info", "images", utm)
-	return addToContext(ctx, utm, "images")
+	return addToContext(ctx, utm, false, "images")
 }
 
 func convertImagesToUnstructured(images map[string]map[string]apiutils.ImageInfo) (map[string]interface{}, error) {
diff --git a/pkg/engine/context/utils.go b/pkg/engine/context/utils.go
index d13d83a868..85fbee0b26 100644
--- a/pkg/engine/context/utils.go
+++ b/pkg/engine/context/utils.go
@@ -8,7 +8,7 @@ import (
 
 // AddJSONObject merges json data
 func AddJSONObject(ctx Interface, data map[string]interface{}) error {
-	return ctx.addJSON(data)
+	return ctx.addJSON(data, false)
 }
 
 func AddResource(ctx Interface, dataRaw []byte) error {
@@ -29,12 +29,12 @@ func AddOldResource(ctx Interface, dataRaw []byte) error {
 	return ctx.AddOldResource(data)
 }
 
-func addToContext(ctx *context, data interface{}, tags ...string) error {
+func addToContext(ctx *context, data interface{}, overwriteMaps bool, tags ...string) error {
 	if v, err := convertStructs(data); err != nil {
 		return err
 	} else {
 		dataRaw := push(v, tags...)
-		return ctx.addJSON(dataRaw)
+		return ctx.addJSON(dataRaw, overwriteMaps)
 	}
 }
 
@@ -90,11 +90,11 @@ func push(data interface{}, tags ...string) map[string]interface{} {
 }
 
 // mergeMaps merges srcMap entries into destMap
-func mergeMaps(srcMap, destMap map[string]interface{}) {
+func mergeMaps(srcMap, destMap map[string]interface{}, overwriteMaps bool) {
 	for k, v := range srcMap {
-		if nextSrcMap, ok := v.(map[string]interface{}); ok {
+		if nextSrcMap, ok := v.(map[string]interface{}); ok && !overwriteMaps {
 			if nextDestMap, ok := destMap[k].(map[string]interface{}); ok {
-				mergeMaps(nextSrcMap, nextDestMap)
+				mergeMaps(nextSrcMap, nextDestMap, overwriteMaps)
 			} else {
 				destMap[k] = nextSrcMap
 			}
diff --git a/pkg/engine/context/utils_test.go b/pkg/engine/context/utils_test.go
index 11f758e905..2abc5898d9 100644
--- a/pkg/engine/context/utils_test.go
+++ b/pkg/engine/context/utils_test.go
@@ -33,7 +33,7 @@ func TestMergeMaps(t *testing.T) {
 		},
 	}
 
-	mergeMaps(map1, map2)
+	mergeMaps(map1, map2, false)
 
 	assert.Equal(t, "bar1", map2["strVal"])
 	assert.Equal(t, "bar2", map2["strVal2"])
@@ -52,7 +52,7 @@ func TestMergeMaps(t *testing.T) {
 	}
 
 	ctxMap := map[string]interface{}{}
-	mergeMaps(requestObj, ctxMap)
+	mergeMaps(requestObj, ctxMap, false)
 
 	r := ctxMap["request"].(map[string]interface{})
 	o := r["object"].(map[string]interface{})
@@ -67,7 +67,7 @@ func TestMergeMaps(t *testing.T) {
 		},
 	}
 
-	mergeMaps(requestObj2, ctxMap)
+	mergeMaps(requestObj2, ctxMap, false)
 	r2 := ctxMap["request"].(map[string]interface{})
 	o2 := r2["object"].(map[string]interface{})
 	assert.Equal(t, "bar2", o2["foo"])
@@ -79,13 +79,39 @@ func TestMergeMaps(t *testing.T) {
 		},
 	}
 
-	mergeMaps(request3, ctxMap)
+	mergeMaps(request3, ctxMap, false)
 	r3 := ctxMap["request"].(map[string]interface{})
 	o3 := r3["object"].(map[string]interface{})
 	assert.NotNil(t, o3)
 	assert.Equal(t, "bar2", o2["foo"])
 	assert.Equal(t, "bar2", o2["foo2"])
 	assert.Equal(t, "user1", r3["userInfo"])
+
+	request4 := map[string]interface{}{
+		"request": map[string]interface{}{
+			"object": map[string]interface{}{
+				"foo": "bar3",
+			},
+		},
+	}
+
+	mergeMaps(request4, ctxMap, false)
+	r4 := ctxMap["request"].(map[string]interface{})
+	assert.NotNil(t, r4)
+	assert.Equal(t, "user1", r4["userInfo"])
+
+	request5 := map[string]interface{}{
+		"request": map[string]interface{}{
+			"object": map[string]interface{}{
+				"foo": "bar4",
+			},
+		},
+	}
+
+	mergeMaps(request5, ctxMap, true)
+	r5 := ctxMap["request"].(map[string]interface{})
+	userInfo := r5["userInfo"]
+	assert.Nil(t, userInfo)
 }
 
 func TestStructToUntypedMap(t *testing.T) {