diff --git a/pkg/nfd-master/nfd-master.go b/pkg/nfd-master/nfd-master.go index d15497e0e..fc6e16861 100644 --- a/pkg/nfd-master/nfd-master.go +++ b/pkg/nfd-master/nfd-master.go @@ -18,9 +18,7 @@ package nfdmaster import ( "crypto/tls" - "crypto/x509" "fmt" - "io/ioutil" "net" "os" "path" @@ -93,6 +91,7 @@ type nfdMaster struct { nodeName string annotationNs string server *grpc.Server + stop chan struct{} ready chan bool apihelper apihelper.APIHelpers } @@ -102,6 +101,7 @@ func NewNfdMaster(args *Args) (NfdMaster, error) { nfd := &nfdMaster{args: *args, nodeName: os.Getenv("NODE_NAME"), ready: make(chan bool, 1), + stop: make(chan struct{}, 1), } if args.Instance == "" { @@ -164,40 +164,61 @@ func (m *nfdMaster) Run() error { close(m.ready) serverOpts := []grpc.ServerOption{} + tlsConfig := utils.TlsConfig{} + // Create watcher for TLS cert files + certWatch, err := utils.CreateFsWatcher(time.Second, m.args.CertFile, m.args.KeyFile, m.args.CaFile) + if err != nil { + return err + } // Enable mutual TLS authentication if --cert-file, --key-file or --ca-file // is defined if m.args.CertFile != "" || m.args.KeyFile != "" || m.args.CaFile != "" { - // Load cert for authenticating this server - cert, err := tls.LoadX509KeyPair(m.args.CertFile, m.args.KeyFile) - if err != nil { - return fmt.Errorf("failed to load server certificate: %v", err) - } - // Load CA cert for client cert verification - caCert, err := ioutil.ReadFile(m.args.CaFile) - if err != nil { - return fmt.Errorf("failed to read root certificate file: %v", err) - } - caPool := x509.NewCertPool() - if ok := caPool.AppendCertsFromPEM(caCert); !ok { - return fmt.Errorf("failed to add certificate from %q", m.args.CaFile) - } - // Create TLS config - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - ClientCAs: caPool, - ClientAuth: tls.RequireAndVerifyClientCert, + if err := tlsConfig.UpdateConfig(m.args.CertFile, m.args.KeyFile, m.args.CaFile); err != nil { + return err } + + tlsConfig := &tls.Config{GetConfigForClient: tlsConfig.GetConfig} serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConfig))) } m.server = grpc.NewServer(serverOpts...) pb.RegisterLabelerServer(m.server, m) klog.Infof("gRPC server serving on port: %d", m.args.Port) - return m.server.Serve(lis) + + // Run gRPC server + grpcErr := make(chan error, 1) + go func() { + defer lis.Close() + grpcErr <- m.server.Serve(lis) + }() + + // NFD-Master main event loop + for { + select { + case <-certWatch.Events: + klog.Infof("reloading TLS certificates") + if err := tlsConfig.UpdateConfig(m.args.CertFile, m.args.KeyFile, m.args.CaFile); err != nil { + return err + } + + case <-grpcErr: + return fmt.Errorf("gRPC server exited with an error: %v", err) + + case <-m.stop: + klog.Infof("shutting down nfd-master") + certWatch.Close() + return nil + } + } } // Stop NfdMaster func (m *nfdMaster) Stop() { m.server.Stop() + + select { + case m.stop <- struct{}{}: + default: + } } // Wait until NfdMaster is able able to accept connections. diff --git a/pkg/nfd-worker/nfd-worker.go b/pkg/nfd-worker/nfd-worker.go index 1f1349aa7..b702ce110 100644 --- a/pkg/nfd-worker/nfd-worker.go +++ b/pkg/nfd-worker/nfd-worker.go @@ -28,7 +28,6 @@ import ( "strings" "time" - "github.com/fsnotify/fsnotify" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -109,6 +108,7 @@ type NfdWorker interface { type nfdWorker struct { args Args + certWatch *utils.FsWatcher clientConn *grpc.ClientConn client pb.LabelerClient configFilePath string @@ -169,40 +169,6 @@ func NewNfdWorker(args *Args) (NfdWorker, error) { return nfd, nil } -func addConfigWatch(path string) (*fsnotify.Watcher, map[string]struct{}, error) { - paths := make(map[string]struct{}) - - // Create watcher - w, err := fsnotify.NewWatcher() - if err != nil { - return w, paths, fmt.Errorf("failed to create fsnotify watcher: %v", err) - } - - // Add watches for all directory components so that we catch e.g. renames - // upper in the tree - added := false - for p := path; ; p = filepath.Dir(p) { - - if err := w.Add(p); err != nil { - klog.V(1).Infof("failed to add fsnotify watch for %q: %v", p, err) - } else { - klog.V(1).Infof("added fsnotify watch %q", p) - added = true - } - - paths[p] = struct{}{} - if filepath.Dir(p) == p { - break - } - } - - if !added { - // Want to be sure that we watch something - return w, paths, fmt.Errorf("failed to add any watch") - } - return w, paths, nil -} - func newDefaultConfig() *NFDConfig { return &NFDConfig{ Core: coreConfig{ @@ -221,7 +187,7 @@ func (w *nfdWorker) Run() error { klog.Infof("NodeName: '%s'", nodeName) // Create watcher for config file and read initial configuration - configWatch, paths, err := addConfigWatch(w.configFilePath) + configWatch, err := utils.CreateFsWatcher(time.Second, w.configFilePath) if err != nil { return err } @@ -229,6 +195,12 @@ func (w *nfdWorker) Run() error { return err } + // Create watcher for TLS certificates + w.certWatch, err = utils.CreateFsWatcher(time.Second, w.args.CaFile, w.args.CertFile, w.args.KeyFile) + if err != nil { + return err + } + // Connect to NFD master err = w.connect() if err != nil { @@ -237,7 +209,6 @@ func (w *nfdWorker) Run() error { defer w.disconnect() labelTrigger := time.After(0) - var configTrigger <-chan time.Time for { select { case <-labelTrigger: @@ -260,32 +231,8 @@ func (w *nfdWorker) Run() error { labelTrigger = time.After(w.config.Core.SleepInterval.Duration) } - case e := <-configWatch.Events: - name := filepath.Clean(e.Name) - - // If any of our paths (directories or the file itself) change - if _, ok := paths[name]; ok { - klog.Infof("fsnotify event in %q detected, reconfiguring fsnotify and reloading configuration", name) - - // Blindly remove existing watch and add a new one - if err := configWatch.Close(); err != nil { - klog.Warningf("failed to close fsnotify watcher: %v", err) - } - configWatch, paths, err = addConfigWatch(w.configFilePath) - if err != nil { - return err - } - - // Rate limiter. In certain filesystem operations we get - // numerous events in quick succession and we only want one - // config re-load - configTrigger = time.After(time.Second) - } - - case e := <-configWatch.Errors: - klog.Errorf("config file watcher error: %v", e) - - case <-configTrigger: + case <-configWatch.Events: + klog.Infof("reloading configuration") if err := w.configure(w.configFilePath, w.args.Options); err != nil { return err } @@ -301,9 +248,17 @@ func (w *nfdWorker) Run() error { // comes into effect even if the sleep interval is long (or infinite) labelTrigger = time.After(0) + case <-w.certWatch.Events: + klog.Infof("TLS certificate update, renewing connection to nfd-master") + w.disconnect() + if err := w.connect(); err != nil { + return err + } + case <-w.stop: klog.Infof("shutting down nfd-worker") configWatch.Close() + w.certWatch.Close() return nil } } @@ -358,6 +313,7 @@ func (w *nfdWorker) connect() error { } else { dialOpts = append(dialOpts, grpc.WithInsecure()) } + klog.Infof("connecting to nfd-master at %s ...", w.args.Server) conn, err := grpc.DialContext(dialCtx, w.args.Server, dialOpts...) if err != nil { return err @@ -371,6 +327,7 @@ func (w *nfdWorker) connect() error { // disconnect closes the connection to NFD master func (w *nfdWorker) disconnect() { if w.clientConn != nil { + klog.Infof("closing connection to nfd-master ...") w.clientConn.Close() } w.clientConn = nil diff --git a/pkg/utils/fswatcher.go b/pkg/utils/fswatcher.go new file mode 100644 index 000000000..164381934 --- /dev/null +++ b/pkg/utils/fswatcher.go @@ -0,0 +1,159 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "fmt" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" + "k8s.io/klog/v2" +) + +// FsWatcher is a wrapper helper for watching files +type FsWatcher struct { + *fsnotify.Watcher + + Events chan struct{} + ratelimit time.Duration + names []string + paths map[string]struct{} +} + +// CreateFsWatcher creates a new FsWatcher +func CreateFsWatcher(ratelimit time.Duration, names ...string) (*FsWatcher, error) { + w := &FsWatcher{ + Events: make(chan struct{}), + names: names, + ratelimit: ratelimit, + } + + if err := w.reset(names...); err != nil { + return nil, err + } + + go w.watch() + + return w, nil +} + +// reset resets the file watches +func (w *FsWatcher) reset(names ...string) error { + if err := w.initWatcher(); err != nil { + return err + } + if err := w.add(names...); err != nil { + return err + } + + return nil +} + +func (w *FsWatcher) initWatcher() error { + if w.Watcher != nil { + if err := w.Watcher.Close(); err != nil { + return fmt.Errorf("failed to close fsnotify watcher: %v", err) + } + } + w.paths = make(map[string]struct{}) + + watcher, err := fsnotify.NewWatcher() + if err != nil { + w.Watcher = nil + return fmt.Errorf("failed to create fsnotify watcher: %v", err) + } + w.Watcher = watcher + + return nil +} + +func (w *FsWatcher) add(names ...string) error { + for _, name := range names { + if name == "" { + continue + } + + added := false + // Add watches for all directory components so that we catch e.g. renames + // upper in the tree + for p := name; ; p = filepath.Dir(p) { + if _, ok := w.paths[p]; !ok { + if err := w.Add(p); err != nil { + klog.V(1).Infof("failed to add fsnotify watch for %q: %v", p, err) + } else { + klog.V(1).Infof("added fsnotify watch %q", p) + added = true + } + + w.paths[p] = struct{}{} + } else { + added = true + } + if filepath.Dir(p) == p { + break + } + } + if !added { + // Want to be sure that we watch something + return fmt.Errorf("failed to add any watch") + } + } + + return nil +} + +func (w *FsWatcher) watch() { + var ratelimiter <-chan time.Time + for { + select { + case e, ok := <-w.Watcher.Events: + // Watcher has been closed + if !ok { + klog.Infof("watcher closed") + return + } + + // If any of our paths change + name := filepath.Clean(e.Name) + if _, ok := w.paths[filepath.Clean(name)]; ok { + klog.V(2).Infof("fsnotify %s event in %q detected", e, name) + + // Rate limiter. In certain filesystem operations we get + // numerous events in quick succession + ratelimiter = time.After(w.ratelimit) + } + + case e, ok := <-w.Watcher.Errors: + // Watcher has been closed + if !ok { + klog.Infof("watcher closed") + return + } + klog.Warningf("fswatcher error event detected: %v", e) + + case <-ratelimiter: + // Blindly remove existing watch and add a new one + if err := w.reset(w.names...); err != nil { + klog.Errorf("%v, re-trying in 60 seconds...", err) + ratelimiter = time.After(60 * time.Second) + } + + w.Events <- struct{}{} + } + } +} diff --git a/pkg/utils/tls.go b/pkg/utils/tls.go new file mode 100644 index 000000000..a7df14117 --- /dev/null +++ b/pkg/utils/tls.go @@ -0,0 +1,70 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "sync" +) + +// TlsConfig is a TLS config wrapper/helper for cert rotation +type TlsConfig struct { + sync.Mutex + config *tls.Config +} + +// GetConfig returns the current TLS configuration. Intended to be used as the +// GetConfigForClient callback in tls.Config. +func (c *TlsConfig) GetConfig(*tls.ClientHelloInfo) (*tls.Config, error) { + c.Lock() + defer c.Unlock() + + return c.config, nil +} + +// UpdateConfig updates the wrapped TLS config +func (c *TlsConfig) UpdateConfig(certFile, keyFile, caFile string) error { + c.Lock() + defer c.Unlock() + + // Load cert for authenticating this server + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("failed to load server certificate: %v", err) + } + // Load CA cert for client cert verification + caCert, err := ioutil.ReadFile(caFile) + if err != nil { + return fmt.Errorf("failed to read root certificate file: %v", err) + } + caPool := x509.NewCertPool() + if ok := caPool.AppendCertsFromPEM(caCert); !ok { + return fmt.Errorf("failed to add certificate from '%s'", caFile) + } + + // Create TLS config + c.config = &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: caPool, + ClientAuth: tls.RequireAndVerifyClientCert, + GetConfigForClient: c.GetConfig, + } + return nil +}