1
0
Fork 0
mirror of https://github.com/kubernetes-sigs/node-feature-discovery.git synced 2025-03-31 04:04:51 +00:00

feat: add spiffe integration in master and worker

Signed-off-by: TessaIO <ahmedgrati1999@gmail.com>
This commit is contained in:
TessaIO 2024-04-14 00:12:01 +02:00
parent 6a101422ee
commit c77c370d0d
7 changed files with 421 additions and 2 deletions

View file

@ -73,6 +73,8 @@ func main() {
args.Overrides.ResyncPeriod = overrides.ResyncPeriod
case "nfd-api-parallelism":
args.Overrides.NfdApiParallelism = overrides.NfdApiParallelism
case "enable-spiffe":
args.Overrides.EnableSpiffe = overrides.EnableSpiffe
}
})
@ -140,6 +142,8 @@ func initFlags(flagset *flag.FlagSet) (*master.Args, *master.ConfigOverrideArgs)
flagset.Var(overrides.ResyncPeriod, "resync-period", "Specify the NFD API controller resync period.")
overrides.NfdApiParallelism = flagset.Int("nfd-api-parallelism", 10, "Defines the maximum number of goroutines responsible of updating nodes. "+
"Can be used for the throttling mechanism.")
overrides.EnableSpiffe = flagset.Bool("enable-spiffe", false,
"Enables the Spiffe signature verification of created CRDs. This is still an EXPERIMENTAL feature.")
return args, overrides
}

View file

@ -93,6 +93,8 @@ func parseArgs(flags *flag.FlagSet, osArgs ...string) *worker.Args {
args.Overrides.LabelSources = overrides.LabelSources
case "no-owner-refs":
args.Overrides.NoOwnerRefs = overrides.NoOwnerRefs
case "enable-spiffe":
args.Overrides.EnableSpiffe = overrides.EnableSpiffe
}
})
@ -131,6 +133,8 @@ func initFlags(flagset *flag.FlagSet) (*worker.Args, *worker.ConfigOverrideArgs)
flagset.Var(overrides.LabelSources, "label-sources",
"Comma separated list of label sources. Special value 'all' enables all sources. "+
"Prefix the source name with '-' to disable it.")
overrides.EnableSpiffe = flagset.Bool("enable-spiffe", false,
"Enables the Spiffe signature verification of created CRDs. This is still an EXPERIMENTAL feature.")
return args, overrides
}

22
go.sum
View file

@ -123,6 +123,27 @@ github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/Z
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
github.com/googleapis/enterprise-certificate-proxy v0.1.0/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
github.com/googleapis/enterprise-certificate-proxy v0.2.0/go.mod h1:8C0jb7/mgJe/9KK8Lm7X9ctZC2t60YyIpYEI16jx0Qg=
github.com/googleapis/enterprise-certificate-proxy v0.2.1/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
github.com/googleapis/gax-go/v2 v2.2.0/go.mod h1:as02EH8zWkzwUoLbBaFeQ+arQaj/OthfcblKl4IGNaM=
github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99EXz9pXxye9YM=
github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c=
github.com/googleapis/gax-go/v2 v2.5.1/go.mod h1:h6B0KMMFNtI2ddbGJn3T3ZbwkeT6yqEF02fYlzkUCyo=
github.com/googleapis/gax-go/v2 v2.6.0/go.mod h1:1mjbznJAPHFpesgE5ucqfYEscaz5kMdcIDwU/6+DDoY=
github.com/googleapis/gax-go/v2 v2.7.0/go.mod h1:TEop28CZZQ2y+c0VxMUmu1lV+fQx57QpBWsYpwqHJx8=
github.com/googleapis/gax-go/v2 v2.7.1/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38/qKbhSAKP6QI=
github.com/googleapis/gax-go/v2 v2.8.0/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38/qKbhSAKP6QI=
github.com/googleapis/gax-go/v2 v2.10.0/go.mod h1:4UOEnMCrxsSqQ940WnTiD6qJ63le2ev3xfyagutxiPw=
github.com/googleapis/gax-go/v2 v2.11.0/go.mod h1:DxmR61SGKkGLa2xigwuZIQpkCI2S5iydzRfb3peWZJI=
github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
@ -262,6 +283,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

View file

