diff --git a/pkg/nfd-master/nfd-master-internal_test.go b/pkg/nfd-master/nfd-master-internal_test.go index eda29561b..73dc43126 100644 --- a/pkg/nfd-master/nfd-master-internal_test.go +++ b/pkg/nfd-master/nfd-master-internal_test.go @@ -77,6 +77,7 @@ func TestUpdateNodeFeatures(t *testing.T) { sort.Strings(fakeExtResourceNames) mockAPIHelper := new(apihelper.MockAPIHelpers) + mockMaster := newMockMaster(mockAPIHelper) mockClient := &k8sclient.Clientset{} // Mock node with old features mockNode := newMockNode() @@ -107,7 +108,7 @@ func TestUpdateNodeFeatures(t *testing.T) { mockAPIHelper.On("GetNode", mockClient, mockNodeName).Return(mockNode, nil).Once() mockAPIHelper.On("PatchNode", mockClient, mockNodeName, mock.MatchedBy(jsonPatchMatcher(metadataPatches))).Return(nil) mockAPIHelper.On("PatchNodeStatus", mockClient, mockNodeName, mock.MatchedBy(jsonPatchMatcher(statusPatches))).Return(nil) - err := updateNodeFeatures(mockAPIHelper, mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) + err := mockMaster.updateNodeFeatures(mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) Convey("Error is nil", func() { So(err, ShouldBeNil) @@ -117,7 +118,7 @@ func TestUpdateNodeFeatures(t *testing.T) { Convey("When I fail to update the node with feature labels", func() { expectedError := errors.New("fake error") mockAPIHelper.On("GetClient").Return(nil, expectedError) - err := updateNodeFeatures(mockAPIHelper, mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) + err := mockMaster.updateNodeFeatures(mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) Convey("Error is produced", func() { So(err, ShouldEqual, expectedError) @@ -127,7 +128,7 @@ func TestUpdateNodeFeatures(t *testing.T) { Convey("When I fail to get a mock client while updating feature labels", func() { expectedError := errors.New("fake error") mockAPIHelper.On("GetClient").Return(nil, expectedError) - err := updateNodeFeatures(mockAPIHelper, mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) + err := mockMaster.updateNodeFeatures(mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) Convey("Error is produced", func() { So(err, ShouldEqual, expectedError) @@ -138,7 +139,7 @@ func TestUpdateNodeFeatures(t *testing.T) { expectedError := errors.New("fake error") mockAPIHelper.On("GetClient").Return(mockClient, nil) mockAPIHelper.On("GetNode", mockClient, mockNodeName).Return(nil, expectedError).Once() - err := updateNodeFeatures(mockAPIHelper, mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) + err := mockMaster.updateNodeFeatures(mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) Convey("Error is produced", func() { So(err, ShouldEqual, expectedError) @@ -150,7 +151,7 @@ func TestUpdateNodeFeatures(t *testing.T) { mockAPIHelper.On("GetClient").Return(mockClient, nil) mockAPIHelper.On("GetNode", mockClient, mockNodeName).Return(mockNode, nil).Once() mockAPIHelper.On("PatchNode", mockClient, mockNodeName, mock.Anything).Return(expectedError).Once() - err := updateNodeFeatures(mockAPIHelper, mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) + err := mockMaster.updateNodeFeatures(mockNodeName, fakeFeatureLabels, fakeAnnotations, fakeExtResources) Convey("Error is produced", func() { So(err, ShouldEqual, expectedError) diff --git a/pkg/nfd-master/nfd-master.go b/pkg/nfd-master/nfd-master.go index b5ba0a24c..a5502051c 100644 --- a/pkg/nfd-master/nfd-master.go +++ b/pkg/nfd-master/nfd-master.go @@ -218,7 +218,7 @@ func (m *nfdMaster) prune() error { stdoutLogger.Printf("pruning node %q...", node.Name) // Prune labels and extended resources - err := updateNodeFeatures(m.apihelper, node.Name, Labels{}, Annotations{}, ExtendedResources{}) + err := m.updateNodeFeatures(node.Name, Labels{}, Annotations{}, ExtendedResources{}) if err != nil { return fmt.Errorf("failed to prune labels from node %q: %v", node.Name, err) } @@ -345,7 +345,7 @@ func (m *nfdMaster) SetLabels(c context.Context, r *pb.SetLabelsRequest) (*pb.Se // Advertise NFD worker version as an annotation annotations := Annotations{workerVersionAnnotation: r.NfdVersion} - err := updateNodeFeatures(m.apihelper, r.NodeName, labels, annotations, extendedResources) + err := m.updateNodeFeatures(r.NodeName, labels, annotations, extendedResources) if err != nil { stderrLogger.Printf("failed to advertise labels: %s", err.Error()) return &pb.SetLabelsReply{}, err @@ -357,14 +357,14 @@ func (m *nfdMaster) SetLabels(c context.Context, r *pb.SetLabelsRequest) (*pb.Se // updateNodeFeatures ensures the Kubernetes node object is up to date, // creating new labels and extended resources where necessary and removing // outdated ones. Also updates the corresponding annotations. -func updateNodeFeatures(helper apihelper.APIHelpers, nodeName string, labels Labels, annotations Annotations, extendedResources ExtendedResources) error { - cli, err := helper.GetClient() +func (m *nfdMaster) updateNodeFeatures(nodeName string, labels Labels, annotations Annotations, extendedResources ExtendedResources) error { + cli, err := m.apihelper.GetClient() if err != nil { return err } // Get the worker node object - node, err := helper.GetNode(cli, nodeName) + node, err := m.apihelper.GetNode(cli, nodeName) if err != nil { return err } @@ -397,7 +397,7 @@ func updateNodeFeatures(helper apihelper.APIHelpers, nodeName string, labels Lab patches = append(patches, removeLabelsWithPrefix(node, "node.alpha.kubernetes-incubator.io/node-feature-discovery")...) // Patch the node object in the apiserver - err = helper.PatchNode(cli, node.Name, patches) + err = m.apihelper.PatchNode(cli, node.Name, patches) if err != nil { stderrLogger.Printf("error while patching node object: %s", err.Error()) return err @@ -405,7 +405,7 @@ func updateNodeFeatures(helper apihelper.APIHelpers, nodeName string, labels Lab // patch node status with extended resource changes patches = createExtendedResourcePatches(node, extendedResources) - err = helper.PatchNodeStatus(cli, node.Name, patches) + err = m.apihelper.PatchNodeStatus(cli, node.Name, patches) if err != nil { stderrLogger.Printf("error while patching extended resources: %s", err.Error()) return err