diff --git a/source/custom/rules/pci_id_rule.go b/source/custom/rules/pci_id_rule.go index 46496b42b..4ce7e97e4 100644 --- a/source/custom/rules/pci_id_rule.go +++ b/source/custom/rules/pci_id_rule.go @@ -1,5 +1,5 @@ /* -Copyright 2020 The Kubernetes Authors. +Copyright 2020-2021 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,10 @@ package rules import ( "fmt" - pciutils "sigs.k8s.io/node-feature-discovery/source/internal" + + "sigs.k8s.io/node-feature-discovery/pkg/api/feature" + "sigs.k8s.io/node-feature-discovery/source" + "sigs.k8s.io/node-feature-discovery/source/pci" ) // Rule that matches on the following PCI device attributes: @@ -37,40 +40,40 @@ type PciIDRule struct { // Match PCI devices on provided PCI device attributes func (r *PciIDRule) Match() (bool, error) { + devs, ok := source.GetFeatureSource("pci").GetFeatures().Instances[pci.DeviceFeature] + if !ok { + return false, fmt.Errorf("cpuid information not available") + } + devAttr := map[string]bool{} for _, attr := range []string{"class", "vendor", "device"} { devAttr[attr] = true } - allDevs, err := pciutils.DetectPci(devAttr) - if err != nil { - return false, fmt.Errorf("failed to detect PCI devices: %s", err.Error()) - } - for _, classDevs := range allDevs { - for _, dev := range classDevs { - // match rule on a single device - if r.matchDevOnRule(dev) { - return true, nil - } + for _, dev := range devs.Elements { + // match rule on a single device + if r.matchDevOnRule(dev) { + return true, nil } } return false, nil } -func (r *PciIDRule) matchDevOnRule(dev pciutils.PciDeviceInfo) bool { +func (r *PciIDRule) matchDevOnRule(dev feature.InstanceFeature) bool { if len(r.Class) == 0 && len(r.Vendor) == 0 && len(r.Device) == 0 { return false } - if len(r.Class) > 0 && !in(dev["class"], r.Class) { + attrs := dev.Attributes + if len(r.Class) > 0 && !in(attrs["class"], r.Class) { return false } - if len(r.Vendor) > 0 && !in(dev["vendor"], r.Vendor) { + if len(r.Vendor) > 0 && !in(attrs["vendor"], r.Vendor) { return false } - if len(r.Device) > 0 && !in(dev["device"], r.Device) { + if len(r.Device) > 0 && !in(attrs["device"], r.Device) { return false } diff --git a/source/pci/pci.go b/source/pci/pci.go index de7a9f71a..987a1c934 100644 --- a/source/pci/pci.go +++ b/source/pci/pci.go @@ -22,12 +22,15 @@ import ( "k8s.io/klog/v2" + "sigs.k8s.io/node-feature-discovery/pkg/api/feature" + "sigs.k8s.io/node-feature-discovery/pkg/utils" "sigs.k8s.io/node-feature-discovery/source" - pciutils "sigs.k8s.io/node-feature-discovery/source/internal" ) const Name = "pci" +const DeviceFeature = "device" + type Config struct { DeviceClassWhitelist []string `json:"deviceClassWhitelist,omitempty"` DeviceLabelFields []string `json:"deviceLabelFields,omitempty"` @@ -41,14 +44,16 @@ func newDefaultConfig() *Config { } } -// pciSource implements the LabelSource and ConfigurableSource interfaces. +// pciSource implements the FeatureSource, LabelSource and ConfigurableSource interfaces. type pciSource struct { - config *Config + config *Config + features *feature.DomainFeatures } // Singleton source instance var ( - src pciSource + src = pciSource{config: newDefaultConfig()} + _ source.FeatureSource = &src _ source.LabelSource = &src _ source.ConfigurableSource = &src ) @@ -77,16 +82,17 @@ func (s *pciSource) Priority() int { return 0 } // GetLabels method of the LabelSource interface func (s *pciSource) GetLabels() (source.FeatureLabels, error) { - features := source.FeatureLabels{} + labels := source.FeatureLabels{} + features := s.GetFeatures() // Construct a device label format, a sorted list of valid attributes - deviceLabelFields := []string{} - configLabelFields := map[string]bool{} + deviceLabelFields := make([]string, 0) + configLabelFields := make(map[string]struct{}, len(s.config.DeviceLabelFields)) for _, field := range s.config.DeviceLabelFields { - configLabelFields[field] = true + configLabelFields[field] = struct{}{} } - for _, attr := range pciutils.DefaultPciDevAttrs { + for _, attr := range mandatoryDevAttrs { if _, ok := configLabelFields[attr]; ok { deviceLabelFields = append(deviceLabelFields, attr) delete(configLabelFields, attr) @@ -97,50 +103,59 @@ func (s *pciSource) GetLabels() (source.FeatureLabels, error) { for key := range configLabelFields { keys = append(keys, key) } - klog.Warningf("invalid fields '%v' in deviceLabelFields, ignoring...", keys) + klog.Warningf("invalid fields (%s) in deviceLabelFields, ignoring...", strings.Join(keys, ", ")) } if len(deviceLabelFields) == 0 { klog.Warningf("no valid fields in deviceLabelFields defined, using the defaults") deviceLabelFields = []string{"class", "vendor"} } - // Read extraDevAttrs + configured or default labels. Attributes - // set to 'true' are considered must-have. - deviceAttrs := map[string]bool{} - for _, label := range pciutils.ExtraPciDevAttrs { - deviceAttrs[label] = false - } - for _, label := range deviceLabelFields { - deviceAttrs[label] = true - } - - devs, err := pciutils.DetectPci(deviceAttrs) - if err != nil { - return nil, fmt.Errorf("failed to detect PCI devices: %s", err.Error()) - } - // Iterate over all device classes - for class, classDevs := range devs { + for _, dev := range features.Instances[DeviceFeature].Elements { + attrs := dev.Attributes + class := attrs["class"] for _, white := range s.config.DeviceClassWhitelist { - if strings.HasPrefix(class, strings.ToLower(white)) { - for _, dev := range classDevs { - devLabel := "" - for i, attr := range deviceLabelFields { - devLabel += dev[attr] - if i < len(deviceLabelFields)-1 { - devLabel += "_" - } - } - features[devLabel+".present"] = true - - if _, ok := dev["sriov_totalvfs"]; ok { - features[devLabel+".sriov.capable"] = true + if strings.HasPrefix(string(class), strings.ToLower(white)) { + devLabel := "" + for i, attr := range deviceLabelFields { + devLabel += attrs[attr] + if i < len(deviceLabelFields)-1 { + devLabel += "_" } } + labels[devLabel+".present"] = true + + if _, ok := attrs["sriov_totalvfs"]; ok { + labels[devLabel+".sriov.capable"] = true + } + break } } } - return features, nil + return labels, nil +} + +// Discover method of the FeatureSource interface +func (s *pciSource) Discover() error { + s.features = feature.NewDomainFeatures() + + devs, err := detectPci() + if err != nil { + return fmt.Errorf("failed to detect PCI devices: %s", err.Error()) + } + s.features.Instances[DeviceFeature] = feature.NewInstanceFeatures(devs) + + utils.KlogDump(3, "discovered pci features:", " ", s.features) + + return nil +} + +// GetFeatures method of the FeatureSource Interface +func (s *pciSource) GetFeatures() *feature.DomainFeatures { + if s.features == nil { + s.features = feature.NewDomainFeatures() + } + return s.features } func init() { diff --git a/source/pci/pci_test.go b/source/pci/pci_test.go new file mode 100644 index 000000000..e419a4018 --- /dev/null +++ b/source/pci/pci_test.go @@ -0,0 +1,36 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pci + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "sigs.k8s.io/node-feature-discovery/pkg/api/feature" +) + +func TestPciSource(t *testing.T) { + assert.Equal(t, src.Name(), Name) + + // Check that GetLabels works with empty features + src.features = feature.NewDomainFeatures() + l, err := src.GetLabels() + + assert.Nil(t, err, err) + assert.Empty(t, l) + +} diff --git a/source/internal/pci_utils.go b/source/pci/utils.go similarity index 55% rename from source/internal/pci_utils.go rename to source/pci/utils.go index a0c11743a..5bc05a674 100644 --- a/source/internal/pci_utils.go +++ b/source/pci/utils.go @@ -14,28 +14,27 @@ See the License for the specific language governing permissions and limitations under the License. */ -package busutils +package pci import ( "fmt" "io/ioutil" - "path" + "path/filepath" "strings" "k8s.io/klog/v2" + "sigs.k8s.io/node-feature-discovery/pkg/api/feature" "sigs.k8s.io/node-feature-discovery/source" ) -type PciDeviceInfo map[string]string - -var DefaultPciDevAttrs = []string{"class", "vendor", "device", "subsystem_vendor", "subsystem_device"} -var ExtraPciDevAttrs = []string{"sriov_totalvfs"} +var mandatoryDevAttrs = []string{"class", "vendor", "device", "subsystem_vendor", "subsystem_device"} +var optionalDevAttrs = []string{"sriov_totalvfs"} // Read a single PCI device attribute // A PCI attribute in this context, maps to the corresponding sysfs file func readSinglePciAttribute(devPath string, attrName string) (string, error) { - data, err := ioutil.ReadFile(path.Join(devPath, attrName)) + data, err := ioutil.ReadFile(filepath.Join(devPath, attrName)) if err != nil { return "", fmt.Errorf("failed to read device attribute %s: %v", attrName, err) } @@ -51,49 +50,43 @@ func readSinglePciAttribute(devPath string, attrName string) (string, error) { } // Read information of one PCI device -func readPciDevInfo(devPath string, deviceAttrSpec map[string]bool) (PciDeviceInfo, error) { - info := PciDeviceInfo{} - - for attr, must := range deviceAttrSpec { +func readPciDevInfo(devPath string) (*feature.InstanceFeature, error) { + attrs := make(map[string]string) + for _, attr := range mandatoryDevAttrs { attrVal, err := readSinglePciAttribute(devPath, attr) if err != nil { - if must { - return info, fmt.Errorf("failed to read device %s: %s", attr, err) - } - continue - + return nil, fmt.Errorf("failed to read device %s: %s", attr, err) } - info[attr] = attrVal + attrs[attr] = attrVal } - return info, nil + for _, attr := range optionalDevAttrs { + attrVal, err := readSinglePciAttribute(devPath, attr) + if err == nil { + attrs[attr] = attrVal + } + } + return feature.NewInstanceFeature(attrs), nil } -// DetectPci lists available PCI devices and retrieve device attributes. -// deviceAttrSpec is a map which specifies which attributes to retrieve. -// a false value for a specific attribute marks the attribute as optional. -// a true value for a specific attribute marks the attribute as mandatory. -// "class" attribute is considered mandatory. -// will fail if the retrieval of a mandatory attribute fails. -func DetectPci(deviceAttrSpec map[string]bool) (map[string][]PciDeviceInfo, error) { +// detectPci detects available PCI devices and retrieves their device attributes. +// An error is returned if reading any of the mandatory attributes fails. +func detectPci() ([]feature.InstanceFeature, error) { sysfsBasePath := source.SysfsDir.Path("bus/pci/devices") - devInfo := make(map[string][]PciDeviceInfo) devices, err := ioutil.ReadDir(sysfsBasePath) if err != nil { return nil, err } - // "class" is a mandatory attribute, inject it to spec if needed. - deviceAttrSpec["class"] = true // Iterate over devices + devInfo := make([]feature.InstanceFeature, 0, len(devices)) for _, device := range devices { - info, err := readPciDevInfo(path.Join(sysfsBasePath, device.Name()), deviceAttrSpec) + info, err := readPciDevInfo(filepath.Join(sysfsBasePath, device.Name())) if err != nil { klog.Error(err) continue } - class := info["class"] - devInfo[class] = append(devInfo[class], info) + devInfo = append(devInfo, *info) } return devInfo, nil