1
0
Fork 0
mirror of https://github.com/kyverno/kyverno.git synced 2024-12-15 17:51:20 +00:00

feat: add globalcontext controller (#9601)

* feat: add globalcontext controller

Signed-off-by: Khaled Emara <khaled.emara@nirmata.com>

* rework controller

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* rbac

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* cmd

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* fix rbac

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* engine

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* k8s resources

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* k8s resource

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* resync zero

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* api call

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* api call

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* clean

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

* fix linter

Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>

---------

Signed-off-by: Khaled Emara <khaled.emara@nirmata.com>
Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>
Co-authored-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>
This commit is contained in:
Khaled Emara 2024-02-02 12:41:35 +02:00 committed by GitHub
parent 3510998d4f
commit 226fa9515a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 611 additions and 19 deletions

View file

@ -74,6 +74,13 @@ rules:
- update
- watch
- deletecollection
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- reports.kyverno.io
resources:

View file

@ -41,6 +41,13 @@ rules:
- update
- watch
- deletecollection
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- ''
resources:

View file

@ -51,6 +51,13 @@ rules:
verbs:
- list
- watch
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- kyverno.io
resources:

View file

@ -34,6 +34,13 @@ rules:
- get
- list
- watch
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- kyverno.io
resources:

View file

@ -115,6 +115,7 @@ func main() {
internal.WithKyvernoDynamicClient(),
internal.WithEventsClient(),
internal.WithApiServerClient(),
internal.WithGlobalContext(),
internal.WithFlagSets(flagset),
)
// parse flags

View file

@ -19,7 +19,9 @@ import (
"github.com/kyverno/kyverno/pkg/controllers/cleanup"
genericloggingcontroller "github.com/kyverno/kyverno/pkg/controllers/generic/logging"
genericwebhookcontroller "github.com/kyverno/kyverno/pkg/controllers/generic/webhook"
"github.com/kyverno/kyverno/pkg/controllers/globalcontext"
ttlcontroller "github.com/kyverno/kyverno/pkg/controllers/ttl"
globalcontextstore "github.com/kyverno/kyverno/pkg/engine/globalcontext/store"
"github.com/kyverno/kyverno/pkg/event"
"github.com/kyverno/kyverno/pkg/informers"
"github.com/kyverno/kyverno/pkg/leaderelection"
@ -101,6 +103,7 @@ func main() {
internal.WithDeferredLoading(),
internal.WithMetadataClient(),
internal.WithApiServerClient(),
internal.WithGlobalContext(),
internal.WithFlagSets(flagset),
)
// parse flags
@ -156,6 +159,16 @@ func main() {
eventGenerator,
event.Workers,
)
store := globalcontextstore.New()
gceController := internal.NewController(
globalcontext.ControllerName,
globalcontext.NewController(
kyvernoInformer.Kyverno().V2alpha1().GlobalContextEntries(),
setup.KyvernoDynamicClient,
store,
),
globalcontext.Workers,
)
// start informers and wait for cache sync
if !internal.StartInformersAndWaitForCacheSync(ctx, setup.Logger, kubeInformer, kyvernoInformer) {
os.Exit(1)
@ -349,6 +362,7 @@ func main() {
defer server.Stop()
// start non leader controllers
eventController.Run(ctx, setup.Logger, &wg)
gceController.Run(ctx, setup.Logger, &wg)
// start leader election
le.Run(ctx)
// wait for everything to shut down and exit

View file

@ -22,6 +22,7 @@ type Configuration interface {
UsesMetadataClient() bool
UsesKyvernoDynamicClient() bool
UsesEventsClient() bool
UsesGlobalContext() bool
FlagSets() []*flag.FlagSet
}
@ -139,6 +140,12 @@ func WithEventsClient() ConfigurationOption {
}
}
func WithGlobalContext() ConfigurationOption {
return func(c *configuration) {
c.usesGlobalContext = true
}
}
func WithFlagSets(flagsets ...*flag.FlagSet) ConfigurationOption {
return func(c *configuration) {
c.flagSets = append(c.flagSets, flagsets...)
@ -163,6 +170,7 @@ type configuration struct {
usesMetadataClient bool
usesKyvernoDynamicClient bool
usesEventsClient bool
usesGlobalContext bool
flagSets []*flag.FlagSet
}
@ -234,6 +242,10 @@ func (c *configuration) UsesEventsClient() bool {
return c.usesEventsClient
}
func (c *configuration) UsesGlobalContext() bool {
return c.usesGlobalContext
}
func (c *configuration) FlagSets() []*flag.FlagSet {
return c.flagSets
}

View file

@ -57,6 +57,8 @@ var (
imageVerifyCacheEnabled bool
imageVerifyCacheTTLDuration time.Duration
imageVerifyCacheMaxSize int64
// global context
enableGlobalContext bool
)
func initLoggingFlags() {
@ -135,6 +137,10 @@ func initCleanupFlags() {
flag.StringVar(&cleanupServerPort, "cleanupServerPort", "9443", "kyverno cleanup server port, defaults to '9443'.")
}
func initGlobalContextFlags() {
flag.BoolVar(&enableGlobalContext, "enableGlobalContext", true, "Enable global context feature.")
}
type options struct {
clientRateLimitQPS float64
clientRateLimitBurst int
@ -218,6 +224,10 @@ func initFlags(config Configuration, opts ...Option) {
if config.UsesLeaderElection() {
initLeaderElectionFlags()
}
// leader election
if config.UsesGlobalContext() {
initGlobalContextFlags()
}
initCleanupFlags()
for _, flagset := range config.FlagSets() {
flagset.VisitAll(func(f *flag.Flag) {
@ -255,6 +265,10 @@ func CleanupServerPort() string {
return cleanupServerPort
}
func GlobalContextEnabled() bool {
return enableGlobalContext
}
func printFlagSettings(logger logr.Logger) {
logger = logger.WithName("flag")
flag.VisitAll(func(f *flag.Flag) {

View file

@ -258,6 +258,7 @@ func main() {
internal.WithKyvernoDynamicClient(),
internal.WithEventsClient(),
internal.WithApiServerClient(),
internal.WithGlobalContext(),
internal.WithFlagSets(flagset),
)
// parse flags

View file

@ -242,6 +242,7 @@ func main() {
internal.WithKyvernoDynamicClient(),
internal.WithEventsClient(),
internal.WithApiServerClient(),
internal.WithGlobalContext(),
internal.WithFlagSets(flagset),
)
// parse flags

View file

@ -50588,6 +50588,13 @@ rules:
- update
- watch
- deletecollection
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- reports.kyverno.io
resources:
@ -50711,6 +50718,13 @@ rules:
- update
- watch
- deletecollection
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- ''
resources:
@ -50833,6 +50847,13 @@ rules:
verbs:
- list
- watch
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- kyverno.io
resources:
@ -51137,6 +51158,13 @@ rules:
- get
- list
- watch
- apiGroups:
- kyverno.io
resources:
- globalcontextentries
verbs:
- list
- watch
- apiGroups:
- kyverno.io
resources:

View file

@ -0,0 +1,107 @@
package globalcontext
import (
"context"
"errors"
"time"
"github.com/go-logr/logr"
kyvernov2alpha1 "github.com/kyverno/kyverno/api/kyverno/v2alpha1"
kyvernov2alpha1informers "github.com/kyverno/kyverno/pkg/client/informers/externalversions/kyverno/v2alpha1"
kyvernov2alpha1listers "github.com/kyverno/kyverno/pkg/client/listers/kyverno/v2alpha1"
"github.com/kyverno/kyverno/pkg/clients/dclient"
"github.com/kyverno/kyverno/pkg/controllers"
"github.com/kyverno/kyverno/pkg/engine/adapters"
"github.com/kyverno/kyverno/pkg/engine/globalcontext/externalapi"
"github.com/kyverno/kyverno/pkg/engine/globalcontext/k8sresource"
"github.com/kyverno/kyverno/pkg/engine/globalcontext/store"
controllerutils "github.com/kyverno/kyverno/pkg/utils/controller"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/util/workqueue"
)
const (
// Workers is the number of workers for this controller
Workers = 1
ControllerName = "global-context"
maxRetries = 10
)
type controller struct {
// listers
gceLister kyvernov2alpha1listers.GlobalContextEntryLister
// queue
queue workqueue.RateLimitingInterface
// state
dclient dclient.Interface
store store.Store
}
func NewController(
gceInformer kyvernov2alpha1informers.GlobalContextEntryInformer,
dclient dclient.Interface,
storage store.Store,
) controllers.Controller {
queue := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), ControllerName)
_, _, err := controllerutils.AddDefaultEventHandlers(logger, gceInformer.Informer(), queue)
if err != nil {
logger.Error(err, "failed to register event handlers")
}
return &controller{
gceLister: gceInformer.Lister(),
queue: queue,
dclient: dclient,
store: storage,
}
}
func (c *controller) Run(ctx context.Context, workers int) {
controllerutils.Run(ctx, logger, ControllerName, time.Second, c.queue, workers, maxRetries, c.reconcile)
}
func (c *controller) reconcile(ctx context.Context, logger logr.Logger, key, _, name string) error {
gce, err := c.getEntry(name)
if err != nil {
if apierrors.IsNotFound(err) {
// entry was deleted, remove it from the store
c.store.Delete(name)
return nil
}
return err
}
// either it's a new entry or an existing entry changed
// create a new element and set it in the store
entry, err := c.makeStoreEntry(ctx, gce)
if err != nil {
return err
}
c.store.Set(name, entry)
return nil
}
func (c *controller) getEntry(name string) (*kyvernov2alpha1.GlobalContextEntry, error) {
return c.gceLister.Get(name)
}
func (c *controller) makeStoreEntry(ctx context.Context, gce *kyvernov2alpha1.GlobalContextEntry) (store.Entry, error) {
// TODO: should be done at validation time
if gce.Spec.KubernetesResource == nil && gce.Spec.APICall == nil {
return nil, errors.New("global context entry neither has K8sResource nor APICall")
}
// TODO: should be done at validation time
if gce.Spec.KubernetesResource != nil && gce.Spec.APICall != nil {
return nil, errors.New("global context entry has both K8sResource and APICall")
}
if gce.Spec.KubernetesResource != nil {
gvr := schema.GroupVersionResource{
Group: gce.Spec.KubernetesResource.Group,
Version: gce.Spec.KubernetesResource.Version,
Resource: gce.Spec.KubernetesResource.Resource,
}
return k8sresource.New(ctx, c.dclient.GetDynamicInterface(), gvr, gce.Spec.KubernetesResource.Namespace)
}
return externalapi.New(ctx, logger, adapters.Client(c.dclient), gce.Spec.APICall.APICall, time.Duration(gce.Spec.APICall.RefreshIntervalSeconds))
}

View file

@ -0,0 +1,5 @@
package globalcontext
import "github.com/kyverno/kyverno/pkg/logging"
var logger = logging.ControllerLogger(ControllerName)

View file

@ -29,20 +29,6 @@ type apiCall struct {
config APICallConfiguration
}
type APICallConfiguration struct {
maxAPICallResponseLength int64
}
func NewAPICallConfiguration(maxLen int64) APICallConfiguration {
return APICallConfiguration{
maxAPICallResponseLength: maxLen,
}
}
type ClientInterface interface {
RawAbsPath(ctx context.Context, path string, method string, dataReader io.Reader) ([]byte, error)
}
func New(
logger logr.Logger,
jp jmespath.Interface,
@ -83,7 +69,7 @@ func (a *apiCall) Fetch(ctx context.Context) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("failed to substitute variables in context entry %s %s: %v", a.entry.Name, a.entry.APICall.URLPath, err)
}
data, err := a.execute(ctx, call)
data, err := a.Execute(ctx, call)
if err != nil {
return nil, err
}
@ -98,11 +84,10 @@ func (a *apiCall) Store(data []byte) ([]byte, error) {
return results, nil
}
func (a *apiCall) execute(ctx context.Context, call *kyvernov1.APICall) ([]byte, error) {
func (a *apiCall) Execute(ctx context.Context, call *kyvernov1.APICall) ([]byte, error) {
if call.URLPath != "" {
return a.executeK8sAPICall(ctx, call.URLPath, call.Method, call.Data)
}
return a.executeServiceCall(ctx, call)
}
@ -111,12 +96,10 @@ func (a *apiCall) executeK8sAPICall(ctx context.Context, path string, method kyv
if err != nil {
return nil, err
}
jsonData, err := a.client.RawAbsPath(ctx, path, string(method), requestData)
if err != nil {
return nil, fmt.Errorf("failed to %v resource with raw url\n: %s: %v", method, path, err)
}
a.logger.V(4).Info("executed APICall", "name", a.entry.Name, "path", path, "method", method, "len", len(jsonData))
return jsonData, nil
}

View file

@ -0,0 +1,169 @@
package apicall
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/go-logr/logr"
kyvernov1 "github.com/kyverno/kyverno/api/kyverno/v1"
"github.com/kyverno/kyverno/pkg/tracing"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
type Caller interface {
Execute(context.Context, *kyvernov1.APICall) ([]byte, error)
}
type caller struct {
logger logr.Logger
name string
client ClientInterface
config APICallConfiguration
}
func NewCaller(
logger logr.Logger,
name string,
client ClientInterface,
config APICallConfiguration,
) *caller {
return &caller{
logger: logger,
name: name,
client: client,
config: config,
}
}
func (a *caller) Execute(ctx context.Context, call *kyvernov1.APICall) ([]byte, error) {
if call.URLPath != "" {
return a.executeK8sAPICall(ctx, call.URLPath, call.Method, call.Data)
}
return a.executeServiceCall(ctx, call)
}
func (a *caller) executeK8sAPICall(ctx context.Context, path string, method kyvernov1.Method, data []kyvernov1.RequestData) ([]byte, error) {
requestData, err := a.buildRequestData(data)
if err != nil {
return nil, err
}
jsonData, err := a.client.RawAbsPath(ctx, path, string(method), requestData)
if err != nil {
return nil, fmt.Errorf("failed to %v resource with raw url\n: %s: %v", method, path, err)
}
a.logger.V(4).Info("executed APICall", "name", a.name, "path", path, "method", method, "len", len(jsonData))
return jsonData, nil
}
func (a *caller) executeServiceCall(ctx context.Context, apiCall *kyvernov1.APICall) ([]byte, error) {
if apiCall.Service == nil {
return nil, fmt.Errorf("missing service for APICall %s", a.name)
}
client, err := a.buildHTTPClient(apiCall.Service)
if err != nil {
return nil, err
}
req, err := a.buildHTTPRequest(ctx, apiCall)
if err != nil {
return nil, fmt.Errorf("failed to build HTTP request for APICall %s: %w", a.name, err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute HTTP request for APICall %s: %w", a.name, err)
}
defer resp.Body.Close()
var w http.ResponseWriter
if a.config.maxAPICallResponseLength != 0 {
resp.Body = http.MaxBytesReader(w, resp.Body, a.config.maxAPICallResponseLength)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, err := io.ReadAll(resp.Body)
if err == nil {
return nil, fmt.Errorf("HTTP %s: %s", resp.Status, string(b))
}
return nil, fmt.Errorf("HTTP %s", resp.Status)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
return nil, fmt.Errorf("response length must be less than max allowed response length of %d.", a.config.maxAPICallResponseLength)
} else {
return nil, fmt.Errorf("failed to read data from APICall %s: %w", a.name, err)
}
}
a.logger.Info("executed service APICall", "name", a.name, "len", len(body))
return body, nil
}
func (a *caller) buildRequestData(data []kyvernov1.RequestData) (io.Reader, error) {
dataMap := make(map[string]interface{})
for _, d := range data {
dataMap[d.Key] = d.Value
}
buffer := new(bytes.Buffer)
if err := json.NewEncoder(buffer).Encode(dataMap); err != nil {
return nil, fmt.Errorf("failed to encode HTTP POST data %v for APICall %s: %w", dataMap, a.name, err)
}
return buffer, nil
}
func (a *caller) buildHTTPClient(service *kyvernov1.ServiceCall) (*http.Client, error) {
if service == nil || service.CABundle == "" {
return http.DefaultClient, nil
}
caCertPool := x509.NewCertPool()
if ok := caCertPool.AppendCertsFromPEM([]byte(service.CABundle)); !ok {
return nil, fmt.Errorf("failed to parse PEM CA bundle for APICall %s", a.name)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
},
}
return &http.Client{
Transport: tracing.Transport(transport, otelhttp.WithFilter(tracing.RequestFilterIsInSpan)),
}, nil
}
func (a *caller) buildHTTPRequest(ctx context.Context, apiCall *kyvernov1.APICall) (req *http.Request, err error) {
if apiCall.Service == nil {
return nil, fmt.Errorf("missing service")
}
token := a.getToken()
defer func() {
if token != "" && req != nil {
req.Header.Add("Authorization", "Bearer "+token)
}
}()
if apiCall.Method == "GET" {
req, err = http.NewRequestWithContext(ctx, "GET", apiCall.Service.URL, nil)
return
}
if apiCall.Method == "POST" {
data, dataErr := a.buildRequestData(apiCall.Data)
if dataErr != nil {
return nil, dataErr
}
req, err = http.NewRequest("POST", apiCall.Service.URL, data)
return
}
return nil, fmt.Errorf("invalid request type %s for APICall %s", apiCall.Method, a.name)
}
func (a *caller) getToken() string {
fileName := "/var/run/secrets/kubernetes.io/serviceaccount/token"
b, err := os.ReadFile(fileName)
if err != nil {
a.logger.Info("failed to read service account token", "path", fileName)
return ""
}
return string(b)
}

View file

@ -0,0 +1,10 @@
package apicall
import (
"context"
"io"
)
type ClientInterface interface {
RawAbsPath(ctx context.Context, path string, method string, dataReader io.Reader) ([]byte, error)
}

View file

@ -0,0 +1,11 @@
package apicall
type APICallConfiguration struct {
maxAPICallResponseLength int64
}
func NewAPICallConfiguration(maxLen int64) APICallConfiguration {
return APICallConfiguration{
maxAPICallResponseLength: maxLen,
}
}

View file

@ -0,0 +1,70 @@
package externalapi
import (
"context"
"sync"
"time"
"github.com/go-logr/logr"
kyvernov1 "github.com/kyverno/kyverno/api/kyverno/v1"
"github.com/kyverno/kyverno/pkg/engine/apicall"
"k8s.io/apimachinery/pkg/util/wait"
)
type entry struct {
sync.Mutex
data any
stop func()
}
func New(ctx context.Context, logger logr.Logger, client apicall.ClientInterface, call kyvernov1.APICall, period time.Duration) (*entry, error) {
var group wait.Group
ctx, cancel := context.WithCancel(ctx)
stop := func() {
// Send stop signal to informer's goroutine
cancel()
// Wait for the group to terminate
group.Wait()
}
e := &entry{
stop: stop,
}
group.StartWithContext(ctx, func(ctx context.Context) {
// TODO: make sure we have called it at least once before returning
// TODO: config
config := apicall.NewAPICallConfiguration(10000)
caller := apicall.NewCaller(logger, "TODO", client, config)
wait.UntilWithContext(ctx, func(ctx context.Context) {
// TODO
if data, err := doCall(ctx, caller, call); err != nil {
logger.Error(err, "failed to get data from api caller")
} else {
e.setData(data)
}
}, period)
})
return e, nil
}
func (e *entry) Get() (any, error) {
e.Lock()
defer e.Unlock()
return e.data, nil
}
func (e *entry) Stop() {
e.Lock()
defer e.Unlock()
e.stop()
}
func (e *entry) setData(data any) {
e.Lock()
defer e.Unlock()
e.data = data
}
func doCall(ctx context.Context, caller apicall.Caller, call kyvernov1.APICall) (any, error) {
// TODO: unmarshall json ?
return caller.Execute(ctx, &call)
}

View file

@ -0,0 +1,21 @@
package invalid
import (
"github.com/pkg/errors"
)
type entry struct {
err error
}
func (i *entry) Get() (any, error) {
return nil, errors.Wrapf(i.err, "failed to create cached context entry")
}
func (i *entry) Stop() {}
func New(err error) *entry {
return &entry{
err: err,
}
}

View file

@ -0,0 +1,61 @@
package k8sresource
import (
"context"
"fmt"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/dynamic/dynamicinformer"
"k8s.io/client-go/tools/cache"
)
type entry struct {
lister cache.GenericLister
stop func()
}
// TODO: error handling
func New(ctx context.Context, client dynamic.Interface, gvr schema.GroupVersionResource, namespace string) (*entry, error) {
indexers := cache.Indexers{
cache.NamespaceIndex: cache.MetaNamespaceIndexFunc,
}
if namespace == "" {
namespace = metav1.NamespaceAll
}
informer := dynamicinformer.NewFilteredDynamicInformer(client, gvr, namespace, 0, indexers, nil)
var group wait.Group
ctx, cancel := context.WithCancel(ctx)
stop := func() {
// Send stop signal to informer's goroutine
cancel()
// Wait for the group to terminate
group.Wait()
}
group.StartWithContext(ctx, func(ctx context.Context) {
informer.Informer().Run(ctx.Done())
})
if !cache.WaitForCacheSync(ctx.Done(), informer.Informer().HasSynced) {
stop()
return nil, fmt.Errorf("failed to wait for cache sync: %s", gvr.Resource)
}
return &entry{
lister: informer.Lister(),
stop: stop,
}, nil
}
func (e *entry) Get() (any, error) {
obj, err := e.lister.List(labels.Everything())
if err != nil {
return nil, err
}
return obj, nil
}
func (e *entry) Stop() {
e.stop()
}

View file

@ -0,0 +1,6 @@
package store
type Entry interface {
Get() (any, error)
Stop()
}

View file

@ -0,0 +1,50 @@
package store
import (
"sync"
)
type Store interface {
Set(key string, val Entry)
Get(key string) (Entry, bool)
Delete(key string)
}
type store struct {
sync.RWMutex
store map[string]Entry
}
func New() Store {
return &store{
store: make(map[string]Entry),
}
}
func (l *store) Set(key string, val Entry) {
l.Lock()
defer l.Unlock()
old := l.store[key]
// If the key already exists, skip it before replacing it
if old != nil {
val.Stop()
}
l.store[key] = val
}
func (l *store) Get(key string) (Entry, bool) {
l.RLock()
defer l.RUnlock()
entry, ok := l.store[key]
return entry, ok
}
func (l *store) Delete(key string) {
l.Lock()
defer l.Unlock()
entry := l.store[key]
if entry != nil {
entry.Stop()
}
delete(l.store, key)
}