@ -51,6 +51,13 @@ import (
"sigs.k8s.io/yaml"
nfdclientset "sigs.k8s.io/node-feature-discovery/api/generated/clientset/versioned"
klogutils "sigs.k8s.io/node-feature-discovery/pkg/utils/klog"
spiffe "sigs.k8s.io/node-feature-discovery/pkg/utils/spiffe"
taintutils "k8s.io/kubernetes/pkg/util/taints"
"sigs.k8s.io/yaml"
"sigs.k8s.io/node-feature-discovery/api/nfd/v1alpha1"
nfdv1alpha1 "sigs.k8s.io/node-feature-discovery/api/nfd/v1alpha1"
"sigs.k8s.io/node-feature-discovery/pkg/apis/nfd/nodefeaturerule"
"sigs.k8s.io/node-feature-discovery/pkg/apis/nfd/validate"
@ -60,6 +67,9 @@ import (
"sigs.k8s.io/node-feature-discovery/pkg/version"
)
// SocketPath specifies Spiffe Socket Path
const SocketPath = "unix:///run/spire/sockets/agent.sock"
// Labels are a Kubernetes representation of discovered features.
type Labels map[string]string
@ -92,6 +102,7 @@ type NFDConfig struct {
NfdApiParallelism int
Klog klogutils.KlogConfigOpts
Restrictions Restrictions
EnableSpiffe bool
}
// LeaderElectionConfig contains the configuration for leader election
@ -110,6 +121,7 @@ type ConfigOverrideArgs struct {
NoPublish *bool
ResyncPeriod *utils.DurationVal
NfdApiParallelism *int
EnableSpiffe *bool
}
// Args holds command line arguments
@ -149,7 +161,8 @@ type nfdMaster struct {
nfdClient nfdclientset.Interface
updaterPool *updaterPool
deniedNs
config *NFDConfig
config *NFDConfig
spiffeClient *spiffe.SpiffeClient
}
// NewNfdMaster creates a new NfdMaster server instance.
@ -206,6 +219,12 @@ func NewNfdMaster(opts ...NfdMasterOption) (NfdMaster, error) {
nfd.updaterPool = newUpdaterPool(nfd)
spiffeClient, err := spiffe.NewSpiffeClient(SocketPath)
if err != nil {
return nfd, err
}
nfd.spiffeClient = spiffeClient
return nfd, nil
}
@ -247,7 +266,6 @@ func newDefaultConfig() *NFDConfig {
RetryPeriod: utils.DurationVal{Duration: time.Duration(2) * time.Second},
RenewDeadline: utils.DurationVal{Duration: time.Duration(10) * time.Second},
},
Klog: make(map[string]string),
Restrictions: Restrictions{
DisableLabels: false,
DisableExtendedResources: false,
@ -255,6 +273,8 @@ func newDefaultConfig() *NFDConfig {
AllowOverwrite: true,
DenyNodeFeatureLabels: false,
},
Klog: make(map[string]string),
EnableSpiffe: false,
}
}
@ -677,6 +697,55 @@ func (m *nfdMaster) nfdAPIUpdateOneNode(cli k8sclient.Interface, node *corev1.No
return fmt.Errorf("failed to merge NodeFeature objects for node %q: %w", node.Name, err)
}
// Sort our objects
sort.Slice(objs, func(i, j int) bool {
// Objects in our nfd namespace gets into the beginning of the list
if objs[i].Namespace == m.namespace && objs[j].Namespace != m.namespace {
return true
}
if objs[i].Namespace != m.namespace && objs[j].Namespace == m.namespace {
return false
}
// After the nfd namespace, sort objects by their name
if objs[i].Name != objs[j].Name {
return objs[i].Name < objs[j].Name
}
// Objects with the same name are sorted by their namespace
return objs[i].Namespace < objs[j].Namespace
})
// If spiffe is enabled, we should filter out the non verified NFD objects
if m.config.EnableSpiffe {
objs, err = m.getVerifiedNFDObjects(objs)
if err != nil {
return err
}
}
klog.V(1).InfoS("processing of node initiated by NodeFeature API", "nodeName", node.Name)
features := nfdv1alpha1.NewNodeFeatureSpec()
if len(objs) > 0 {
// Merge in features
//
// NOTE: changing the rule api to support handle multiple objects instead
// of merging would probably perform better with lot less data to copy.
features = objs[0].Spec.DeepCopy()
if m.config.AutoDefaultNs {
features.Labels = addNsToMapKeys(features.Labels, nfdv1alpha1.FeatureLabelNs)
}
for _, o := range objs[1:] {
s := o.Spec.DeepCopy()
if m.config.AutoDefaultNs {
s.Labels = addNsToMapKeys(s.Labels, nfdv1alpha1.FeatureLabelNs)
}
s.MergeInto(features)
}
klog.V(4).InfoS("merged nodeFeatureSpecs", "newNodeFeatureSpec", utils.DelayedDumper(features))
}
// Update node labels et al. This may also mean removing all NFD-owned
// labels (et al.), for example in the case no NodeFeature objects are
// present.
@ -1187,6 +1256,9 @@ func (m *nfdMaster) configure(filepath string, overrides string) error {
if m.args.Overrides.NfdApiParallelism != nil {
c.NfdApiParallelism = *m.args.Overrides.NfdApiParallelism
}
if m.args.Overrides.EnableSpiffe != nil {
c.EnableSpiffe = *m.args.Overrides.EnableSpiffe
}
if c.NfdApiParallelism <= 0 {
return fmt.Errorf("the maximum number of concurrent labelers should be a non-zero positive number")
@ -1387,3 +1459,27 @@ func patchNode(cli k8sclient.Interface, nodeName string, patches []utils.JsonPat
func patchNodeStatus(cli k8sclient.Interface, nodeName string, patches []utils.JsonPatch) error {
return patchNode(cli, nodeName, patches, "status")
}
func (m *nfdMaster) getVerifiedNFDObjects(objs []*v1alpha1.NodeFeature) ([]*v1alpha1.NodeFeature, error) {
verifiedObjects := []*v1alpha1.NodeFeature{}
workerPrivateKey, workerPublicKey, err := m.spiffeClient.GetWorkerKeys()
if err != nil {
return verifiedObjects, err
}
for _, obj := range objs {
isSignatureVerified, err := spiffe.VerifyDataSignature(obj.Spec, obj.Annotations["signature"], workerPrivateKey, workerPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to verify NodeFeature signature: %w", err)
}
if isSignatureVerified {
klog.InfoS("NodeFeature verified", "NodeFeature name", obj.Name)
verifiedObjects = append(verifiedObjects, obj)
} else {
klog.InfoS("NodeFeature not verified, skipping...", "NodeFeature name", obj.Name)
}
}
return verifiedObjects, nil
}

View file

@ -17,6 +17,9 @@ limitations under the License.
package nfdworker
import (
"crypto/tls"
"crypto/x509"
b64 "encoding/base64"
"encoding/json"
"fmt"
"net/http"
@ -45,6 +48,7 @@ import (
nfdclient "sigs.k8s.io/node-feature-discovery/api/generated/clientset/versioned"
nfdv1alpha1 "sigs.k8s.io/node-feature-discovery/api/nfd/v1alpha1"
"sigs.k8s.io/node-feature-discovery/pkg/utils"
spiffe "sigs.k8s.io/node-feature-discovery/pkg/utils/spiffe"
"sigs.k8s.io/node-feature-discovery/pkg/version"
"sigs.k8s.io/node-feature-discovery/source"
@ -62,6 +66,9 @@ import (
_ "sigs.k8s.io/node-feature-discovery/source/usb"
)
// SocketPath specifies Spiffe Socket Path
const SocketPath = "unix:///run/spire/sockets/agent.sock"
// NfdWorker is the interface for nfd-worker daemon
type NfdWorker interface {
Run() error
@ -83,6 +90,7 @@ type coreConfig struct {
Sources *[]string
LabelSources []string
SleepInterval utils.DurationVal
EnableSpiffe bool
}
type sourcesConfig map[string]source.Config
@ -109,6 +117,7 @@ type ConfigOverrideArgs struct {
NoOwnerRefs *bool
FeatureSources *utils.StringSliceVal
LabelSources *utils.StringSliceVal
EnableSpiffe *bool
}
type nfdWorker struct {
@ -122,6 +131,7 @@ type nfdWorker struct {
featureSources []source.FeatureSource
labelSources []source.LabelSource
ownerReference []metav1.OwnerReference
spiffeClient *spiffe.SpiffeClient
}
// This ticker can represent infinite and normal intervals.
@ -188,6 +198,12 @@ func NewNfdWorker(opts ...NfdWorkerOption) (NfdWorker, error) {
nfd.k8sClient = cli
}
spiffeClient, err := spiffe.NewSpiffeClient(SocketPath)
if err != nil {
return nfd, err
}
nfd.spiffeClient = spiffeClient
return nfd, nil
}
@ -509,6 +525,9 @@ func (w *nfdWorker) configure(filepath string, overrides string) error {
if w.args.Overrides.LabelSources != nil {
c.Core.LabelSources = *w.args.Overrides.LabelSources
}
if w.args.Overrides.EnableSpiffe != nil {
c.Core.EnableSpiffe = *w.args.Overrides.EnableSpiffe
}
c.Core.sanitize()
@ -643,6 +662,14 @@ func (m *nfdWorker) updateNodeFeatureObject(labels Labels) error {
}
klog.InfoS("creating NodeFeature object", "nodefeature", klog.KObj(nfr))
// If Spiffe is enabled, we add the signature to the annotations section
if m.config.Core.EnableSpiffe {
err = m.signNodeFeatureCR(nfr)
if err != nil {
return err
}
}
nfrCreated, err := cli.NfdV1alpha1().NodeFeatures(namespace).Create(context.TODO(), nfr, metav1.CreateOptions{})
if err != nil {
return fmt.Errorf("failed to create NodeFeature object %q: %w", nfr.Name, err)
@ -661,6 +688,13 @@ func (m *nfdWorker) updateNodeFeatureObject(labels Labels) error {
Labels: labels,
}
if m.config.Core.EnableSpiffe {
err = m.signNodeFeatureCR(nfrUpdated)
if err != nil {
return err
}
}
if !apiequality.Semantic.DeepEqual(nfr, nfrUpdated) {
klog.InfoS("updating NodeFeature object", "nodefeature", klog.KObj(nfr))
nfrUpdated, err = cli.NfdV1alpha1().NodeFeatures(namespace).Update(context.TODO(), nfrUpdated, metav1.UpdateOptions{})
@ -718,3 +752,23 @@ func (c *sourcesConfig) UnmarshalJSON(data []byte) error {
return nil
}
// signNodeFeatureCR add the signature to the annotations of a given NodeFeature CR
func (m *nfdWorker) signNodeFeatureCR(nfr *nfdv1alpha1.NodeFeature) error {
workerPrivateKey, _, err := m.spiffeClient.GetWorkerKeys()
if err != nil {
return fmt.Errorf("error while getting worker keys: %w", err)
}
signature, err := spiffe.SignData(nfr.Spec, workerPrivateKey)
if err != nil {
return fmt.Errorf("failed to sign CRD data using Spiffe: %w", err)
}
encodedSignature := b64.StdEncoding.EncodeToString(signature)
nfr.ObjectMeta.Annotations["signature"] = encodedSignature
return nil
}

120
pkg/utils/spiffe/spiffe.go Normal file
View file

@ -0,0 +1,120 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
b64 "encoding/base64"
"encoding/json"
"fmt"
"github.com/spiffe/go-spiffe/v2/workloadapi"
)
// WorkerSpiffeID is the SpiffeID of the worker
const WorkerSpiffeID = "spiffe://nfd.com/worker"
type SpiffeClient struct {
WorkloadApiClient workloadapi.Client
}
func NewSpiffeClient(socketPath string) (*SpiffeClient, error) {
spiffeClient := SpiffeClient{}
workloadApiClient, err := workloadapi.New(context.Background(), workloadapi.WithAddr(socketPath))
if err != nil {
return nil, err
}
spiffeClient.WorkloadApiClient = *workloadApiClient
return &spiffeClient, nil
}
func SignData(data interface{}, privateKey crypto.Signer) ([]byte, error) {
stringifyData, err := json.Marshal(data)
if err != nil {
return []byte{}, err
}
dataHash := sha256.Sum256([]byte(stringifyData))
switch t := privateKey.(type) {
case *rsa.PrivateKey:
signedData, err := rsa.SignPKCS1v15(rand.Reader, privateKey.(*rsa.PrivateKey), crypto.SHA256, dataHash[:])
if err != nil {
return []byte{}, err
}
return signedData, nil
case *ecdsa.PrivateKey:
signedData, err := ecdsa.SignASN1(rand.Reader, privateKey.(*ecdsa.PrivateKey), dataHash[:])
if err != nil {
return []byte{}, err
}
return signedData, nil
default:
return nil, fmt.Errorf("unknown private key type: %v", t)
}
}
func VerifyDataSignature(data interface{}, signedData string, privateKey crypto.Signer, publicKey crypto.PublicKey) (bool, error) {
stringifyData, err := json.Marshal(data)
if err != nil {
return false, err
}
decodedSignature, err := b64.StdEncoding.DecodeString(signedData)
if err != nil {
return false, err
}
dataHash := sha256.Sum256([]byte(stringifyData))
if err != nil {
return false, err
}
switch t := privateKey.(type) {
case *rsa.PrivateKey:
err = rsa.VerifyPKCS1v15(publicKey.(*rsa.PublicKey), crypto.SHA256, dataHash[:], decodedSignature)
if err != nil {
return false, err
}
return true, nil
case *ecdsa.PrivateKey:
verify := ecdsa.VerifyASN1(publicKey.(*ecdsa.PublicKey), dataHash[:], decodedSignature)
return verify, nil
default:
return false, fmt.Errorf("unknown private key type: %v", t)
}
}
func (s *SpiffeClient) GetWorkerKeys() (crypto.Signer, crypto.PublicKey, error) {
ctx := context.Background()
svids, err := s.WorkloadApiClient.FetchX509SVIDs(ctx)
if err != nil {
return nil, nil, err
}
for _, svid := range svids {
if svid.ID.String() == WorkerSpiffeID {
return svid.PrivateKey, svid.PrivateKey.Public, nil
}
}
return nil, nil, fmt.Errorf("cannot sign data: spiffe ID %s is not found", WorkerSpiffeID)
}

View file

@ -0,0 +1,119 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
b64 "encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"sigs.k8s.io/node-feature-discovery/api/nfd/v1alpha1"
)
func mockNFRSpec() v1alpha1.NodeFeatureSpec {
return v1alpha1.NodeFeatureSpec{
Features: v1alpha1.Features{
Flags: map[string]v1alpha1.FlagFeatureSet{
"test": {
Elements: map[string]v1alpha1.Nil{
"test2": {},
},
},
},
},
}
}
func mockWorkerECDSAPrivateKey() (*ecdsa.PrivateKey, *ecdsa.PublicKey) {
privateKey, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
return privateKey, &privateKey.PublicKey
}
func mockWorkerRSAPrivateKey() (*rsa.PrivateKey, *rsa.PublicKey) {
privateKey, _ := rsa.GenerateKey(rand.Reader, 4096)
return privateKey, &privateKey.PublicKey
}
func TestVerify(t *testing.T) {
rsaPrivateKey, rsaPublicKey := mockWorkerRSAPrivateKey()
ecdsaPrivateKey, ecdsaPublicKey := mockWorkerECDSAPrivateKey()
spec := mockNFRSpec()
tc := []struct {
name string
privateKey crypto.Signer
publicKey crypto.PublicKey
wantErr bool
}{
{
name: "RSA Keys",
privateKey: rsaPrivateKey,
publicKey: rsaPublicKey,
wantErr: true,
},
{
name: "ECDSA Keys",
privateKey: ecdsaPrivateKey,
publicKey: ecdsaPublicKey,
wantErr: false,
},
}
for _, tt := range tc {
signedData, err := SignData(spec, tt.privateKey)
assert.NoError(t, err)
isVerified, err := VerifyDataSignature(spec, b64.StdEncoding.EncodeToString(signedData), tt.privateKey, tt.publicKey)
assert.NoError(t, err)
assert.True(t, isVerified)
signedData = append(signedData, "random"...)
isVerified, err = VerifyDataSignature(spec, b64.StdEncoding.EncodeToString(signedData), tt.privateKey, tt.publicKey)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.False(t, isVerified)
}
}
}
func TestSignData(t *testing.T) {
rsaPrivateKey, _ := mockWorkerRSAPrivateKey()
ecdsaPrivateKey, _ := mockWorkerECDSAPrivateKey()
spec := mockNFRSpec()
tc := []struct {
name string
privateKey crypto.Signer
}{
{
name: "RSA Keys",
privateKey: rsaPrivateKey,
},
{
name: "ECDSA Keys",
privateKey: ecdsaPrivateKey,
},
}
for _, tt := range tc {
_, err := SignData(spec, tt.privateKey)
assert.NoError(t, err)
}
}