diff --git a/source/pci/utils.go b/source/pci/utils.go index 29c7afa11..41e13868b 100644 --- a/source/pci/utils.go +++ b/source/pci/utils.go @@ -17,6 +17,8 @@ limitations under the License. package pci import ( + "bytes" + "encoding/json" "fmt" "os" "path/filepath" @@ -68,6 +70,11 @@ func readPciDevInfo(devPath string) (*nfdv1alpha1.InstanceFeature, error) { return nfdv1alpha1.NewInstanceFeature(attrs), nil } +type DevGroupedEntry struct { + Count int + Bytes []byte +} + // detectPci detects available PCI devices and retrieves their device attributes. // An error is returned if reading any of the mandatory attributes fails. func detectPci() ([]nfdv1alpha1.InstanceFeature, error) { @@ -80,13 +87,35 @@ func detectPci() ([]nfdv1alpha1.InstanceFeature, error) { // Iterate over devices devInfo := make([]nfdv1alpha1.InstanceFeature, 0, len(devices)) + devGrouped := make(map[string]map[string]DevGroupedEntry) for _, device := range devices { info, err := readPciDevInfo(filepath.Join(sysfsBasePath, device.Name())) if err != nil { klog.ErrorS(err, "failed to read PCI device info") continue } - devInfo = append(devInfo, *info) + + b, err := json.Marshal(info.Attributes) + if err != nil { + return nil, err + } + + if entry, ok := devGrouped[info.Attributes["vendor"]][info.Attributes["device"]]; !ok { + devGrouped[info.Attributes["vendor"]] = make(map[string]DevGroupedEntry) + devGrouped[info.Attributes["vendor"]][info.Attributes["device"]] = DevGroupedEntry{Bytes: b, Count: 1} + devInfo = append(devInfo, *info) + } else { + result := bytes.Compare(b, devGrouped[info.Attributes["vendor"]][info.Attributes["device"]].Bytes) + if result == 0 { + entry.Count += 1 + devGrouped[info.Attributes["vendor"]][info.Attributes["device"]] = entry + } + } + } + + for _, dev := range devInfo { + entry := devGrouped[dev.Attributes["vendor"]][dev.Attributes["device"]] + dev.Attributes["count"] = string(entry.Count) } return devInfo, nil