Import docker.io_1.3.3~dfsg1.orig-libtrust.tar.gz
authorTianon Gravi <admwiggin@gmail.com>
Fri, 19 Dec 2014 04:54:12 +0000 (04:54 +0000)
committerTianon Gravi <admwiggin@gmail.com>
Fri, 19 Dec 2014 04:54:12 +0000 (04:54 +0000)
[dgit import orig docker.io_1.3.3~dfsg1.orig-libtrust.tar.gz]

31 files changed:
CONTRIBUTING.md [new file with mode: 0644]
LICENSE [new file with mode: 0644]
MAINTAINERS [new file with mode: 0644]
README.md [new file with mode: 0644]
certificates.go [new file with mode: 0644]
certificates_test.go [new file with mode: 0644]
doc.go [new file with mode: 0644]
ec_key.go [new file with mode: 0644]
ec_key_test.go [new file with mode: 0644]
filter.go [new file with mode: 0644]
filter_test.go [new file with mode: 0644]
hash.go [new file with mode: 0644]
jsonsign.go [new file with mode: 0644]
jsonsign_test.go [new file with mode: 0644]
key.go [new file with mode: 0644]
key_files.go [new file with mode: 0644]
key_files_test.go [new file with mode: 0644]
rsa_key.go [new file with mode: 0644]
rsa_key_test.go [new file with mode: 0644]
testutil/certificates.go [new file with mode: 0644]
tlsdemo/README.md [new file with mode: 0644]
tlsdemo/client.go [new file with mode: 0644]
tlsdemo/gencert.go [new file with mode: 0644]
tlsdemo/genkeys.go [new file with mode: 0644]
tlsdemo/server.go [new file with mode: 0644]
trustgraph/graph.go [new file with mode: 0644]
trustgraph/memory_graph.go [new file with mode: 0644]
trustgraph/memory_graph_test.go [new file with mode: 0644]
trustgraph/statement.go [new file with mode: 0644]
trustgraph/statement_test.go [new file with mode: 0644]
util.go [new file with mode: 0644]

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644 (file)
index 0000000..05be0f8
--- /dev/null
@@ -0,0 +1,13 @@
+# Contributing to libtrust
+
+Want to hack on libtrust? Awesome! Here are instructions to get you
+started.
+
+libtrust is a part of the [Docker](https://www.docker.com) project, and follows
+the same rules and principles. If you're already familiar with the way
+Docker does things, you'll feel right at home.
+
+Otherwise, go read
+[Docker's contributions guidelines](https://github.com/docker/docker/blob/master/CONTRIBUTING.md).
+
+Happy hacking!
diff --git a/LICENSE b/LICENSE
new file mode 100644 (file)
index 0000000..2744858
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,191 @@
+
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   Copyright 2014 Docker, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/MAINTAINERS b/MAINTAINERS
new file mode 100644 (file)
index 0000000..9768175
--- /dev/null
@@ -0,0 +1,3 @@
+Solomon Hykes <solomon@docker.com>
+Josh Hawn <josh@docker.com> (github: jlhawn)
+Derek McGowan <derek@docker.com> (github: dmcgowan)
diff --git a/README.md b/README.md
new file mode 100644 (file)
index 0000000..8e7db38
--- /dev/null
+++ b/README.md
@@ -0,0 +1,18 @@
+# libtrust
+
+Libtrust is library for managing authentication and authorization using public key cryptography.
+
+Authentication is handled using the identity attached to the public key.
+Libtrust provides multiple methods to prove possession of the private key associated with an identity.
+ - TLS x509 certificates
+ - Signature verification
+ - Key Challenge
+
+Authorization and access control is managed through a distributed trust graph.
+Trust servers are used as the authorities of the trust graph and allow caching portions of the graph for faster access.
+
+## Copyright and license
+
+Code and documentation copyright 2014 Docker, inc. Code released under the Apache 2.0 license.
+Docs released under Creative commons.
+
diff --git a/certificates.go b/certificates.go
new file mode 100644 (file)
index 0000000..3dcca33
--- /dev/null
@@ -0,0 +1,175 @@
+package libtrust
+
+import (
+       "crypto/rand"
+       "crypto/x509"
+       "crypto/x509/pkix"
+       "encoding/pem"
+       "fmt"
+       "io/ioutil"
+       "math/big"
+       "net"
+       "time"
+)
+
+type certTemplateInfo struct {
+       commonName  string
+       domains     []string
+       ipAddresses []net.IP
+       isCA        bool
+       clientAuth  bool
+       serverAuth  bool
+}
+
+func generateCertTemplate(info *certTemplateInfo) *x509.Certificate {
+       // Generate a certificate template which is valid from the past week to
+       // 10 years from now. The usage of the certificate depends on the
+       // specified fields in the given certTempInfo object.
+       var (
+               keyUsage    x509.KeyUsage
+               extKeyUsage []x509.ExtKeyUsage
+       )
+
+       if info.isCA {
+               keyUsage = x509.KeyUsageCertSign
+       }
+
+       if info.clientAuth {
+               extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth)
+       }
+
+       if info.serverAuth {
+               extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageServerAuth)
+       }
+
+       return &x509.Certificate{
+               SerialNumber: big.NewInt(0),
+               Subject: pkix.Name{
+                       CommonName: info.commonName,
+               },
+               NotBefore:             time.Now().Add(-time.Hour * 24 * 7),
+               NotAfter:              time.Now().Add(time.Hour * 24 * 365 * 10),
+               DNSNames:              info.domains,
+               IPAddresses:           info.ipAddresses,
+               IsCA:                  info.isCA,
+               KeyUsage:              keyUsage,
+               ExtKeyUsage:           extKeyUsage,
+               BasicConstraintsValid: info.isCA,
+       }
+}
+
+func generateCert(pub PublicKey, priv PrivateKey, subInfo, issInfo *certTemplateInfo) (cert *x509.Certificate, err error) {
+       pubCertTemplate := generateCertTemplate(subInfo)
+       privCertTemplate := generateCertTemplate(issInfo)
+
+       certDER, err := x509.CreateCertificate(
+               rand.Reader, pubCertTemplate, privCertTemplate,
+               pub.CryptoPublicKey(), priv.CryptoPrivateKey(),
+       )
+       if err != nil {
+               return nil, fmt.Errorf("failed to create certificate: %s", err)
+       }
+
+       cert, err = x509.ParseCertificate(certDER)
+       if err != nil {
+               return nil, fmt.Errorf("failed to parse certificate: %s", err)
+       }
+
+       return
+}
+
+// GenerateSelfSignedServerCert creates a self-signed certificate for the
+// given key which is to be used for TLS servers with the given domains and
+// IP addresses.
+func GenerateSelfSignedServerCert(key PrivateKey, domains []string, ipAddresses []net.IP) (*x509.Certificate, error) {
+       info := &certTemplateInfo{
+               commonName:  key.KeyID(),
+               domains:     domains,
+               ipAddresses: ipAddresses,
+               serverAuth:  true,
+       }
+
+       return generateCert(key.PublicKey(), key, info, info)
+}
+
+// GenerateSelfSignedClientCert creates a self-signed certificate for the
+// given key which is to be used for TLS clients.
+func GenerateSelfSignedClientCert(key PrivateKey) (*x509.Certificate, error) {
+       info := &certTemplateInfo{
+               commonName: key.KeyID(),
+               clientAuth: true,
+       }
+
+       return generateCert(key.PublicKey(), key, info, info)
+}
+
+// GenerateCACert creates a certificate which can be used as a trusted
+// certificate authority.
+func GenerateCACert(signer PrivateKey, trustedKey PublicKey) (*x509.Certificate, error) {
+       subjectInfo := &certTemplateInfo{
+               commonName: trustedKey.KeyID(),
+               isCA:       true,
+       }
+       issuerInfo := &certTemplateInfo{
+               commonName: signer.KeyID(),
+       }
+
+       return generateCert(trustedKey, signer, subjectInfo, issuerInfo)
+}
+
+// GenerateCACertPool creates a certificate authority pool to be used for a
+// TLS configuration. Any self-signed certificates issued by the specified
+// trusted keys will be verified during a TLS handshake
+func GenerateCACertPool(signer PrivateKey, trustedKeys []PublicKey) (*x509.CertPool, error) {
+       certPool := x509.NewCertPool()
+
+       for _, trustedKey := range trustedKeys {
+               cert, err := GenerateCACert(signer, trustedKey)
+               if err != nil {
+                       return nil, fmt.Errorf("failed to generate CA certificate: %s", err)
+               }
+
+               certPool.AddCert(cert)
+       }
+
+       return certPool, nil
+}
+
+// LoadCertificateBundle loads certificates from the given file.  The file should be pem encoded
+// containing one or more certificates.  The expected pem type is "CERTIFICATE".
+func LoadCertificateBundle(filename string) ([]*x509.Certificate, error) {
+       b, err := ioutil.ReadFile(filename)
+       if err != nil {
+               return nil, err
+       }
+       certificates := []*x509.Certificate{}
+       var block *pem.Block
+       block, b = pem.Decode(b)
+       for ; block != nil; block, b = pem.Decode(b) {
+               if block.Type == "CERTIFICATE" {
+                       cert, err := x509.ParseCertificate(block.Bytes)
+                       if err != nil {
+                               return nil, err
+                       }
+                       certificates = append(certificates, cert)
+               } else {
+                       return nil, fmt.Errorf("invalid pem block type: %s", block.Type)
+               }
+       }
+
+       return certificates, nil
+}
+
+// LoadCertificatePool loads a CA pool from the given file.  The file should be pem encoded
+// containing one or more certificates. The expected pem type is "CERTIFICATE".
+func LoadCertificatePool(filename string) (*x509.CertPool, error) {
+       certs, err := LoadCertificateBundle(filename)
+       if err != nil {
+               return nil, err
+       }
+       pool := x509.NewCertPool()
+       for _, cert := range certs {
+               pool.AddCert(cert)
+       }
+       return pool, nil
+}
diff --git a/certificates_test.go b/certificates_test.go
new file mode 100644 (file)
index 0000000..c111f35
--- /dev/null
@@ -0,0 +1,111 @@
+package libtrust
+
+import (
+       "encoding/pem"
+       "io/ioutil"
+       "net"
+       "os"
+       "path"
+       "testing"
+)
+
+func TestGenerateCertificates(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       _, err = GenerateSelfSignedServerCert(key, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")})
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       _, err = GenerateSelfSignedClientCert(key)
+       if err != nil {
+               t.Fatal(err)
+       }
+}
+
+func TestGenerateCACertPool(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       caKey1, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       caKey2, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       _, err = GenerateCACertPool(key, []PublicKey{caKey1.PublicKey(), caKey2.PublicKey()})
+       if err != nil {
+               t.Fatal(err)
+       }
+}
+
+func TestLoadCertificates(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       caKey1, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+       caKey2, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       cert1, err := GenerateCACert(caKey1, key)
+       if err != nil {
+               t.Fatal(err)
+       }
+       cert2, err := GenerateCACert(caKey2, key)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       d, err := ioutil.TempDir("/tmp", "cert-test")
+       if err != nil {
+               t.Fatal(err)
+       }
+       caFile := path.Join(d, "ca.pem")
+       f, err := os.OpenFile(caFile, os.O_CREATE|os.O_WRONLY, 0644)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert1.Raw})
+       if err != nil {
+               t.Fatal(err)
+       }
+       err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert2.Raw})
+       if err != nil {
+               t.Fatal(err)
+       }
+       f.Close()
+
+       certs, err := LoadCertificateBundle(caFile)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if len(certs) != 2 {
+               t.Fatalf("Wrong number of certs received, expected: %d, received %d", 2, len(certs))
+       }
+
+       pool, err := LoadCertificatePool(caFile)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if len(pool.Subjects()) != 2 {
+               t.Fatalf("Invalid certificate pool")
+       }
+}
diff --git a/doc.go b/doc.go
new file mode 100644 (file)
index 0000000..ec5d215
--- /dev/null
+++ b/doc.go
@@ -0,0 +1,9 @@
+/*
+Package libtrust provides an interface for managing authentication and
+authorization using public key cryptography. Authentication is handled
+using the identity attached to the public key and verified through TLS
+x509 certificates, a key challenge, or signature. Authorization and
+access control is managed through a trust graph distributed between
+both remote trust servers and locally cached and managed data.
+*/
+package libtrust
diff --git a/ec_key.go b/ec_key.go
new file mode 100644 (file)
index 0000000..c7ac684
--- /dev/null
+++ b/ec_key.go
@@ -0,0 +1,437 @@
+package libtrust
+
+import (
+       "crypto"
+       "crypto/ecdsa"
+       "crypto/elliptic"
+       "crypto/rand"
+       "crypto/x509"
+       "encoding/json"
+       "encoding/pem"
+       "errors"
+       "fmt"
+       "io"
+       "math/big"
+)
+
+/*
+ * EC DSA PUBLIC KEY
+ */
+
+// ecPublicKey implements a libtrust.PublicKey using elliptic curve digital
+// signature algorithms.
+type ecPublicKey struct {
+       *ecdsa.PublicKey
+       curveName          string
+       signatureAlgorithm *signatureAlgorithm
+       extended           map[string]interface{}
+}
+
+func fromECPublicKey(cryptoPublicKey *ecdsa.PublicKey) (*ecPublicKey, error) {
+       curve := cryptoPublicKey.Curve
+
+       switch {
+       case curve == elliptic.P256():
+               return &ecPublicKey{cryptoPublicKey, "P-256", es256, map[string]interface{}{}}, nil
+       case curve == elliptic.P384():
+               return &ecPublicKey{cryptoPublicKey, "P-384", es384, map[string]interface{}{}}, nil
+       case curve == elliptic.P521():
+               return &ecPublicKey{cryptoPublicKey, "P-521", es512, map[string]interface{}{}}, nil
+       default:
+               return nil, errors.New("unsupported elliptic curve")
+       }
+}
+
+// KeyType returns the key type for elliptic curve keys, i.e., "EC".
+func (k *ecPublicKey) KeyType() string {
+       return "EC"
+}
+
+// CurveName returns the elliptic curve identifier.
+// Possible values are "P-256", "P-384", and "P-521".
+func (k *ecPublicKey) CurveName() string {
+       return k.curveName
+}
+
+// KeyID returns a distinct identifier which is unique to this Public Key.
+func (k *ecPublicKey) KeyID() string {
+       // Generate and return a libtrust fingerprint of the EC public key.
+       // For an EC key this should be:
+       //   SHA256("EC"+curveName+bytes(X)+bytes(Y))
+       // Then truncated to 240 bits and encoded into 12 base32 groups like so:
+       //   ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
+       hasher := crypto.SHA256.New()
+       hasher.Write([]byte(k.KeyType() + k.CurveName()))
+       hasher.Write(k.X.Bytes())
+       hasher.Write(k.Y.Bytes())
+       return keyIDEncode(hasher.Sum(nil)[:30])
+}
+
+func (k *ecPublicKey) String() string {
+       return fmt.Sprintf("EC Public Key <%s>", k.KeyID())
+}
+
+// Verify verifyies the signature of the data in the io.Reader using this
+// PublicKey. The alg parameter should identify the digital signature
+// algorithm which was used to produce the signature and should be supported
+// by this public key. Returns a nil error if the signature is valid.
+func (k *ecPublicKey) Verify(data io.Reader, alg string, signature []byte) error {
+       // For EC keys there is only one supported signature algorithm depending
+       // on the curve parameters.
+       if k.signatureAlgorithm.HeaderParam() != alg {
+               return fmt.Errorf("unable to verify signature: EC Public Key with curve %q does not support signature algorithm %q", k.curveName, alg)
+       }
+
+       // signature is the concatenation of (r, s), base64Url encoded.
+       sigLength := len(signature)
+       expectedOctetLength := 2 * ((k.Params().BitSize + 7) >> 3)
+       if sigLength != expectedOctetLength {
+               return fmt.Errorf("signature length is %d octets long, should be %d", sigLength, expectedOctetLength)
+       }
+
+       rBytes, sBytes := signature[:sigLength/2], signature[sigLength/2:]
+       r := new(big.Int).SetBytes(rBytes)
+       s := new(big.Int).SetBytes(sBytes)
+
+       hasher := k.signatureAlgorithm.HashID().New()
+       _, err := io.Copy(hasher, data)
+       if err != nil {
+               return fmt.Errorf("error reading data to sign: %s", err)
+       }
+       hash := hasher.Sum(nil)
+
+       if !ecdsa.Verify(k.PublicKey, hash, r, s) {
+               return errors.New("invalid signature")
+       }
+
+       return nil
+}
+
+// CryptoPublicKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *ecPublicKey) CryptoPublicKey() crypto.PublicKey {
+       return k.PublicKey
+}
+
+func (k *ecPublicKey) toMap() map[string]interface{} {
+       jwk := make(map[string]interface{})
+       for k, v := range k.extended {
+               jwk[k] = v
+       }
+       jwk["kty"] = k.KeyType()
+       jwk["kid"] = k.KeyID()
+       jwk["crv"] = k.CurveName()
+
+       xBytes := k.X.Bytes()
+       yBytes := k.Y.Bytes()
+       octetLength := (k.Params().BitSize + 7) >> 3
+       // MUST include leading zeros in the output so that x, y are each
+       // *octetLength* bytes long.
+       xBuf := make([]byte, octetLength-len(xBytes), octetLength)
+       yBuf := make([]byte, octetLength-len(yBytes), octetLength)
+       xBuf = append(xBuf, xBytes...)
+       yBuf = append(yBuf, yBytes...)
+
+       jwk["x"] = joseBase64UrlEncode(xBuf)
+       jwk["y"] = joseBase64UrlEncode(yBuf)
+
+       return jwk
+}
+
+// MarshalJSON serializes this Public Key using the JWK JSON serialization format for
+// elliptic curve keys.
+func (k *ecPublicKey) MarshalJSON() (data []byte, err error) {
+       return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Public Key to DER-encoded PKIX format.
+func (k *ecPublicKey) PEMBlock() (*pem.Block, error) {
+       derBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey)
+       if err != nil {
+               return nil, fmt.Errorf("unable to serialize EC PublicKey to DER-encoded PKIX format: %s", err)
+       }
+       k.extended["keyID"] = k.KeyID() // For display purposes.
+       return createPemBlock("PUBLIC KEY", derBytes, k.extended)
+}
+
+func (k *ecPublicKey) AddExtendedField(field string, value interface{}) {
+       k.extended[field] = value
+}
+
+func (k *ecPublicKey) GetExtendedField(field string) interface{} {
+       v, ok := k.extended[field]
+       if !ok {
+               return nil
+       }
+       return v
+}
+
+func ecPublicKeyFromMap(jwk map[string]interface{}) (*ecPublicKey, error) {
+       // JWK key type (kty) has already been determined to be "EC".
+       // Need to extract 'crv', 'x', 'y', and 'kid' and check for
+       // consistency.
+
+       // Get the curve identifier value.
+       crv, err := stringFromMap(jwk, "crv")
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Public Key curve identifier: %s", err)
+       }
+
+       var (
+               curve  elliptic.Curve
+               sigAlg *signatureAlgorithm
+       )
+
+       switch {
+       case crv == "P-256":
+               curve = elliptic.P256()
+               sigAlg = es256
+       case crv == "P-384":
+               curve = elliptic.P384()
+               sigAlg = es384
+       case crv == "P-521":
+               curve = elliptic.P521()
+               sigAlg = es512
+       default:
+               return nil, fmt.Errorf("JWK EC Public Key curve identifier not supported: %q\n", crv)
+       }
+
+       // Get the X and Y coordinates for the public key point.
+       xB64Url, err := stringFromMap(jwk, "x")
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Public Key x-coordinate: %s", err)
+       }
+       x, err := parseECCoordinate(xB64Url, curve)
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Public Key x-coordinate: %s", err)
+       }
+
+       yB64Url, err := stringFromMap(jwk, "y")
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Public Key y-coordinate: %s", err)
+       }
+       y, err := parseECCoordinate(yB64Url, curve)
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Public Key y-coordinate: %s", err)
+       }
+
+       key := &ecPublicKey{
+               PublicKey: &ecdsa.PublicKey{Curve: curve, X: x, Y: y},
+               curveName: crv, signatureAlgorithm: sigAlg,
+       }
+
+       // Key ID is optional too, but if it exists, it should match the key.
+       _, ok := jwk["kid"]
+       if ok {
+               kid, err := stringFromMap(jwk, "kid")
+               if err != nil {
+                       return nil, fmt.Errorf("JWK EC Public Key ID: %s", err)
+               }
+               if kid != key.KeyID() {
+                       return nil, fmt.Errorf("JWK EC Public Key ID does not match: %s", kid)
+               }
+       }
+
+       key.extended = jwk
+
+       return key, nil
+}
+
+/*
+ * EC DSA PRIVATE KEY
+ */
+
+// ecPrivateKey implements a JWK Private Key using elliptic curve digital signature
+// algorithms.
+type ecPrivateKey struct {
+       ecPublicKey
+       *ecdsa.PrivateKey
+}
+
+func fromECPrivateKey(cryptoPrivateKey *ecdsa.PrivateKey) (*ecPrivateKey, error) {
+       publicKey, err := fromECPublicKey(&cryptoPrivateKey.PublicKey)
+       if err != nil {
+               return nil, err
+       }
+
+       return &ecPrivateKey{*publicKey, cryptoPrivateKey}, nil
+}
+
+// PublicKey returns the Public Key data associated with this Private Key.
+func (k *ecPrivateKey) PublicKey() PublicKey {
+       return &k.ecPublicKey
+}
+
+func (k *ecPrivateKey) String() string {
+       return fmt.Sprintf("EC Private Key <%s>", k.KeyID())
+}
+
+// Sign signs the data read from the io.Reader using a signature algorithm supported
+// by the elliptic curve private key. If the specified hashing algorithm is
+// supported by this key, that hash function is used to generate the signature
+// otherwise the the default hashing algorithm for this key is used. Returns
+// the signature and the name of the JWK signature algorithm used, e.g.,
+// "ES256", "ES384", "ES512".
+func (k *ecPrivateKey) Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) {
+       // Generate a signature of the data using the internal alg.
+       // The given hashId is only a suggestion, and since EC keys only support
+       // on signature/hash algorithm given the curve name, we disregard it for
+       // the elliptic curve JWK signature implementation.
+       hasher := k.signatureAlgorithm.HashID().New()
+       _, err = io.Copy(hasher, data)
+       if err != nil {
+               return nil, "", fmt.Errorf("error reading data to sign: %s", err)
+       }
+       hash := hasher.Sum(nil)
+
+       r, s, err := ecdsa.Sign(rand.Reader, k.PrivateKey, hash)
+       if err != nil {
+               return nil, "", fmt.Errorf("error producing signature: %s", err)
+       }
+       rBytes, sBytes := r.Bytes(), s.Bytes()
+       octetLength := (k.ecPublicKey.Params().BitSize + 7) >> 3
+       // MUST include leading zeros in the output
+       rBuf := make([]byte, octetLength-len(rBytes), octetLength)
+       sBuf := make([]byte, octetLength-len(sBytes), octetLength)
+
+       rBuf = append(rBuf, rBytes...)
+       sBuf = append(sBuf, sBytes...)
+
+       signature = append(rBuf, sBuf...)
+       alg = k.signatureAlgorithm.HeaderParam()
+
+       return
+}
+
+// CryptoPrivateKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *ecPrivateKey) CryptoPrivateKey() crypto.PrivateKey {
+       return k.PrivateKey
+}
+
+func (k *ecPrivateKey) toMap() map[string]interface{} {
+       jwk := k.ecPublicKey.toMap()
+
+       dBytes := k.D.Bytes()
+       // The length of this octet string MUST be ceiling(log-base-2(n)/8)
+       // octets (where n is the order of the curve). This is because the private
+       // key d must be in the interval [1, n-1] so the bitlength of d should be
+       // no larger than the bitlength of n-1. The easiest way to find the octet
+       // length is to take bitlength(n-1), add 7 to force a carry, and shift this
+       // bit sequence right by 3, which is essentially dividing by 8 and adding
+       // 1 if there is any remainder. Thus, the private key value d should be
+       // output to (bitlength(n-1)+7)>>3 octets.
+       n := k.ecPublicKey.Params().N
+       octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
+       // Create a buffer with the necessary zero-padding.
+       dBuf := make([]byte, octetLength-len(dBytes), octetLength)
+       dBuf = append(dBuf, dBytes...)
+
+       jwk["d"] = joseBase64UrlEncode(dBuf)
+
+       return jwk
+}
+
+// MarshalJSON serializes this Private Key using the JWK JSON serialization format for
+// elliptic curve keys.
+func (k *ecPrivateKey) MarshalJSON() (data []byte, err error) {
+       return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Private Key to DER-encoded PKIX format.
+func (k *ecPrivateKey) PEMBlock() (*pem.Block, error) {
+       derBytes, err := x509.MarshalECPrivateKey(k.PrivateKey)
+       if err != nil {
+               return nil, fmt.Errorf("unable to serialize EC PrivateKey to DER-encoded PKIX format: %s", err)
+       }
+       k.extended["keyID"] = k.KeyID() // For display purposes.
+       return createPemBlock("EC PRIVATE KEY", derBytes, k.extended)
+}
+
+func ecPrivateKeyFromMap(jwk map[string]interface{}) (*ecPrivateKey, error) {
+       dB64Url, err := stringFromMap(jwk, "d")
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Private Key: %s", err)
+       }
+
+       // JWK key type (kty) has already been determined to be "EC".
+       // Need to extract the public key information, then extract the private
+       // key value 'd'.
+       publicKey, err := ecPublicKeyFromMap(jwk)
+       if err != nil {
+               return nil, err
+       }
+
+       d, err := parseECPrivateParam(dB64Url, publicKey.Curve)
+       if err != nil {
+               return nil, fmt.Errorf("JWK EC Private Key d-param: %s", err)
+       }
+
+       key := &ecPrivateKey{
+               ecPublicKey: *publicKey,
+               PrivateKey: &ecdsa.PrivateKey{
+                       PublicKey: *publicKey.PublicKey,
+                       D:         d,
+               },
+       }
+
+       return key, nil
+}
+
+/*
+ *     Key Generation Functions.
+ */
+
+func generateECPrivateKey(curve elliptic.Curve) (k *ecPrivateKey, err error) {
+       k = new(ecPrivateKey)
+       k.PrivateKey, err = ecdsa.GenerateKey(curve, rand.Reader)
+       if err != nil {
+               return nil, err
+       }
+
+       k.ecPublicKey.PublicKey = &k.PrivateKey.PublicKey
+       k.extended = make(map[string]interface{})
+
+       return
+}
+
+// GenerateECP256PrivateKey generates a key pair using elliptic curve P-256.
+func GenerateECP256PrivateKey() (PrivateKey, error) {
+       k, err := generateECPrivateKey(elliptic.P256())
+       if err != nil {
+               return nil, fmt.Errorf("error generating EC P-256 key: %s", err)
+       }
+
+       k.curveName = "P-256"
+       k.signatureAlgorithm = es256
+
+       return k, nil
+}
+
+// GenerateECP384PrivateKey generates a key pair using elliptic curve P-384.
+func GenerateECP384PrivateKey() (PrivateKey, error) {
+       k, err := generateECPrivateKey(elliptic.P384())
+       if err != nil {
+               return nil, fmt.Errorf("error generating EC P-384 key: %s", err)
+       }
+
+       k.curveName = "P-384"
+       k.signatureAlgorithm = es384
+
+       return k, nil
+}
+
+// GenerateECP521PrivateKey generates aß key pair using elliptic curve P-521.
+func GenerateECP521PrivateKey() (PrivateKey, error) {
+       k, err := generateECPrivateKey(elliptic.P521())
+       if err != nil {
+               return nil, fmt.Errorf("error generating EC P-521 key: %s", err)
+       }
+
+       k.curveName = "P-521"
+       k.signatureAlgorithm = es512
+
+       return k, nil
+}
diff --git a/ec_key_test.go b/ec_key_test.go
new file mode 100644 (file)
index 0000000..26ac381
--- /dev/null
@@ -0,0 +1,157 @@
+package libtrust
+
+import (
+       "bytes"
+       "encoding/json"
+       "testing"
+)
+
+func generateECTestKeys(t *testing.T) []PrivateKey {
+       p256Key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       p384Key, err := GenerateECP384PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       p521Key, err := GenerateECP521PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       return []PrivateKey{p256Key, p384Key, p521Key}
+}
+
+func TestECKeys(t *testing.T) {
+       ecKeys := generateECTestKeys(t)
+
+       for _, ecKey := range ecKeys {
+               if ecKey.KeyType() != "EC" {
+                       t.Fatalf("key type must be %q, instead got %q", "EC", ecKey.KeyType())
+               }
+       }
+}
+
+func TestECSignVerify(t *testing.T) {
+       ecKeys := generateECTestKeys(t)
+
+       message := "Hello, World!"
+       data := bytes.NewReader([]byte(message))
+
+       sigAlgs := []*signatureAlgorithm{es256, es384, es512}
+
+       for i, ecKey := range ecKeys {
+               sigAlg := sigAlgs[i]
+
+               t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, ecKey.KeyID())
+
+               data.Seek(0, 0) // Reset the byte reader
+
+               // Sign
+               sig, alg, err := ecKey.Sign(data, sigAlg.HashID())
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               data.Seek(0, 0) // Reset the byte reader
+
+               // Verify
+               err = ecKey.Verify(data, alg, sig)
+               if err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
+
+func TestMarshalUnmarshalECKeys(t *testing.T) {
+       ecKeys := generateECTestKeys(t)
+       data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test."))
+       sigAlgs := []*signatureAlgorithm{es256, es384, es512}
+
+       for i, ecKey := range ecKeys {
+               sigAlg := sigAlgs[i]
+               privateJWKJSON, err := json.MarshalIndent(ecKey, "", "    ")
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               publicJWKJSON, err := json.MarshalIndent(ecKey.PublicKey(), "", "    ")
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               t.Logf("JWK Private Key: %s", string(privateJWKJSON))
+               t.Logf("JWK Public Key: %s", string(publicJWKJSON))
+
+               privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               // Ensure we can sign/verify a message with the unmarshalled keys.
+               data.Seek(0, 0) // Reset the byte reader
+               signature, alg, err := privKey2.Sign(data, sigAlg.HashID())
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               data.Seek(0, 0) // Reset the byte reader
+               err = pubKey2.Verify(data, alg, signature)
+               if err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
+
+func TestFromCryptoECKeys(t *testing.T) {
+       ecKeys := generateECTestKeys(t)
+
+       for _, ecKey := range ecKeys {
+               cryptoPrivateKey := ecKey.CryptoPrivateKey()
+               cryptoPublicKey := ecKey.CryptoPublicKey()
+
+               pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if pubKey.KeyID() != ecKey.KeyID() {
+                       t.Fatal("public key key ID mismatch")
+               }
+
+               privKey, err := FromCryptoPrivateKey(cryptoPrivateKey)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if privKey.KeyID() != ecKey.KeyID() {
+                       t.Fatal("public key key ID mismatch")
+               }
+       }
+}
+
+func TestExtendedFields(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       key.AddExtendedField("test", "foobar")
+       val := key.GetExtendedField("test")
+
+       gotVal, ok := val.(string)
+       if !ok {
+               t.Fatalf("value is not a string")
+       } else if gotVal != val {
+               t.Fatalf("value %q is not equal to %q", gotVal, val)
+       }
+
+}
diff --git a/filter.go b/filter.go
new file mode 100644 (file)
index 0000000..945852a
--- /dev/null
+++ b/filter.go
@@ -0,0 +1,44 @@
+package libtrust
+
+import (
+       "path/filepath"
+)
+
+// FilterByHosts filters the list of PublicKeys to only those which contain a
+// 'hosts' pattern which matches the given host. If *includeEmpty* is true,
+// then keys which do not specify any hosts are also returned.
+func FilterByHosts(keys []PublicKey, host string, includeEmpty bool) ([]PublicKey, error) {
+       filtered := make([]PublicKey, 0, len(keys))
+
+       for _, pubKey := range keys {
+               hosts, ok := pubKey.GetExtendedField("hosts").([]interface{})
+
+               if !ok || (ok && len(hosts) == 0) {
+                       if includeEmpty {
+                               filtered = append(filtered, pubKey)
+                       }
+                       continue
+               }
+
+               // Check if any hosts match pattern
+               for _, hostVal := range hosts {
+                       hostPattern, ok := hostVal.(string)
+                       if !ok {
+                               continue
+                       }
+
+                       match, err := filepath.Match(hostPattern, host)
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       if match {
+                               filtered = append(filtered, pubKey)
+                               continue
+                       }
+               }
+
+       }
+
+       return filtered, nil
+}
diff --git a/filter_test.go b/filter_test.go
new file mode 100644 (file)
index 0000000..b24e332
--- /dev/null
@@ -0,0 +1,79 @@
+package libtrust
+
+import (
+       "testing"
+)
+
+func compareKeySlices(t *testing.T, sliceA, sliceB []PublicKey) {
+       if len(sliceA) != len(sliceB) {
+               t.Fatalf("slice size %d, expected %d", len(sliceA), len(sliceB))
+       }
+
+       for i, itemA := range sliceA {
+               itemB := sliceB[i]
+               if itemA != itemB {
+                       t.Fatalf("slice index %d not equal: %#v != %#v", i, itemA, itemB)
+               }
+       }
+}
+
+func TestFilter(t *testing.T) {
+       keys := make([]PublicKey, 0, 8)
+
+       // Create 8 keys and add host entries.
+       for i := 0; i < cap(keys); i++ {
+               key, err := GenerateECP256PrivateKey()
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               switch {
+               case i == 0:
+                       // Don't add entries for this key, key 0.
+                       break
+               case i%2 == 0:
+                       // Should catch keys 2, 4, and 6.
+                       key.AddExtendedField("hosts", []interface{}{"*.even.example.com"})
+               case i == 7:
+                       // Should catch only the last key, and make it match any hostname.
+                       key.AddExtendedField("hosts", []interface{}{"*"})
+               default:
+                       // should catch keys 1, 3, 5.
+                       key.AddExtendedField("hosts", []interface{}{"*.example.com"})
+               }
+
+               keys = append(keys, key)
+       }
+
+       // Should match 2 keys, the empty one, and the one that matches all hosts.
+       matchedKeys, err := FilterByHosts(keys, "foo.bar.com", true)
+       if err != nil {
+               t.Fatal(err)
+       }
+       expectedMatch := []PublicKey{keys[0], keys[7]}
+       compareKeySlices(t, expectedMatch, matchedKeys)
+
+       // Should match 1 key, the one that matches any host.
+       matchedKeys, err = FilterByHosts(keys, "foo.bar.com", false)
+       if err != nil {
+               t.Fatal(err)
+       }
+       expectedMatch = []PublicKey{keys[7]}
+       compareKeySlices(t, expectedMatch, matchedKeys)
+
+       // Should match keys that end in "example.com", and the key that matches anything.
+       matchedKeys, err = FilterByHosts(keys, "foo.example.com", false)
+       if err != nil {
+               t.Fatal(err)
+       }
+       expectedMatch = []PublicKey{keys[1], keys[3], keys[5], keys[7]}
+       compareKeySlices(t, expectedMatch, matchedKeys)
+
+       // Should match all of the keys except the empty key.
+       matchedKeys, err = FilterByHosts(keys, "foo.even.example.com", false)
+       if err != nil {
+               t.Fatal(err)
+       }
+       expectedMatch = keys[1:]
+       compareKeySlices(t, expectedMatch, matchedKeys)
+}
diff --git a/hash.go b/hash.go
new file mode 100644 (file)
index 0000000..a2df787
--- /dev/null
+++ b/hash.go
@@ -0,0 +1,56 @@
+package libtrust
+
+import (
+       "crypto"
+       _ "crypto/sha256" // Registrer SHA224 and SHA256
+       _ "crypto/sha512" // Registrer SHA384 and SHA512
+       "fmt"
+)
+
+type signatureAlgorithm struct {
+       algHeaderParam string
+       hashID         crypto.Hash
+}
+
+func (h *signatureAlgorithm) HeaderParam() string {
+       return h.algHeaderParam
+}
+
+func (h *signatureAlgorithm) HashID() crypto.Hash {
+       return h.hashID
+}
+
+var (
+       rs256 = &signatureAlgorithm{"RS256", crypto.SHA256}
+       rs384 = &signatureAlgorithm{"RS384", crypto.SHA384}
+       rs512 = &signatureAlgorithm{"RS512", crypto.SHA512}
+       es256 = &signatureAlgorithm{"ES256", crypto.SHA256}
+       es384 = &signatureAlgorithm{"ES384", crypto.SHA384}
+       es512 = &signatureAlgorithm{"ES512", crypto.SHA512}
+)
+
+func rsaSignatureAlgorithmByName(alg string) (*signatureAlgorithm, error) {
+       switch {
+       case alg == "RS256":
+               return rs256, nil
+       case alg == "RS384":
+               return rs384, nil
+       case alg == "RS512":
+               return rs512, nil
+       default:
+               return nil, fmt.Errorf("RSA Digital Signature Algorithm %q not supported", alg)
+       }
+}
+
+func rsaPKCS1v15SignatureAlgorithmForHashID(hashID crypto.Hash) *signatureAlgorithm {
+       switch {
+       case hashID == crypto.SHA512:
+               return rs512
+       case hashID == crypto.SHA384:
+               return rs384
+       case hashID == crypto.SHA256:
+               fallthrough
+       default:
+               return rs256
+       }
+}
diff --git a/jsonsign.go b/jsonsign.go
new file mode 100644 (file)
index 0000000..c635304
--- /dev/null
@@ -0,0 +1,566 @@
+package libtrust
+
+import (
+       "bytes"
+       "crypto"
+       "crypto/x509"
+       "encoding/base64"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "time"
+       "unicode"
+)
+
+var (
+       // ErrInvalidSignContent is used when the content to be signed is invalid.
+       ErrInvalidSignContent = errors.New("invalid sign content")
+
+       // ErrInvalidJSONContent is used when invalid json is encountered.
+       ErrInvalidJSONContent = errors.New("invalid json content")
+
+       // ErrMissingSignatureKey is used when the specified signature key
+       // does not exist in the JSON content.
+       ErrMissingSignatureKey = errors.New("missing signature key")
+)
+
+type jsHeader struct {
+       JWK       PublicKey `json:"jwk,omitempty"`
+       Algorithm string    `json:"alg"`
+       Chain     []string  `json:"x5c,omitempty"`
+}
+
+type jsSignature struct {
+       Header    *jsHeader `json:"header"`
+       Signature string    `json:"signature"`
+       Protected string    `json:"protected,omitempty"`
+}
+
+type signKey struct {
+       PrivateKey
+       Chain []*x509.Certificate
+}
+
+// JSONSignature represents a signature of a json object.
+type JSONSignature struct {
+       payload      string
+       signatures   []*jsSignature
+       indent       string
+       formatLength int
+       formatTail   []byte
+}
+
+func newJSONSignature() *JSONSignature {
+       return &JSONSignature{
+               signatures: make([]*jsSignature, 0, 1),
+       }
+}
+
+// Payload returns the encoded payload of the signature. This
+// payload should not be signed directly
+func (js *JSONSignature) Payload() ([]byte, error) {
+       return joseBase64UrlDecode(js.payload)
+}
+
+func (js *JSONSignature) protectedHeader() (string, error) {
+       protected := map[string]interface{}{
+               "formatLength": js.formatLength,
+               "formatTail":   joseBase64UrlEncode(js.formatTail),
+               "time":         time.Now().UTC().Format(time.RFC3339),
+       }
+       protectedBytes, err := json.Marshal(protected)
+       if err != nil {
+               return "", err
+       }
+
+       return joseBase64UrlEncode(protectedBytes), nil
+}
+
+func (js *JSONSignature) signBytes(protectedHeader string) ([]byte, error) {
+       buf := make([]byte, len(js.payload)+len(protectedHeader)+1)
+       copy(buf, protectedHeader)
+       buf[len(protectedHeader)] = '.'
+       copy(buf[len(protectedHeader)+1:], js.payload)
+       return buf, nil
+}
+
+// Sign adds a signature using the given private key.
+func (js *JSONSignature) Sign(key PrivateKey) error {
+       protected, err := js.protectedHeader()
+       if err != nil {
+               return err
+       }
+       signBytes, err := js.signBytes(protected)
+       if err != nil {
+               return err
+       }
+       sigBytes, algorithm, err := key.Sign(bytes.NewReader(signBytes), crypto.SHA256)
+       if err != nil {
+               return err
+       }
+
+       header := &jsHeader{
+               JWK:       key.PublicKey(),
+               Algorithm: algorithm,
+       }
+       sig := &jsSignature{
+               Header:    header,
+               Signature: joseBase64UrlEncode(sigBytes),
+               Protected: protected,
+       }
+
+       js.signatures = append(js.signatures, sig)
+
+       return nil
+}
+
+// SignWithChain adds a signature using the given private key
+// and setting the x509 chain. The public key of the first element
+// in the chain must be the public key corresponding with the sign key.
+func (js *JSONSignature) SignWithChain(key PrivateKey, chain []*x509.Certificate) error {
+       // Ensure key.Chain[0] is public key for key
+       //key.Chain.PublicKey
+       //key.PublicKey().CryptoPublicKey()
+
+       // Verify chain
+       protected, err := js.protectedHeader()
+       if err != nil {
+               return err
+       }
+       signBytes, err := js.signBytes(protected)
+       if err != nil {
+               return err
+       }
+       sigBytes, algorithm, err := key.Sign(bytes.NewReader(signBytes), crypto.SHA256)
+       if err != nil {
+               return err
+       }
+
+       header := &jsHeader{
+               Chain:     make([]string, len(chain)),
+               Algorithm: algorithm,
+       }
+
+       for i, cert := range chain {
+               header.Chain[i] = base64.StdEncoding.EncodeToString(cert.Raw)
+       }
+
+       sig := &jsSignature{
+               Header:    header,
+               Signature: joseBase64UrlEncode(sigBytes),
+               Protected: protected,
+       }
+
+       js.signatures = append(js.signatures, sig)
+
+       return nil
+}
+
+// Verify verifies all the signatures and returns the list of
+// public keys used to sign. Any x509 chains are not checked.
+func (js *JSONSignature) Verify() ([]PublicKey, error) {
+       keys := make([]PublicKey, len(js.signatures))
+       for i, signature := range js.signatures {
+               signBytes, err := js.signBytes(signature.Protected)
+               if err != nil {
+                       return nil, err
+               }
+               var publicKey PublicKey
+               if len(signature.Header.Chain) > 0 {
+                       certBytes, err := base64.StdEncoding.DecodeString(signature.Header.Chain[0])
+                       if err != nil {
+                               return nil, err
+                       }
+                       cert, err := x509.ParseCertificate(certBytes)
+                       if err != nil {
+                               return nil, err
+                       }
+                       publicKey, err = FromCryptoPublicKey(cert.PublicKey)
+                       if err != nil {
+                               return nil, err
+                       }
+               } else if signature.Header.JWK != nil {
+                       publicKey = signature.Header.JWK
+               } else {
+                       return nil, errors.New("missing public key")
+               }
+
+               sigBytes, err := joseBase64UrlDecode(signature.Signature)
+               if err != nil {
+                       return nil, err
+               }
+
+               err = publicKey.Verify(bytes.NewReader(signBytes), signature.Header.Algorithm, sigBytes)
+               if err != nil {
+                       return nil, err
+               }
+
+               keys[i] = publicKey
+       }
+       return keys, nil
+}
+
+// VerifyChains verifies all the signatures and the chains associated
+// with each signature and returns the list of verified chains.
+// Signatures without an x509 chain are not checked.
+func (js *JSONSignature) VerifyChains(ca *x509.CertPool) ([][]*x509.Certificate, error) {
+       chains := make([][]*x509.Certificate, 0, len(js.signatures))
+       for _, signature := range js.signatures {
+               signBytes, err := js.signBytes(signature.Protected)
+               if err != nil {
+                       return nil, err
+               }
+               var publicKey PublicKey
+               if len(signature.Header.Chain) > 0 {
+                       certBytes, err := base64.StdEncoding.DecodeString(signature.Header.Chain[0])
+                       if err != nil {
+                               return nil, err
+                       }
+                       cert, err := x509.ParseCertificate(certBytes)
+                       if err != nil {
+                               return nil, err
+                       }
+                       publicKey, err = FromCryptoPublicKey(cert.PublicKey)
+                       if err != nil {
+                               return nil, err
+                       }
+                       intermediates := x509.NewCertPool()
+                       if len(signature.Header.Chain) > 1 {
+                               intermediateChain := signature.Header.Chain[1:]
+                               for i := range intermediateChain {
+                                       certBytes, err := base64.StdEncoding.DecodeString(intermediateChain[i])
+                                       if err != nil {
+                                               return nil, err
+                                       }
+                                       intermediate, err := x509.ParseCertificate(certBytes)
+                                       if err != nil {
+                                               return nil, err
+                                       }
+                                       intermediates.AddCert(intermediate)
+                               }
+                       }
+
+                       verifyOptions := x509.VerifyOptions{
+                               Intermediates: intermediates,
+                               Roots:         ca,
+                       }
+
+                       verifiedChains, err := cert.Verify(verifyOptions)
+                       if err != nil {
+                               return nil, err
+                       }
+                       chains = append(chains, verifiedChains...)
+
+                       sigBytes, err := joseBase64UrlDecode(signature.Signature)
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       err = publicKey.Verify(bytes.NewReader(signBytes), signature.Header.Algorithm, sigBytes)
+                       if err != nil {
+                               return nil, err
+                       }
+               }
+
+       }
+       return chains, nil
+}
+
+// JWS returns JSON serialized JWS according to
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-7.2
+func (js *JSONSignature) JWS() ([]byte, error) {
+       if len(js.signatures) == 0 {
+               return nil, errors.New("missing signature")
+       }
+       jsonMap := map[string]interface{}{
+               "payload":    js.payload,
+               "signatures": js.signatures,
+       }
+
+       return json.MarshalIndent(jsonMap, "", "   ")
+}
+
+func notSpace(r rune) bool {
+       return !unicode.IsSpace(r)
+}
+
+func detectJSONIndent(jsonContent []byte) (indent string) {
+       if len(jsonContent) > 2 && jsonContent[0] == '{' && jsonContent[1] == '\n' {
+               quoteIndex := bytes.IndexRune(jsonContent[1:], '"')
+               if quoteIndex > 0 {
+                       indent = string(jsonContent[2 : quoteIndex+1])
+               }
+       }
+       return
+}
+
+type jsParsedHeader struct {
+       JWK       json.RawMessage `json:"jwk"`
+       Algorithm string          `json:"alg"`
+       Chain     []string        `json:"x5c"`
+}
+
+type jsParsedSignature struct {
+       Header    *jsParsedHeader `json:"header"`
+       Signature string          `json:"signature"`
+       Protected string          `json:"protected"`
+}
+
+// ParseJWS parses a JWS serialized JSON object into a Json Signature.
+func ParseJWS(content []byte) (*JSONSignature, error) {
+       type jsParsed struct {
+               Payload    string               `json:"payload"`
+               Signatures []*jsParsedSignature `json:"signatures"`
+       }
+       parsed := &jsParsed{}
+       err := json.Unmarshal(content, parsed)
+       if err != nil {
+               return nil, err
+       }
+       if len(parsed.Signatures) == 0 {
+               return nil, errors.New("missing signatures")
+       }
+       payload, err := joseBase64UrlDecode(parsed.Payload)
+       if err != nil {
+               return nil, err
+       }
+
+       js, err := NewJSONSignature(payload)
+       if err != nil {
+               return nil, err
+       }
+       js.signatures = make([]*jsSignature, len(parsed.Signatures))
+       for i, signature := range parsed.Signatures {
+               header := &jsHeader{
+                       Algorithm: signature.Header.Algorithm,
+               }
+               if signature.Header.Chain != nil {
+                       header.Chain = signature.Header.Chain
+               }
+               if signature.Header.JWK != nil {
+                       publicKey, err := UnmarshalPublicKeyJWK([]byte(signature.Header.JWK))
+                       if err != nil {
+                               return nil, err
+                       }
+                       header.JWK = publicKey
+               }
+               js.signatures[i] = &jsSignature{
+                       Header:    header,
+                       Signature: signature.Signature,
+                       Protected: signature.Protected,
+               }
+       }
+
+       return js, nil
+}
+
+// NewJSONSignature returns a new unsigned JWS from a json byte array.
+// JSONSignature will need to be signed before serializing or storing.
+func NewJSONSignature(content []byte) (*JSONSignature, error) {
+       var dataMap map[string]interface{}
+       err := json.Unmarshal(content, &dataMap)
+       if err != nil {
+               return nil, err
+       }
+
+       js := newJSONSignature()
+       js.indent = detectJSONIndent(content)
+
+       js.payload = joseBase64UrlEncode(content)
+
+       // Find trailing } and whitespace, put in protected header
+       closeIndex := bytes.LastIndexFunc(content, notSpace)
+       if content[closeIndex] != '}' {
+               return nil, ErrInvalidJSONContent
+       }
+       lastRuneIndex := bytes.LastIndexFunc(content[:closeIndex], notSpace)
+       if content[lastRuneIndex] == ',' {
+               return nil, ErrInvalidJSONContent
+       }
+       js.formatLength = lastRuneIndex + 1
+       js.formatTail = content[js.formatLength:]
+
+       return js, nil
+}
+
+// NewJSONSignatureFromMap returns a new unsigned JSONSignature from a map or
+// struct. JWS will need to be signed before serializing or storing.
+func NewJSONSignatureFromMap(content interface{}) (*JSONSignature, error) {
+       switch content.(type) {
+       case map[string]interface{}:
+       case struct{}:
+       default:
+               return nil, errors.New("invalid data type")
+       }
+
+       js := newJSONSignature()
+       js.indent = "   "
+
+       payload, err := json.MarshalIndent(content, "", js.indent)
+       if err != nil {
+               return nil, err
+       }
+       js.payload = joseBase64UrlEncode(payload)
+
+       // Remove '\n}' from formatted section, put in protected header
+       js.formatLength = len(payload) - 2
+       js.formatTail = payload[js.formatLength:]
+
+       return js, nil
+}
+
+func readIntFromMap(key string, m map[string]interface{}) (int, bool) {
+       value, ok := m[key]
+       if !ok {
+               return 0, false
+       }
+       switch v := value.(type) {
+       case int:
+               return v, true
+       case float64:
+               return int(v), true
+       default:
+               return 0, false
+       }
+}
+
+func readStringFromMap(key string, m map[string]interface{}) (v string, ok bool) {
+       value, ok := m[key]
+       if !ok {
+               return "", false
+       }
+       v, ok = value.(string)
+       return
+}
+
+// ParsePrettySignature parses a formatted signature into a
+// JSON signature. If the signatures are missing the format information
+// an error is thrown. The formatted signature must be created by
+// the same method as format signature.
+func ParsePrettySignature(content []byte, signatureKey string) (*JSONSignature, error) {
+       var contentMap map[string]json.RawMessage
+       err := json.Unmarshal(content, &contentMap)
+       if err != nil {
+               return nil, fmt.Errorf("error unmarshalling content: %s", err)
+       }
+       sigMessage, ok := contentMap[signatureKey]
+       if !ok {
+               return nil, ErrMissingSignatureKey
+       }
+
+       var signatureBlocks []jsParsedSignature
+       err = json.Unmarshal([]byte(sigMessage), &signatureBlocks)
+       if err != nil {
+               return nil, fmt.Errorf("error unmarshalling signatures: %s", err)
+       }
+
+       js := newJSONSignature()
+       js.signatures = make([]*jsSignature, len(signatureBlocks))
+
+       for i, signatureBlock := range signatureBlocks {
+               protectedBytes, err := joseBase64UrlDecode(signatureBlock.Protected)
+               if err != nil {
+                       return nil, fmt.Errorf("base64 decode error: %s", err)
+               }
+               var protectedHeader map[string]interface{}
+               err = json.Unmarshal(protectedBytes, &protectedHeader)
+               if err != nil {
+                       return nil, fmt.Errorf("error unmarshalling protected header: %s", err)
+               }
+
+               formatLength, ok := readIntFromMap("formatLength", protectedHeader)
+               if !ok {
+                       return nil, errors.New("missing formatted length")
+               }
+               encodedTail, ok := readStringFromMap("formatTail", protectedHeader)
+               if !ok {
+                       return nil, errors.New("missing formatted tail")
+               }
+               formatTail, err := joseBase64UrlDecode(encodedTail)
+               if err != nil {
+                       return nil, fmt.Errorf("base64 decode error on tail: %s", err)
+               }
+               if js.formatLength == 0 {
+                       js.formatLength = formatLength
+               } else if js.formatLength != formatLength {
+                       return nil, errors.New("conflicting format length")
+               }
+               if len(js.formatTail) == 0 {
+                       js.formatTail = formatTail
+               } else if bytes.Compare(js.formatTail, formatTail) != 0 {
+                       return nil, errors.New("conflicting format tail")
+               }
+
+               header := &jsHeader{
+                       Algorithm: signatureBlock.Header.Algorithm,
+                       Chain:     signatureBlock.Header.Chain,
+               }
+               if signatureBlock.Header.JWK != nil {
+                       publicKey, err := UnmarshalPublicKeyJWK([]byte(signatureBlock.Header.JWK))
+                       if err != nil {
+                               return nil, fmt.Errorf("error unmarshalling public key: %s", err)
+                       }
+                       header.JWK = publicKey
+               }
+               js.signatures[i] = &jsSignature{
+                       Header:    header,
+                       Signature: signatureBlock.Signature,
+                       Protected: signatureBlock.Protected,
+               }
+       }
+       if js.formatLength > len(content) {
+               return nil, errors.New("invalid format length")
+       }
+       formatted := make([]byte, js.formatLength+len(js.formatTail))
+       copy(formatted, content[:js.formatLength])
+       copy(formatted[js.formatLength:], js.formatTail)
+       js.indent = detectJSONIndent(formatted)
+       js.payload = joseBase64UrlEncode(formatted)
+
+       return js, nil
+}
+
+// PrettySignature formats a json signature into an easy to read
+// single json serialized object.
+func (js *JSONSignature) PrettySignature(signatureKey string) ([]byte, error) {
+       if len(js.signatures) == 0 {
+               return nil, errors.New("no signatures")
+       }
+       payload, err := joseBase64UrlDecode(js.payload)
+       if err != nil {
+               return nil, err
+       }
+       payload = payload[:js.formatLength]
+
+       var marshalled []byte
+       var marshallErr error
+       if js.indent != "" {
+               marshalled, marshallErr = json.MarshalIndent(js.signatures, js.indent, js.indent)
+       } else {
+               marshalled, marshallErr = json.Marshal(js.signatures)
+       }
+       if marshallErr != nil {
+               return nil, marshallErr
+       }
+
+       buf := bytes.NewBuffer(make([]byte, 0, len(payload)+len(marshalled)+34))
+       buf.Write(payload)
+       buf.WriteByte(',')
+       if js.indent != "" {
+               buf.WriteByte('\n')
+               buf.WriteString(js.indent)
+               buf.WriteByte('"')
+               buf.WriteString(signatureKey)
+               buf.WriteString("\": ")
+               buf.Write(marshalled)
+               buf.WriteByte('\n')
+       } else {
+               buf.WriteByte('"')
+               buf.WriteString(signatureKey)
+               buf.WriteString("\":")
+               buf.Write(marshalled)
+       }
+       buf.WriteByte('}')
+
+       return buf.Bytes(), nil
+}
diff --git a/jsonsign_test.go b/jsonsign_test.go
new file mode 100644 (file)
index 0000000..59616b9
--- /dev/null
@@ -0,0 +1,297 @@
+package libtrust
+
+import (
+       "bytes"
+       "crypto/x509"
+       "encoding/json"
+       "fmt"
+       "testing"
+
+       "github.com/docker/libtrust/testutil"
+)
+
+func createTestJSON(sigKey string, indent string) (map[string]interface{}, []byte) {
+       testMap := map[string]interface{}{
+               "name": "dmcgowan/mycontainer",
+               "config": map[string]interface{}{
+                       "ports": []int{9101, 9102},
+                       "run":   "/bin/echo \"Hello\"",
+               },
+               "layers": []string{
+                       "2893c080-27f5-11e4-8c21-0800200c9a66",
+                       "c54bc25b-fbb2-497b-a899-a8bc1b5b9d55",
+                       "4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4",
+                       "0b6da891-7f7f-4abf-9c97-7887549e696c",
+                       "1d960389-ae4f-4011-85fd-18d0f96a67ad",
+               },
+       }
+       formattedSection := `{"config":{"ports":[9101,9102],"run":"/bin/echo \"Hello\""},"layers":["2893c080-27f5-11e4-8c21-0800200c9a66","c54bc25b-fbb2-497b-a899-a8bc1b5b9d55","4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4","0b6da891-7f7f-4abf-9c97-7887549e696c","1d960389-ae4f-4011-85fd-18d0f96a67ad"],"name":"dmcgowan/mycontainer","%s":[{"header":{`
+       formattedSection = fmt.Sprintf(formattedSection, sigKey)
+       if indent != "" {
+               buf := bytes.NewBuffer(nil)
+               json.Indent(buf, []byte(formattedSection), "", indent)
+               return testMap, buf.Bytes()
+       }
+       return testMap, []byte(formattedSection)
+
+}
+
+func TestSignJSON(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating EC key: %s", err)
+       }
+
+       testMap, _ := createTestJSON("buildSignatures", "   ")
+       indented, err := json.MarshalIndent(testMap, "", "   ")
+       if err != nil {
+               t.Fatalf("Marshall error: %s", err)
+       }
+
+       js, err := NewJSONSignature(indented)
+       if err != nil {
+               t.Fatalf("Error creating JSON signature: %s", err)
+       }
+       err = js.Sign(key)
+       if err != nil {
+               t.Fatalf("Error signing content: %s", err)
+       }
+
+       keys, err := js.Verify()
+       if err != nil {
+               t.Fatalf("Error verifying signature: %s", err)
+       }
+       if len(keys) != 1 {
+               t.Fatalf("Error wrong number of keys returned")
+       }
+       if keys[0].KeyID() != key.KeyID() {
+               t.Fatalf("Unexpected public key returned")
+       }
+
+}
+
+func TestSignMap(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating EC key: %s", err)
+       }
+
+       testMap, _ := createTestJSON("buildSignatures", "   ")
+       js, err := NewJSONSignatureFromMap(testMap)
+       if err != nil {
+               t.Fatalf("Error creating JSON signature: %s", err)
+       }
+       err = js.Sign(key)
+       if err != nil {
+               t.Fatalf("Error signing JSON signature: %s", err)
+       }
+
+       keys, err := js.Verify()
+       if err != nil {
+               t.Fatalf("Error verifying signature: %s", err)
+       }
+       if len(keys) != 1 {
+               t.Fatalf("Error wrong number of keys returned")
+       }
+       if keys[0].KeyID() != key.KeyID() {
+               t.Fatalf("Unexpected public key returned")
+       }
+}
+
+func TestFormattedJson(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating EC key: %s", err)
+       }
+
+       testMap, firstSection := createTestJSON("buildSignatures", "     ")
+       indented, err := json.MarshalIndent(testMap, "", "     ")
+       if err != nil {
+               t.Fatalf("Marshall error: %s", err)
+       }
+
+       js, err := NewJSONSignature(indented)
+       if err != nil {
+               t.Fatalf("Error creating JSON signature: %s", err)
+       }
+       err = js.Sign(key)
+       if err != nil {
+               t.Fatalf("Error signing content: %s", err)
+       }
+
+       b, err := js.PrettySignature("buildSignatures")
+       if err != nil {
+               t.Fatalf("Error signing map: %s", err)
+       }
+
+       if bytes.Compare(b[:len(firstSection)], firstSection) != 0 {
+               t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)])
+       }
+
+       parsed, err := ParsePrettySignature(b, "buildSignatures")
+       if err != nil {
+               t.Fatalf("Error parsing formatted signature: %s", err)
+       }
+
+       keys, err := parsed.Verify()
+       if err != nil {
+               t.Fatalf("Error verifying signature: %s", err)
+       }
+       if len(keys) != 1 {
+               t.Fatalf("Error wrong number of keys returned")
+       }
+       if keys[0].KeyID() != key.KeyID() {
+               t.Fatalf("Unexpected public key returned")
+       }
+
+       var unmarshalled map[string]interface{}
+       err = json.Unmarshal(b, &unmarshalled)
+       if err != nil {
+               t.Fatalf("Could not unmarshall after parse: %s", err)
+       }
+
+}
+
+func TestFormattedFlatJson(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating EC key: %s", err)
+       }
+
+       testMap, firstSection := createTestJSON("buildSignatures", "")
+       unindented, err := json.Marshal(testMap)
+       if err != nil {
+               t.Fatalf("Marshall error: %s", err)
+       }
+
+       js, err := NewJSONSignature(unindented)
+       if err != nil {
+               t.Fatalf("Error creating JSON signature: %s", err)
+       }
+       err = js.Sign(key)
+       if err != nil {
+               t.Fatalf("Error signing JSON signature: %s", err)
+       }
+
+       b, err := js.PrettySignature("buildSignatures")
+       if err != nil {
+               t.Fatalf("Error signing map: %s", err)
+       }
+
+       if bytes.Compare(b[:len(firstSection)], firstSection) != 0 {
+               t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)])
+       }
+
+       parsed, err := ParsePrettySignature(b, "buildSignatures")
+       if err != nil {
+               t.Fatalf("Error parsing formatted signature: %s", err)
+       }
+
+       keys, err := parsed.Verify()
+       if err != nil {
+               t.Fatalf("Error verifying signature: %s", err)
+       }
+       if len(keys) != 1 {
+               t.Fatalf("Error wrong number of keys returned")
+       }
+       if keys[0].KeyID() != key.KeyID() {
+               t.Fatalf("Unexpected public key returned")
+       }
+}
+
+func generateTrustChain(t *testing.T, key PrivateKey, ca *x509.Certificate) (PrivateKey, []*x509.Certificate) {
+       parent := ca
+       parentKey := key
+       chain := make([]*x509.Certificate, 6)
+       for i := 5; i > 0; i-- {
+               intermediatekey, err := GenerateECP256PrivateKey()
+               if err != nil {
+                       t.Fatalf("Error generate key: %s", err)
+               }
+               chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+               if err != nil {
+                       t.Fatalf("Error generating intermdiate certificate: %s", err)
+               }
+               parent = chain[i]
+               parentKey = intermediatekey
+       }
+       trustKey, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generate key: %s", err)
+       }
+       chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+       if err != nil {
+               t.Fatalf("Error generate trust cert: %s", err)
+       }
+
+       return trustKey, chain
+}
+
+func TestChainVerify(t *testing.T) {
+       caKey, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating key: %s", err)
+       }
+       ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
+       if err != nil {
+               t.Fatalf("Error generating ca: %s", err)
+       }
+       trustKey, chain := generateTrustChain(t, caKey, ca)
+
+       testMap, _ := createTestJSON("verifySignatures", "   ")
+       js, err := NewJSONSignatureFromMap(testMap)
+       if err != nil {
+               t.Fatalf("Error creating JSONSignature from map: %s", err)
+       }
+
+       err = js.SignWithChain(trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error signing with chain: %s", err)
+       }
+
+       pool := x509.NewCertPool()
+       pool.AddCert(ca)
+       chains, err := js.VerifyChains(pool)
+       if err != nil {
+               t.Fatalf("Error verifying content: %s", err)
+       }
+       if len(chains) != 1 {
+               t.Fatalf("Unexpected chains length: %d", len(chains))
+       }
+       if len(chains[0]) != 7 {
+               t.Fatalf("Unexpected chain length: %d", len(chains[0]))
+       }
+}
+
+func TestInvalidChain(t *testing.T) {
+       caKey, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating key: %s", err)
+       }
+       ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
+       if err != nil {
+               t.Fatalf("Error generating ca: %s", err)
+       }
+       trustKey, chain := generateTrustChain(t, caKey, ca)
+
+       testMap, _ := createTestJSON("verifySignatures", "   ")
+       js, err := NewJSONSignatureFromMap(testMap)
+       if err != nil {
+               t.Fatalf("Error creating JSONSignature from map: %s", err)
+       }
+
+       err = js.SignWithChain(trustKey, chain[:5])
+       if err != nil {
+               t.Fatalf("Error signing with chain: %s", err)
+       }
+
+       pool := x509.NewCertPool()
+       pool.AddCert(ca)
+       chains, err := js.VerifyChains(pool)
+       if err == nil {
+               t.Fatalf("Expected error verifying with bad chain")
+       }
+       if len(chains) != 0 {
+               t.Fatalf("Unexpected chains returned from invalid verify")
+       }
+}
diff --git a/key.go b/key.go
new file mode 100644 (file)
index 0000000..73642db
--- /dev/null
+++ b/key.go
@@ -0,0 +1,253 @@
+package libtrust
+
+import (
+       "crypto"
+       "crypto/ecdsa"
+       "crypto/rsa"
+       "crypto/x509"
+       "encoding/json"
+       "encoding/pem"
+       "errors"
+       "fmt"
+       "io"
+)
+
+// PublicKey is a generic interface for a Public Key.
+type PublicKey interface {
+       // KeyType returns the key type for this key. For elliptic curve keys,
+       // this value should be "EC". For RSA keys, this value should be "RSA".
+       KeyType() string
+       // KeyID returns a distinct identifier which is unique to this Public Key.
+       // The format generated by this library is a base32 encoding of a 240 bit
+       // hash of the public key data divided into 12 groups like so:
+       //    ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
+       KeyID() string
+       // Verify verifyies the signature of the data in the io.Reader using this
+       // Public Key. The alg parameter should identify the digital signature
+       // algorithm which was used to produce the signature and should be
+       // supported by this public key. Returns a nil error if the signature
+       // is valid.
+       Verify(data io.Reader, alg string, signature []byte) error
+       // CryptoPublicKey returns the internal object which can be used as a
+       // crypto.PublicKey for use with other standard library operations. The type
+       // is either *rsa.PublicKey or *ecdsa.PublicKey
+       CryptoPublicKey() crypto.PublicKey
+       // These public keys can be serialized to the standard JSON encoding for
+       // JSON Web Keys. See section 6 of the IETF draft RFC for JOSE JSON Web
+       // Algorithms.
+       MarshalJSON() ([]byte, error)
+       // These keys can also be serialized to the standard PEM encoding.
+       PEMBlock() (*pem.Block, error)
+       // The string representation of a key is its key type and ID.
+       String() string
+       AddExtendedField(string, interface{})
+       GetExtendedField(string) interface{}
+}
+
+// PrivateKey is a generic interface for a Private Key.
+type PrivateKey interface {
+       // A PrivateKey contains all fields and methods of a PublicKey of the
+       // same type. The MarshalJSON method also outputs the private key as a
+       // JSON Web Key, and the PEMBlock method outputs the private key as a
+       // PEM block.
+       PublicKey
+       // PublicKey returns the PublicKey associated with this PrivateKey.
+       PublicKey() PublicKey
+       // Sign signs the data read from the io.Reader using a signature algorithm
+       // supported by the private key. If the specified hashing algorithm is
+       // supported by this key, that hash function is used to generate the
+       // signature otherwise the the default hashing algorithm for this key is
+       // used. Returns the signature and identifier of the algorithm used.
+       Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error)
+       // CryptoPrivateKey returns the internal object which can be used as a
+       // crypto.PublicKey for use with other standard library operations. The
+       // type is either *rsa.PublicKey or *ecdsa.PublicKey
+       CryptoPrivateKey() crypto.PrivateKey
+}
+
+// FromCryptoPublicKey returns a libtrust PublicKey representation of the given
+// *ecdsa.PublicKey or *rsa.PublicKey. Returns a non-nil error when the given
+// key is of an unsupported type.
+func FromCryptoPublicKey(cryptoPublicKey crypto.PublicKey) (PublicKey, error) {
+       switch cryptoPublicKey := cryptoPublicKey.(type) {
+       case *ecdsa.PublicKey:
+               return fromECPublicKey(cryptoPublicKey)
+       case *rsa.PublicKey:
+               return fromRSAPublicKey(cryptoPublicKey), nil
+       default:
+               return nil, fmt.Errorf("public key type %T is not supported", cryptoPublicKey)
+       }
+}
+
+// FromCryptoPrivateKey returns a libtrust PrivateKey representation of the given
+// *ecdsa.PrivateKey or *rsa.PrivateKey. Returns a non-nil error when the given
+// key is of an unsupported type.
+func FromCryptoPrivateKey(cryptoPrivateKey crypto.PrivateKey) (PrivateKey, error) {
+       switch cryptoPrivateKey := cryptoPrivateKey.(type) {
+       case *ecdsa.PrivateKey:
+               return fromECPrivateKey(cryptoPrivateKey)
+       case *rsa.PrivateKey:
+               return fromRSAPrivateKey(cryptoPrivateKey), nil
+       default:
+               return nil, fmt.Errorf("private key type %T is not supported", cryptoPrivateKey)
+       }
+}
+
+// UnmarshalPublicKeyPEM parses the PEM encoded data and returns a libtrust
+// PublicKey or an error if there is a problem with the encoding.
+func UnmarshalPublicKeyPEM(data []byte) (PublicKey, error) {
+       pemBlock, _ := pem.Decode(data)
+       if pemBlock == nil {
+               return nil, errors.New("unable to find PEM encoded data")
+       } else if pemBlock.Type != "PUBLIC KEY" {
+               return nil, fmt.Errorf("unable to get PublicKey from PEM type: %s", pemBlock.Type)
+       }
+
+       return pubKeyFromPEMBlock(pemBlock)
+}
+
+// UnmarshalPublicKeyPEMBundle parses the PEM encoded data as a bundle of
+// PEM blocks appended one after the other and returns a slice of PublicKey
+// objects that it finds.
+func UnmarshalPublicKeyPEMBundle(data []byte) ([]PublicKey, error) {
+       pubKeys := []PublicKey{}
+
+       for {
+               var pemBlock *pem.Block
+               pemBlock, data = pem.Decode(data)
+               if pemBlock == nil {
+                       break
+               } else if pemBlock.Type != "PUBLIC KEY" {
+                       return nil, fmt.Errorf("unable to get PublicKey from PEM type: %s", pemBlock.Type)
+               }
+
+               pubKey, err := pubKeyFromPEMBlock(pemBlock)
+               if err != nil {
+                       return nil, err
+               }
+
+               pubKeys = append(pubKeys, pubKey)
+       }
+
+       return pubKeys, nil
+}
+
+// UnmarshalPrivateKeyPEM parses the PEM encoded data and returns a libtrust
+// PrivateKey or an error if there is a problem with the encoding.
+func UnmarshalPrivateKeyPEM(data []byte) (PrivateKey, error) {
+       pemBlock, _ := pem.Decode(data)
+       if pemBlock == nil {
+               return nil, errors.New("unable to find PEM encoded data")
+       }
+
+       var key PrivateKey
+
+       switch {
+       case pemBlock.Type == "RSA PRIVATE KEY":
+               rsaPrivateKey, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes)
+               if err != nil {
+                       return nil, fmt.Errorf("unable to decode RSA Private Key PEM data: %s", err)
+               }
+               key = fromRSAPrivateKey(rsaPrivateKey)
+       case pemBlock.Type == "EC PRIVATE KEY":
+               ecPrivateKey, err := x509.ParseECPrivateKey(pemBlock.Bytes)
+               if err != nil {
+                       return nil, fmt.Errorf("unable to decode EC Private Key PEM data: %s", err)
+               }
+               key, err = fromECPrivateKey(ecPrivateKey)
+               if err != nil {
+                       return nil, err
+               }
+       default:
+               return nil, fmt.Errorf("unable to get PrivateKey from PEM type: %s", pemBlock.Type)
+       }
+
+       addPEMHeadersToKey(pemBlock, key.PublicKey())
+
+       return key, nil
+}
+
+// UnmarshalPublicKeyJWK unmarshals the given JSON Web Key into a generic
+// Public Key to be used with libtrust.
+func UnmarshalPublicKeyJWK(data []byte) (PublicKey, error) {
+       jwk := make(map[string]interface{})
+
+       err := json.Unmarshal(data, &jwk)
+       if err != nil {
+               return nil, fmt.Errorf(
+                       "decoding JWK Public Key JSON data: %s\n", err,
+               )
+       }
+
+       // Get the Key Type value.
+       kty, err := stringFromMap(jwk, "kty")
+       if err != nil {
+               return nil, fmt.Errorf("JWK Public Key type: %s", err)
+       }
+
+       switch {
+       case kty == "EC":
+               // Call out to unmarshal EC public key.
+               return ecPublicKeyFromMap(jwk)
+       case kty == "RSA":
+               // Call out to unmarshal RSA public key.
+               return rsaPublicKeyFromMap(jwk)
+       default:
+               return nil, fmt.Errorf(
+                       "JWK Public Key type not supported: %q\n", kty,
+               )
+       }
+}
+
+// UnmarshalPublicKeyJWKSet parses the JSON encoded data as a JSON Web Key Set
+// and returns a slice of Public Key objects.
+func UnmarshalPublicKeyJWKSet(data []byte) ([]PublicKey, error) {
+       rawKeys, err := loadJSONKeySetRaw(data)
+       if err != nil {
+               return nil, err
+       }
+
+       pubKeys := make([]PublicKey, 0, len(rawKeys))
+
+       for _, rawKey := range rawKeys {
+               pubKey, err := UnmarshalPublicKeyJWK(rawKey)
+               if err != nil {
+                       return nil, err
+               }
+               pubKeys = append(pubKeys, pubKey)
+       }
+
+       return pubKeys, nil
+}
+
+// UnmarshalPrivateKeyJWK unmarshals the given JSON Web Key into a generic
+// Private Key to be used with libtrust.
+func UnmarshalPrivateKeyJWK(data []byte) (PrivateKey, error) {
+       jwk := make(map[string]interface{})
+
+       err := json.Unmarshal(data, &jwk)
+       if err != nil {
+               return nil, fmt.Errorf(
+                       "decoding JWK Private Key JSON data: %s\n", err,
+               )
+       }
+
+       // Get the Key Type value.
+       kty, err := stringFromMap(jwk, "kty")
+       if err != nil {
+               return nil, fmt.Errorf("JWK Private Key type: %s", err)
+       }
+
+       switch {
+       case kty == "EC":
+               // Call out to unmarshal EC private key.
+               return ecPrivateKeyFromMap(jwk)
+       case kty == "RSA":
+               // Call out to unmarshal RSA private key.
+               return rsaPrivateKeyFromMap(jwk)
+       default:
+               return nil, fmt.Errorf(
+                       "JWK Private Key type not supported: %q\n", kty,
+               )
+       }
+}
diff --git a/key_files.go b/key_files.go
new file mode 100644 (file)
index 0000000..c526de5
--- /dev/null
@@ -0,0 +1,255 @@
+package libtrust
+
+import (
+       "encoding/json"
+       "encoding/pem"
+       "errors"
+       "fmt"
+       "io/ioutil"
+       "os"
+       "strings"
+)
+
+var (
+       // ErrKeyFileDoesNotExist indicates that the private key file does not exist.
+       ErrKeyFileDoesNotExist = errors.New("key file does not exist")
+)
+
+func readKeyFileBytes(filename string) ([]byte, error) {
+       data, err := ioutil.ReadFile(filename)
+       if err != nil {
+               if os.IsNotExist(err) {
+                       err = ErrKeyFileDoesNotExist
+               } else {
+                       err = fmt.Errorf("unable to read key file %s: %s", filename, err)
+               }
+
+               return nil, err
+       }
+
+       return data, nil
+}
+
+/*
+       Loading and Saving of Public and Private Keys in either PEM or JWK format.
+*/
+
+// LoadKeyFile opens the given filename and attempts to read a Private Key
+// encoded in either PEM or JWK format (if .json or .jwk file extension).
+func LoadKeyFile(filename string) (PrivateKey, error) {
+       contents, err := readKeyFileBytes(filename)
+       if err != nil {
+               return nil, err
+       }
+
+       var key PrivateKey
+
+       if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+               key, err = UnmarshalPrivateKeyJWK(contents)
+               if err != nil {
+                       return nil, fmt.Errorf("unable to decode private key JWK: %s", err)
+               }
+       } else {
+               key, err = UnmarshalPrivateKeyPEM(contents)
+               if err != nil {
+                       return nil, fmt.Errorf("unable to decode private key PEM: %s", err)
+               }
+       }
+
+       return key, nil
+}
+
+// LoadPublicKeyFile opens the given filename and attempts to read a Public Key
+// encoded in either PEM or JWK format (if .json or .jwk file extension).
+func LoadPublicKeyFile(filename string) (PublicKey, error) {
+       contents, err := readKeyFileBytes(filename)
+       if err != nil {
+               return nil, err
+       }
+
+       var key PublicKey
+
+       if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+               key, err = UnmarshalPublicKeyJWK(contents)
+               if err != nil {
+                       return nil, fmt.Errorf("unable to decode public key JWK: %s", err)
+               }
+       } else {
+               key, err = UnmarshalPublicKeyPEM(contents)
+               if err != nil {
+                       return nil, fmt.Errorf("unable to decode public key PEM: %s", err)
+               }
+       }
+
+       return key, nil
+}
+
+// SaveKey saves the given key to a file using the provided filename.
+// This process will overwrite any existing file at the provided location.
+func SaveKey(filename string, key PrivateKey) error {
+       var encodedKey []byte
+       var err error
+
+       if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+               // Encode in JSON Web Key format.
+               encodedKey, err = json.MarshalIndent(key, "", "    ")
+               if err != nil {
+                       return fmt.Errorf("unable to encode private key JWK: %s", err)
+               }
+       } else {
+               // Encode in PEM format.
+               pemBlock, err := key.PEMBlock()
+               if err != nil {
+                       return fmt.Errorf("unable to encode private key PEM: %s", err)
+               }
+               encodedKey = pem.EncodeToMemory(pemBlock)
+       }
+
+       err = ioutil.WriteFile(filename, encodedKey, os.FileMode(0600))
+       if err != nil {
+               return fmt.Errorf("unable to write private key file %s: %s", filename, err)
+       }
+
+       return nil
+}
+
+// SavePublicKey saves the given public key to the file.
+func SavePublicKey(filename string, key PublicKey) error {
+       var encodedKey []byte
+       var err error
+
+       if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+               // Encode in JSON Web Key format.
+               encodedKey, err = json.MarshalIndent(key, "", "    ")
+               if err != nil {
+                       return fmt.Errorf("unable to encode public key JWK: %s", err)
+               }
+       } else {
+               // Encode in PEM format.
+               pemBlock, err := key.PEMBlock()
+               if err != nil {
+                       return fmt.Errorf("unable to encode public key PEM: %s", err)
+               }
+               encodedKey = pem.EncodeToMemory(pemBlock)
+       }
+
+       err = ioutil.WriteFile(filename, encodedKey, os.FileMode(0644))
+       if err != nil {
+               return fmt.Errorf("unable to write public key file %s: %s", filename, err)
+       }
+
+       return nil
+}
+
+// Public Key Set files
+
+type jwkSet struct {
+       Keys []json.RawMessage `json:"keys"`
+}
+
+// LoadKeySetFile loads a key set
+func LoadKeySetFile(filename string) ([]PublicKey, error) {
+       if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+               return loadJSONKeySetFile(filename)
+       }
+
+       // Must be a PEM format file
+       return loadPEMKeySetFile(filename)
+}
+
+func loadJSONKeySetRaw(data []byte) ([]json.RawMessage, error) {
+       if len(data) == 0 {
+               // This is okay, just return an empty slice.
+               return []json.RawMessage{}, nil
+       }
+
+       keySet := jwkSet{}
+
+       err := json.Unmarshal(data, &keySet)
+       if err != nil {
+               return nil, fmt.Errorf("unable to decode JSON Web Key Set: %s", err)
+       }
+
+       return keySet.Keys, nil
+}
+
+func loadJSONKeySetFile(filename string) ([]PublicKey, error) {
+       contents, err := readKeyFileBytes(filename)
+       if err != nil && err != ErrKeyFileDoesNotExist {
+               return nil, err
+       }
+
+       return UnmarshalPublicKeyJWKSet(contents)
+}
+
+func loadPEMKeySetFile(filename string) ([]PublicKey, error) {
+       data, err := readKeyFileBytes(filename)
+       if err != nil && err != ErrKeyFileDoesNotExist {
+               return nil, err
+       }
+
+       return UnmarshalPublicKeyPEMBundle(data)
+}
+
+// AddKeySetFile adds a key to a key set
+func AddKeySetFile(filename string, key PublicKey) error {
+       if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+               return addKeySetJSONFile(filename, key)
+       }
+
+       // Must be a PEM format file
+       return addKeySetPEMFile(filename, key)
+}
+
+func addKeySetJSONFile(filename string, key PublicKey) error {
+       encodedKey, err := json.Marshal(key)
+       if err != nil {
+               return fmt.Errorf("unable to encode trusted client key: %s", err)
+       }
+
+       contents, err := readKeyFileBytes(filename)
+       if err != nil && err != ErrKeyFileDoesNotExist {
+               return err
+       }
+
+       rawEntries, err := loadJSONKeySetRaw(contents)
+       if err != nil {
+               return err
+       }
+
+       rawEntries = append(rawEntries, json.RawMessage(encodedKey))
+       entriesWrapper := jwkSet{Keys: rawEntries}
+
+       encodedEntries, err := json.MarshalIndent(entriesWrapper, "", "    ")
+       if err != nil {
+               return fmt.Errorf("unable to encode trusted client keys: %s", err)
+       }
+
+       err = ioutil.WriteFile(filename, encodedEntries, os.FileMode(0644))
+       if err != nil {
+               return fmt.Errorf("unable to write trusted client keys file %s: %s", filename, err)
+       }
+
+       return nil
+}
+
+func addKeySetPEMFile(filename string, key PublicKey) error {
+       // Encode to PEM, open file for appending, write PEM.
+       file, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_RDWR, os.FileMode(0644))
+       if err != nil {
+               return fmt.Errorf("unable to open trusted client keys file %s: %s", filename, err)
+       }
+       defer file.Close()
+
+       pemBlock, err := key.PEMBlock()
+       if err != nil {
+               return fmt.Errorf("unable to encoded trusted key: %s", err)
+       }
+
+       _, err = file.Write(pem.EncodeToMemory(pemBlock))
+       if err != nil {
+               return fmt.Errorf("unable to write trusted keys file: %s", err)
+       }
+
+       return nil
+}
diff --git a/key_files_test.go b/key_files_test.go
new file mode 100644 (file)
index 0000000..66c71dd
--- /dev/null
@@ -0,0 +1,220 @@
+package libtrust
+
+import (
+       "errors"
+       "io/ioutil"
+       "os"
+       "testing"
+)
+
+func makeTempFile(t *testing.T, prefix string) (filename string) {
+       file, err := ioutil.TempFile("", prefix)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       filename = file.Name()
+       file.Close()
+
+       return
+}
+
+func TestKeyFiles(t *testing.T) {
+       key, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       testKeyFiles(t, key)
+
+       key, err = GenerateRSA2048PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       testKeyFiles(t, key)
+}
+
+func testKeyFiles(t *testing.T, key PrivateKey) {
+       var err error
+
+       privateKeyFilename := makeTempFile(t, "private_key")
+       privateKeyFilenamePEM := privateKeyFilename + ".pem"
+       privateKeyFilenameJWK := privateKeyFilename + ".jwk"
+
+       publicKeyFilename := makeTempFile(t, "public_key")
+       publicKeyFilenamePEM := publicKeyFilename + ".pem"
+       publicKeyFilenameJWK := publicKeyFilename + ".jwk"
+
+       if err = SaveKey(privateKeyFilenamePEM, key); err != nil {
+               t.Fatal(err)
+       }
+
+       if err = SaveKey(privateKeyFilenameJWK, key); err != nil {
+               t.Fatal(err)
+       }
+
+       if err = SavePublicKey(publicKeyFilenamePEM, key.PublicKey()); err != nil {
+               t.Fatal(err)
+       }
+
+       if err = SavePublicKey(publicKeyFilenameJWK, key.PublicKey()); err != nil {
+               t.Fatal(err)
+       }
+
+       loadedPEMKey, err := LoadKeyFile(privateKeyFilenamePEM)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       loadedJWKKey, err := LoadKeyFile(privateKeyFilenameJWK)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       loadedPEMPublicKey, err := LoadPublicKeyFile(publicKeyFilenamePEM)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       loadedJWKPublicKey, err := LoadPublicKeyFile(publicKeyFilenameJWK)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if key.KeyID() != loadedPEMKey.KeyID() {
+               t.Fatal(errors.New("key IDs do not match"))
+       }
+
+       if key.KeyID() != loadedJWKKey.KeyID() {
+               t.Fatal(errors.New("key IDs do not match"))
+       }
+
+       if key.KeyID() != loadedPEMPublicKey.KeyID() {
+               t.Fatal(errors.New("key IDs do not match"))
+       }
+
+       if key.KeyID() != loadedJWKPublicKey.KeyID() {
+               t.Fatal(errors.New("key IDs do not match"))
+       }
+
+       os.Remove(privateKeyFilename)
+       os.Remove(privateKeyFilenamePEM)
+       os.Remove(privateKeyFilenameJWK)
+       os.Remove(publicKeyFilename)
+       os.Remove(publicKeyFilenamePEM)
+       os.Remove(publicKeyFilenameJWK)
+}
+
+func TestTrustedHostKeysFile(t *testing.T) {
+       trustedHostKeysFilename := makeTempFile(t, "trusted_host_keys")
+       trustedHostKeysFilenamePEM := trustedHostKeysFilename + ".pem"
+       trustedHostKeysFilenameJWK := trustedHostKeysFilename + ".json"
+
+       testTrustedHostKeysFile(t, trustedHostKeysFilenamePEM)
+       testTrustedHostKeysFile(t, trustedHostKeysFilenameJWK)
+
+       os.Remove(trustedHostKeysFilename)
+       os.Remove(trustedHostKeysFilenamePEM)
+       os.Remove(trustedHostKeysFilenameJWK)
+}
+
+func testTrustedHostKeysFile(t *testing.T, trustedHostKeysFilename string) {
+       hostAddress1 := "docker.example.com:2376"
+       hostKey1, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       hostKey1.AddExtendedField("hosts", []string{hostAddress1})
+       err = AddKeySetFile(trustedHostKeysFilename, hostKey1.PublicKey())
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       trustedHostKeysMapping, err := LoadKeySetFile(trustedHostKeysFilename)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       for addr, hostKey := range trustedHostKeysMapping {
+               t.Logf("Host Address: %s\n", addr)
+               t.Logf("Host Key: %s\n\n", hostKey)
+       }
+
+       hostAddress2 := "192.168.59.103:2376"
+       hostKey2, err := GenerateRSA2048PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       hostKey2.AddExtendedField("hosts", hostAddress2)
+       err = AddKeySetFile(trustedHostKeysFilename, hostKey2.PublicKey())
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       trustedHostKeysMapping, err = LoadKeySetFile(trustedHostKeysFilename)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       for addr, hostKey := range trustedHostKeysMapping {
+               t.Logf("Host Address: %s\n", addr)
+               t.Logf("Host Key: %s\n\n", hostKey)
+       }
+
+}
+
+func TestTrustedClientKeysFile(t *testing.T) {
+       trustedClientKeysFilename := makeTempFile(t, "trusted_client_keys")
+       trustedClientKeysFilenamePEM := trustedClientKeysFilename + ".pem"
+       trustedClientKeysFilenameJWK := trustedClientKeysFilename + ".json"
+
+       testTrustedClientKeysFile(t, trustedClientKeysFilenamePEM)
+       testTrustedClientKeysFile(t, trustedClientKeysFilenameJWK)
+
+       os.Remove(trustedClientKeysFilename)
+       os.Remove(trustedClientKeysFilenamePEM)
+       os.Remove(trustedClientKeysFilenameJWK)
+}
+
+func testTrustedClientKeysFile(t *testing.T, trustedClientKeysFilename string) {
+       clientKey1, err := GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = AddKeySetFile(trustedClientKeysFilename, clientKey1.PublicKey())
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       trustedClientKeys, err := LoadKeySetFile(trustedClientKeysFilename)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       for _, clientKey := range trustedClientKeys {
+               t.Logf("Client Key: %s\n", clientKey)
+       }
+
+       clientKey2, err := GenerateRSA2048PrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = AddKeySetFile(trustedClientKeysFilename, clientKey2.PublicKey())
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       trustedClientKeys, err = LoadKeySetFile(trustedClientKeysFilename)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       for _, clientKey := range trustedClientKeys {
+               t.Logf("Client Key: %s\n", clientKey)
+       }
+}
diff --git a/rsa_key.go b/rsa_key.go
new file mode 100644 (file)
index 0000000..4546303
--- /dev/null
@@ -0,0 +1,436 @@
+package libtrust
+
+import (
+       "crypto"
+       "crypto/rand"
+       "crypto/rsa"
+       "crypto/x509"
+       "encoding/json"
+       "encoding/pem"
+       "errors"
+       "fmt"
+       "io"
+       "math/big"
+)
+
+/*
+ * RSA DSA PUBLIC KEY
+ */
+
+// rsaPublicKey implements a JWK Public Key using RSA digital signature algorithms.
+type rsaPublicKey struct {
+       *rsa.PublicKey
+       extended map[string]interface{}
+}
+
+func fromRSAPublicKey(cryptoPublicKey *rsa.PublicKey) *rsaPublicKey {
+       return &rsaPublicKey{cryptoPublicKey, map[string]interface{}{}}
+}
+
+// KeyType returns the JWK key type for RSA keys, i.e., "RSA".
+func (k *rsaPublicKey) KeyType() string {
+       return "RSA"
+}
+
+// KeyID returns a distinct identifier which is unique to this Public Key.
+func (k *rsaPublicKey) KeyID() string {
+       // Generate and return a 'libtrust' fingerprint of the RSA public key.
+       // For an RSA key this should be:
+       //   SHA256("RSA"+bytes(N)+bytes(E))
+       // Then truncated to 240 bits and encoded into 12 base32 groups like so:
+       //   ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
+       hasher := crypto.SHA256.New()
+       hasher.Write([]byte(k.KeyType()))
+       hasher.Write(k.N.Bytes())
+       hasher.Write(serializeRSAPublicExponentParam(k.E))
+       return keyIDEncode(hasher.Sum(nil)[:30])
+}
+
+func (k *rsaPublicKey) String() string {
+       return fmt.Sprintf("RSA Public Key <%s>", k.KeyID())
+}
+
+// Verify verifyies the signature of the data in the io.Reader using this Public Key.
+// The alg parameter should be the name of the JWA digital signature algorithm
+// which was used to produce the signature and should be supported by this
+// public key. Returns a nil error if the signature is valid.
+func (k *rsaPublicKey) Verify(data io.Reader, alg string, signature []byte) error {
+       // Verify the signature of the given date, return non-nil error if valid.
+       sigAlg, err := rsaSignatureAlgorithmByName(alg)
+       if err != nil {
+               return fmt.Errorf("unable to verify Signature: %s", err)
+       }
+
+       hasher := sigAlg.HashID().New()
+       _, err = io.Copy(hasher, data)
+       if err != nil {
+               return fmt.Errorf("error reading data to sign: %s", err)
+       }
+       hash := hasher.Sum(nil)
+
+       err = rsa.VerifyPKCS1v15(k.PublicKey, sigAlg.HashID(), hash, signature)
+       if err != nil {
+               return fmt.Errorf("invalid %s signature: %s", sigAlg.HeaderParam(), err)
+       }
+
+       return nil
+}
+
+// CryptoPublicKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
+       return k.PublicKey
+}
+
+func (k *rsaPublicKey) toMap() map[string]interface{} {
+       jwk := make(map[string]interface{})
+       for k, v := range k.extended {
+               jwk[k] = v
+       }
+       jwk["kty"] = k.KeyType()
+       jwk["kid"] = k.KeyID()
+       jwk["n"] = joseBase64UrlEncode(k.N.Bytes())
+       jwk["e"] = joseBase64UrlEncode(serializeRSAPublicExponentParam(k.E))
+
+       return jwk
+}
+
+// MarshalJSON serializes this Public Key using the JWK JSON serialization format for
+// RSA keys.
+func (k *rsaPublicKey) MarshalJSON() (data []byte, err error) {
+       return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Public Key to DER-encoded PKIX format.
+func (k *rsaPublicKey) PEMBlock() (*pem.Block, error) {
+       derBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey)
+       if err != nil {
+               return nil, fmt.Errorf("unable to serialize RSA PublicKey to DER-encoded PKIX format: %s", err)
+       }
+       k.extended["keyID"] = k.KeyID() // For display purposes.
+       return createPemBlock("PUBLIC KEY", derBytes, k.extended)
+}
+
+func (k *rsaPublicKey) AddExtendedField(field string, value interface{}) {
+       k.extended[field] = value
+}
+
+func (k *rsaPublicKey) GetExtendedField(field string) interface{} {
+       v, ok := k.extended[field]
+       if !ok {
+               return nil
+       }
+       return v
+}
+
+func rsaPublicKeyFromMap(jwk map[string]interface{}) (*rsaPublicKey, error) {
+       // JWK key type (kty) has already been determined to be "RSA".
+       // Need to extract 'n', 'e', and 'kid' and check for
+       // consistency.
+
+       // Get the modulus parameter N.
+       nB64Url, err := stringFromMap(jwk, "n")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Public Key modulus: %s", err)
+       }
+
+       n, err := parseRSAModulusParam(nB64Url)
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Public Key modulus: %s", err)
+       }
+
+       // Get the public exponent E.
+       eB64Url, err := stringFromMap(jwk, "e")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Public Key exponent: %s", err)
+       }
+
+       e, err := parseRSAPublicExponentParam(eB64Url)
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Public Key exponent: %s", err)
+       }
+
+       key := &rsaPublicKey{
+               PublicKey: &rsa.PublicKey{N: n, E: e},
+       }
+
+       // Key ID is optional, but if it exists, it should match the key.
+       _, ok := jwk["kid"]
+       if ok {
+               kid, err := stringFromMap(jwk, "kid")
+               if err != nil {
+                       return nil, fmt.Errorf("JWK RSA Public Key ID: %s", err)
+               }
+               if kid != key.KeyID() {
+                       return nil, fmt.Errorf("JWK RSA Public Key ID does not match: %s", kid)
+               }
+       }
+
+       if _, ok := jwk["d"]; ok {
+               return nil, fmt.Errorf("JWK RSA Public Key cannot contain private exponent")
+       }
+
+       key.extended = jwk
+
+       return key, nil
+}
+
+/*
+ * RSA DSA PRIVATE KEY
+ */
+
+// rsaPrivateKey implements a JWK Private Key using RSA digital signature algorithms.
+type rsaPrivateKey struct {
+       rsaPublicKey
+       *rsa.PrivateKey
+}
+
+func fromRSAPrivateKey(cryptoPrivateKey *rsa.PrivateKey) *rsaPrivateKey {
+       return &rsaPrivateKey{
+               *fromRSAPublicKey(&cryptoPrivateKey.PublicKey),
+               cryptoPrivateKey,
+       }
+}
+
+// PublicKey returns the Public Key data associated with this Private Key.
+func (k *rsaPrivateKey) PublicKey() PublicKey {
+       return &k.rsaPublicKey
+}
+
+func (k *rsaPrivateKey) String() string {
+       return fmt.Sprintf("RSA Private Key <%s>", k.KeyID())
+}
+
+// Sign signs the data read from the io.Reader using a signature algorithm supported
+// by the RSA private key. If the specified hashing algorithm is supported by
+// this key, that hash function is used to generate the signature otherwise the
+// the default hashing algorithm for this key is used. Returns the signature
+// and the name of the JWK signature algorithm used, e.g., "RS256", "RS384",
+// "RS512".
+func (k *rsaPrivateKey) Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) {
+       // Generate a signature of the data using the internal alg.
+       sigAlg := rsaPKCS1v15SignatureAlgorithmForHashID(hashID)
+       hasher := sigAlg.HashID().New()
+
+       _, err = io.Copy(hasher, data)
+       if err != nil {
+               return nil, "", fmt.Errorf("error reading data to sign: %s", err)
+       }
+       hash := hasher.Sum(nil)
+
+       signature, err = rsa.SignPKCS1v15(rand.Reader, k.PrivateKey, sigAlg.HashID(), hash)
+       if err != nil {
+               return nil, "", fmt.Errorf("error producing signature: %s", err)
+       }
+
+       alg = sigAlg.HeaderParam()
+
+       return
+}
+
+// CryptoPrivateKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *rsaPrivateKey) CryptoPrivateKey() crypto.PrivateKey {
+       return k.PrivateKey
+}
+
+func (k *rsaPrivateKey) toMap() map[string]interface{} {
+       k.Precompute() // Make sure the precomputed values are stored.
+       jwk := k.rsaPublicKey.toMap()
+
+       jwk["d"] = joseBase64UrlEncode(k.D.Bytes())
+       jwk["p"] = joseBase64UrlEncode(k.Primes[0].Bytes())
+       jwk["q"] = joseBase64UrlEncode(k.Primes[1].Bytes())
+       jwk["dp"] = joseBase64UrlEncode(k.Precomputed.Dp.Bytes())
+       jwk["dq"] = joseBase64UrlEncode(k.Precomputed.Dq.Bytes())
+       jwk["qi"] = joseBase64UrlEncode(k.Precomputed.Qinv.Bytes())
+
+       otherPrimes := k.Primes[2:]
+
+       if len(otherPrimes) > 0 {
+               otherPrimesInfo := make([]interface{}, len(otherPrimes))
+               for i, r := range otherPrimes {
+                       otherPrimeInfo := make(map[string]string, 3)
+                       otherPrimeInfo["r"] = joseBase64UrlEncode(r.Bytes())
+                       crtVal := k.Precomputed.CRTValues[i]
+                       otherPrimeInfo["d"] = joseBase64UrlEncode(crtVal.Exp.Bytes())
+                       otherPrimeInfo["t"] = joseBase64UrlEncode(crtVal.Coeff.Bytes())
+                       otherPrimesInfo[i] = otherPrimeInfo
+               }
+               jwk["oth"] = otherPrimesInfo
+       }
+
+       return jwk
+}
+
+// MarshalJSON serializes this Private Key using the JWK JSON serialization format for
+// RSA keys.
+func (k *rsaPrivateKey) MarshalJSON() (data []byte, err error) {
+       return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Private Key to DER-encoded PKIX format.
+func (k *rsaPrivateKey) PEMBlock() (*pem.Block, error) {
+       derBytes := x509.MarshalPKCS1PrivateKey(k.PrivateKey)
+       k.extended["keyID"] = k.KeyID() // For display purposes.
+       return createPemBlock("RSA PRIVATE KEY", derBytes, k.extended)
+}
+
+func rsaPrivateKeyFromMap(jwk map[string]interface{}) (*rsaPrivateKey, error) {
+       // The JWA spec for RSA Private Keys (draft rfc section 5.3.2) states that
+       // only the private key exponent 'd' is REQUIRED, the others are just for
+       // signature/decryption optimizations and SHOULD be included when the JWK
+       // is produced. We MAY choose to accept a JWK which only includes 'd', but
+       // we're going to go ahead and not choose to accept it without the extra
+       // fields. Only the 'oth' field will be optional (for multi-prime keys).
+       privateExponent, err := parseRSAPrivateKeyParamFromMap(jwk, "d")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Private Key exponent: %s", err)
+       }
+       firstPrimeFactor, err := parseRSAPrivateKeyParamFromMap(jwk, "p")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err)
+       }
+       secondPrimeFactor, err := parseRSAPrivateKeyParamFromMap(jwk, "q")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err)
+       }
+       firstFactorCRT, err := parseRSAPrivateKeyParamFromMap(jwk, "dp")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err)
+       }
+       secondFactorCRT, err := parseRSAPrivateKeyParamFromMap(jwk, "dq")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err)
+       }
+       crtCoeff, err := parseRSAPrivateKeyParamFromMap(jwk, "qi")
+       if err != nil {
+               return nil, fmt.Errorf("JWK RSA Private Key CRT coefficient: %s", err)
+       }
+
+       var oth interface{}
+       if _, ok := jwk["oth"]; ok {
+               oth = jwk["oth"]
+               delete(jwk, "oth")
+       }
+
+       // JWK key type (kty) has already been determined to be "RSA".
+       // Need to extract the public key information, then extract the private
+       // key values.
+       publicKey, err := rsaPublicKeyFromMap(jwk)
+       if err != nil {
+               return nil, err
+       }
+
+       privateKey := &rsa.PrivateKey{
+               PublicKey: *publicKey.PublicKey,
+               D:         privateExponent,
+               Primes:    []*big.Int{firstPrimeFactor, secondPrimeFactor},
+               Precomputed: rsa.PrecomputedValues{
+                       Dp:   firstFactorCRT,
+                       Dq:   secondFactorCRT,
+                       Qinv: crtCoeff,
+               },
+       }
+
+       if oth != nil {
+               // Should be an array of more JSON objects.
+               otherPrimesInfo, ok := oth.([]interface{})
+               if !ok {
+                       return nil, errors.New("JWK RSA Private Key: Invalid other primes info: must be an array")
+               }
+               numOtherPrimeFactors := len(otherPrimesInfo)
+               if numOtherPrimeFactors == 0 {
+                       return nil, errors.New("JWK RSA Privake Key: Invalid other primes info: must be absent or non-empty")
+               }
+               otherPrimeFactors := make([]*big.Int, numOtherPrimeFactors)
+               productOfPrimes := new(big.Int).Mul(firstPrimeFactor, secondPrimeFactor)
+               crtValues := make([]rsa.CRTValue, numOtherPrimeFactors)
+
+               for i, val := range otherPrimesInfo {
+                       otherPrimeinfo, ok := val.(map[string]interface{})
+                       if !ok {
+                               return nil, errors.New("JWK RSA Private Key: Invalid other prime info: must be a JSON object")
+                       }
+
+                       otherPrimeFactor, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "r")
+                       if err != nil {
+                               return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err)
+                       }
+                       otherFactorCRT, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "d")
+                       if err != nil {
+                               return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err)
+                       }
+                       otherCrtCoeff, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "t")
+                       if err != nil {
+                               return nil, fmt.Errorf("JWK RSA Private Key CRT coefficient: %s", err)
+                       }
+
+                       crtValue := crtValues[i]
+                       crtValue.Exp = otherFactorCRT
+                       crtValue.Coeff = otherCrtCoeff
+                       crtValue.R = productOfPrimes
+                       otherPrimeFactors[i] = otherPrimeFactor
+                       productOfPrimes = new(big.Int).Mul(productOfPrimes, otherPrimeFactor)
+               }
+
+               privateKey.Primes = append(privateKey.Primes, otherPrimeFactors...)
+               privateKey.Precomputed.CRTValues = crtValues
+       }
+
+       key := &rsaPrivateKey{
+               rsaPublicKey: *publicKey,
+               PrivateKey:   privateKey,
+       }
+
+       return key, nil
+}
+
+/*
+ *     Key Generation Functions.
+ */
+
+func generateRSAPrivateKey(bits int) (k *rsaPrivateKey, err error) {
+       k = new(rsaPrivateKey)
+       k.PrivateKey, err = rsa.GenerateKey(rand.Reader, bits)
+       if err != nil {
+               return nil, err
+       }
+
+       k.rsaPublicKey.PublicKey = &k.PrivateKey.PublicKey
+       k.extended = make(map[string]interface{})
+
+       return
+}
+
+// GenerateRSA2048PrivateKey generates a key pair using 2048-bit RSA.
+func GenerateRSA2048PrivateKey() (PrivateKey, error) {
+       k, err := generateRSAPrivateKey(2048)
+       if err != nil {
+               return nil, fmt.Errorf("error generating RSA 2048-bit key: %s", err)
+       }
+
+       return k, nil
+}
+
+// GenerateRSA3072PrivateKey generates a key pair using 3072-bit RSA.
+func GenerateRSA3072PrivateKey() (PrivateKey, error) {
+       k, err := generateRSAPrivateKey(3072)
+       if err != nil {
+               return nil, fmt.Errorf("error generating RSA 3072-bit key: %s", err)
+       }
+
+       return k, nil
+}
+
+// GenerateRSA4096PrivateKey generates a key pair using 4096-bit RSA.
+func GenerateRSA4096PrivateKey() (PrivateKey, error) {
+       k, err := generateRSAPrivateKey(4096)
+       if err != nil {
+               return nil, fmt.Errorf("error generating RSA 4096-bit key: %s", err)
+       }
+
+       return k, nil
+}
diff --git a/rsa_key_test.go b/rsa_key_test.go
new file mode 100644 (file)
index 0000000..5ec7707
--- /dev/null
@@ -0,0 +1,157 @@
+package libtrust
+
+import (
+       "bytes"
+       "encoding/json"
+       "log"
+       "testing"
+)
+
+var rsaKeys []PrivateKey
+
+func init() {
+       var err error
+       rsaKeys, err = generateRSATestKeys()
+       if err != nil {
+               log.Fatal(err)
+       }
+}
+
+func generateRSATestKeys() (keys []PrivateKey, err error) {
+       log.Println("Generating RSA 2048-bit Test Key")
+       rsa2048Key, err := GenerateRSA2048PrivateKey()
+       if err != nil {
+               return
+       }
+
+       log.Println("Generating RSA 3072-bit Test Key")
+       rsa3072Key, err := GenerateRSA3072PrivateKey()
+       if err != nil {
+               return
+       }
+
+       log.Println("Generating RSA 4096-bit Test Key")
+       rsa4096Key, err := GenerateRSA4096PrivateKey()
+       if err != nil {
+               return
+       }
+
+       log.Println("Done generating RSA Test Keys!")
+       keys = []PrivateKey{rsa2048Key, rsa3072Key, rsa4096Key}
+
+       return
+}
+
+func TestRSAKeys(t *testing.T) {
+       for _, rsaKey := range rsaKeys {
+               if rsaKey.KeyType() != "RSA" {
+                       t.Fatalf("key type must be %q, instead got %q", "RSA", rsaKey.KeyType())
+               }
+       }
+}
+
+func TestRSASignVerify(t *testing.T) {
+       message := "Hello, World!"
+       data := bytes.NewReader([]byte(message))
+
+       sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512}
+
+       for i, rsaKey := range rsaKeys {
+               sigAlg := sigAlgs[i]
+
+               t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, rsaKey.KeyID())
+
+               data.Seek(0, 0) // Reset the byte reader
+
+               // Sign
+               sig, alg, err := rsaKey.Sign(data, sigAlg.HashID())
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               data.Seek(0, 0) // Reset the byte reader
+
+               // Verify
+               err = rsaKey.Verify(data, alg, sig)
+               if err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
+
+func TestMarshalUnmarshalRSAKeys(t *testing.T) {
+       data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test."))
+       sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512}
+
+       for i, rsaKey := range rsaKeys {
+               sigAlg := sigAlgs[i]
+               privateJWKJSON, err := json.MarshalIndent(rsaKey, "", "    ")
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               publicJWKJSON, err := json.MarshalIndent(rsaKey.PublicKey(), "", "    ")
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               t.Logf("JWK Private Key: %s", string(privateJWKJSON))
+               t.Logf("JWK Public Key: %s", string(publicJWKJSON))
+
+               privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               // Ensure we can sign/verify a message with the unmarshalled keys.
+               data.Seek(0, 0) // Reset the byte reader
+               signature, alg, err := privKey2.Sign(data, sigAlg.HashID())
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               data.Seek(0, 0) // Reset the byte reader
+               err = pubKey2.Verify(data, alg, signature)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               // It's a good idea to validate the Private Key to make sure our
+               // (un)marshal process didn't corrupt the extra parameters.
+               k := privKey2.(*rsaPrivateKey)
+               err = k.PrivateKey.Validate()
+               if err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
+
+func TestFromCryptoRSAKeys(t *testing.T) {
+       for _, rsaKey := range rsaKeys {
+               cryptoPrivateKey := rsaKey.CryptoPrivateKey()
+               cryptoPublicKey := rsaKey.CryptoPublicKey()
+
+               pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if pubKey.KeyID() != rsaKey.KeyID() {
+                       t.Fatal("public key key ID mismatch")
+               }
+
+               privKey, err := FromCryptoPrivateKey(cryptoPrivateKey)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if privKey.KeyID() != rsaKey.KeyID() {
+                       t.Fatal("public key key ID mismatch")
+               }
+       }
+}
diff --git a/testutil/certificates.go b/testutil/certificates.go
new file mode 100644 (file)
index 0000000..89debf6
--- /dev/null
@@ -0,0 +1,94 @@
+package testutil
+
+import (
+       "crypto"
+       "crypto/rand"
+       "crypto/x509"
+       "crypto/x509/pkix"
+       "math/big"
+       "time"
+)
+
+// GenerateTrustCA generates a new certificate authority for testing.
+func GenerateTrustCA(pub crypto.PublicKey, priv crypto.PrivateKey) (*x509.Certificate, error) {
+       cert := &x509.Certificate{
+               SerialNumber: big.NewInt(0),
+               Subject: pkix.Name{
+                       CommonName: "CA Root",
+               },
+               NotBefore:             time.Now().Add(-time.Second),
+               NotAfter:              time.Now().Add(time.Hour),
+               IsCA:                  true,
+               KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
+               BasicConstraintsValid: true,
+       }
+
+       certDER, err := x509.CreateCertificate(rand.Reader, cert, cert, pub, priv)
+       if err != nil {
+               return nil, err
+       }
+
+       cert, err = x509.ParseCertificate(certDER)
+       if err != nil {
+               return nil, err
+       }
+
+       return cert, nil
+}
+
+// GenerateIntermediate generates an intermediate certificate for testing using
+// the parent certificate (likely a CA) and the provided keys.
+func GenerateIntermediate(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) {
+       cert := &x509.Certificate{
+               SerialNumber: big.NewInt(0),
+               Subject: pkix.Name{
+                       CommonName: "Intermediate",
+               },
+               NotBefore:             time.Now().Add(-time.Second),
+               NotAfter:              time.Now().Add(time.Hour),
+               IsCA:                  true,
+               KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
+               BasicConstraintsValid: true,
+       }
+
+       certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey)
+       if err != nil {
+               return nil, err
+       }
+
+       cert, err = x509.ParseCertificate(certDER)
+       if err != nil {
+               return nil, err
+       }
+
+       return cert, nil
+}
+
+// GenerateTrustCert generates a new trust certificate for testing.  Unlike the
+// intermediate certificates, this certificate should  be used for signature
+// only, not creating certificates.
+func GenerateTrustCert(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) {
+       cert := &x509.Certificate{
+               SerialNumber: big.NewInt(0),
+               Subject: pkix.Name{
+                       CommonName: "Trust Cert",
+               },
+               NotBefore:             time.Now().Add(-time.Second),
+               NotAfter:              time.Now().Add(time.Hour),
+               IsCA:                  true,
+               KeyUsage:              x509.KeyUsageDigitalSignature,
+               BasicConstraintsValid: true,
+       }
+
+       certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey)
+       if err != nil {
+               return nil, err
+       }
+
+       cert, err = x509.ParseCertificate(certDER)
+       if err != nil {
+               return nil, err
+       }
+
+       return cert, nil
+}
diff --git a/tlsdemo/README.md b/tlsdemo/README.md
new file mode 100644 (file)
index 0000000..24124db
--- /dev/null
@@ -0,0 +1,50 @@
+## Libtrust TLS Config Demo
+
+This program generates key pairs and trust files for a TLS client and server.
+
+To generate the keys, run:
+
+```
+$ go run genkeys.go
+```
+
+The generated files are:
+
+```
+$ ls -l client_data/ server_data/
+client_data/:
+total 24
+-rw-------  1 jlhawn  staff  281 Aug  8 16:21 private_key.json
+-rw-r--r--  1 jlhawn  staff  225 Aug  8 16:21 public_key.json
+-rw-r--r--  1 jlhawn  staff  275 Aug  8 16:21 trusted_hosts.json
+
+server_data/:
+total 24
+-rw-r--r--  1 jlhawn  staff  348 Aug  8 16:21 trusted_clients.json
+-rw-------  1 jlhawn  staff  281 Aug  8 16:21 private_key.json
+-rw-r--r--  1 jlhawn  staff  225 Aug  8 16:21 public_key.json
+```
+
+The private key and public key for the client and server are stored in `private_key.json` and `public_key.json`, respectively, and in their respective directories. They are represented as JSON Web Keys: JSON objects which represent either an ECDSA or RSA private key. The host keys trusted by the client are stored in `trusted_hosts.json` and contain a mapping of an internet address, `<HOSTNAME_OR_IP>:<PORT>`, to a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted server. The client keys trusted by the server are stored in `trusted_clients.json` and contain an array of JSON objects which contain a comment field which can be used describe the key and a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted client.
+
+To start the server, run:
+
+```
+$ go run server.go
+```
+
+This starts an HTTPS server which listens on `localhost:8888`. The server configures itself with a certificate which is valid for both `localhost` and `127.0.0.1` and uses the key from `server_data/private_key.json`. It accepts connections from clients which present a certificate for a key that it is configured to trust from the `trusted_clients.json` file and returns a simple 'hello' message.
+
+To make a request using the client, run:
+
+```
+$ go run client.go
+```
+
+This command creates an HTTPS client which makes a GET request to `https://localhost:8888`. The client configures itself with a certificate using the key from `client_data/private_key.json`. It only connects to a server which presents a certificate signed by the key specified for the `localhost:8888` address from `client_data/trusted_hosts.json` and made to be used for the `localhost` hostname. If the connection succeeds, it prints the response from the server.
+
+The file `gencert.go` can be used to generate PEM encoded version of the client key and certificate. If you save them to `key.pem` and `cert.pem` respectively, you can use them with `curl` to test out the server (if it is still running).
+
+```
+curl --cert cert.pem --key key.pem -k https://localhost:8888
+``` 
diff --git a/tlsdemo/client.go b/tlsdemo/client.go
new file mode 100644 (file)
index 0000000..0a699a0
--- /dev/null
@@ -0,0 +1,89 @@
+package main
+
+import (
+       "crypto/tls"
+       "fmt"
+       "io/ioutil"
+       "log"
+       "net"
+       "net/http"
+
+       "github.com/docker/libtrust"
+)
+
+var (
+       serverAddress        = "localhost:8888"
+       privateKeyFilename   = "client_data/private_key.pem"
+       trustedHostsFilename = "client_data/trusted_hosts.pem"
+)
+
+func main() {
+       // Load Client Key.
+       clientKey, err := libtrust.LoadKeyFile(privateKeyFilename)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Generate Client Certificate.
+       selfSignedClientCert, err := libtrust.GenerateSelfSignedClientCert(clientKey)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Load trusted host keys.
+       hostKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Ensure the host we want to connect to is trusted!
+       host, _, err := net.SplitHostPort(serverAddress)
+       if err != nil {
+               log.Fatal(err)
+       }
+       serverKeys, err := libtrust.FilterByHosts(hostKeys, host, false)
+       if err != nil {
+               log.Fatalf("%q is not a known and trusted host", host)
+       }
+
+       // Generate a CA pool with the trusted host's key.
+       caPool, err := libtrust.GenerateCACertPool(clientKey, serverKeys)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Create HTTP Client.
+       client := &http.Client{
+               Transport: &http.Transport{
+                       TLSClientConfig: &tls.Config{
+                               Certificates: []tls.Certificate{
+                                       tls.Certificate{
+                                               Certificate: [][]byte{selfSignedClientCert.Raw},
+                                               PrivateKey:  clientKey.CryptoPrivateKey(),
+                                               Leaf:        selfSignedClientCert,
+                                       },
+                               },
+                               RootCAs: caPool,
+                       },
+               },
+       }
+
+       var makeRequest = func(url string) {
+               resp, err := client.Get(url)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               defer resp.Body.Close()
+
+               body, err := ioutil.ReadAll(resp.Body)
+               if err != nil {
+                       log.Fatal(err)
+               }
+
+               log.Println(resp.Status)
+               log.Println(string(body))
+       }
+
+       // Make the request to the trusted server!
+       makeRequest(fmt.Sprintf("https://%s", serverAddress))
+}
diff --git a/tlsdemo/gencert.go b/tlsdemo/gencert.go
new file mode 100644 (file)
index 0000000..c65f3b6
--- /dev/null
@@ -0,0 +1,62 @@
+package main
+
+import (
+       "encoding/pem"
+       "fmt"
+       "log"
+       "net"
+
+       "github.com/docker/libtrust"
+)
+
+var (
+       serverAddress            = "localhost:8888"
+       clientPrivateKeyFilename = "client_data/private_key.pem"
+       trustedHostsFilename     = "client_data/trusted_hosts.pem"
+)
+
+func main() {
+       key, err := libtrust.LoadKeyFile(clientPrivateKeyFilename)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       keyPEMBlock, err := key.PEMBlock()
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       encodedPrivKey := pem.EncodeToMemory(keyPEMBlock)
+       fmt.Printf("Client Key:\n\n%s\n", string(encodedPrivKey))
+
+       cert, err := libtrust.GenerateSelfSignedClientCert(key)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       encodedCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
+       fmt.Printf("Client Cert:\n\n%s\n", string(encodedCert))
+
+       trustedServerKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       hostname, _, err := net.SplitHostPort(serverAddress)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       trustedServerKeys, err = libtrust.FilterByHosts(trustedServerKeys, hostname, false)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       caCert, err := libtrust.GenerateCACert(key, trustedServerKeys[0])
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       encodedCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw})
+       fmt.Printf("CA Cert:\n\n%s\n", string(encodedCert))
+}
diff --git a/tlsdemo/genkeys.go b/tlsdemo/genkeys.go
new file mode 100644 (file)
index 0000000..9dc8842
--- /dev/null
@@ -0,0 +1,61 @@
+package main
+
+import (
+       "log"
+
+       "github.com/docker/libtrust"
+)
+
+func main() {
+       // Generate client key.
+       clientKey, err := libtrust.GenerateECP256PrivateKey()
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Add a comment for the client key.
+       clientKey.AddExtendedField("comment", "TLS Demo Client")
+
+       // Save the client key, public and private versions.
+       err = libtrust.SaveKey("client_data/private_key.pem", clientKey)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       err = libtrust.SavePublicKey("client_data/public_key.pem", clientKey.PublicKey())
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Generate server key.
+       serverKey, err := libtrust.GenerateECP256PrivateKey()
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Set the list of addresses to use for the server.
+       serverKey.AddExtendedField("hosts", []string{"localhost", "docker.example.com"})
+
+       // Save the server key, public and private versions.
+       err = libtrust.SaveKey("server_data/private_key.pem", serverKey)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       err = libtrust.SavePublicKey("server_data/public_key.pem", serverKey.PublicKey())
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Generate Authorized Keys file for server.
+       err = libtrust.AddKeySetFile("server_data/trusted_clients.pem", clientKey.PublicKey())
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Generate Known Host Keys file for client.
+       err = libtrust.AddKeySetFile("client_data/trusted_hosts.pem", serverKey.PublicKey())
+       if err != nil {
+               log.Fatal(err)
+       }
+}
diff --git a/tlsdemo/server.go b/tlsdemo/server.go
new file mode 100644 (file)
index 0000000..d3cb2ea
--- /dev/null
@@ -0,0 +1,80 @@
+package main
+
+import (
+       "crypto/tls"
+       "fmt"
+       "html"
+       "log"
+       "net"
+       "net/http"
+
+       "github.com/docker/libtrust"
+)
+
+var (
+       serverAddress             = "localhost:8888"
+       privateKeyFilename        = "server_data/private_key.pem"
+       authorizedClientsFilename = "server_data/trusted_clients.pem"
+)
+
+func requestHandler(w http.ResponseWriter, r *http.Request) {
+       clientCert := r.TLS.PeerCertificates[0]
+       keyID := clientCert.Subject.CommonName
+       log.Printf("Request from keyID: %s\n", keyID)
+       fmt.Fprintf(w, "Hello, client! I'm a server! And you are %T: %s.\n", clientCert.PublicKey, html.EscapeString(keyID))
+}
+
+func main() {
+       // Load server key.
+       serverKey, err := libtrust.LoadKeyFile(privateKeyFilename)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Generate server certificate.
+       selfSignedServerCert, err := libtrust.GenerateSelfSignedServerCert(
+               serverKey, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")},
+       )
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Load authorized client keys.
+       authorizedClients, err := libtrust.LoadKeySetFile(authorizedClientsFilename)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Create CA pool using trusted client keys.
+       caPool, err := libtrust.GenerateCACertPool(serverKey, authorizedClients)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       // Create TLS config, requiring client certificates.
+       tlsConfig := &tls.Config{
+               Certificates: []tls.Certificate{
+                       tls.Certificate{
+                               Certificate: [][]byte{selfSignedServerCert.Raw},
+                               PrivateKey:  serverKey.CryptoPrivateKey(),
+                               Leaf:        selfSignedServerCert,
+                       },
+               },
+               ClientAuth: tls.RequireAndVerifyClientCert,
+               ClientCAs:  caPool,
+       }
+
+       // Create HTTP server with simple request handler.
+       server := &http.Server{
+               Addr:    serverAddress,
+               Handler: http.HandlerFunc(requestHandler),
+       }
+
+       // Listen and server HTTPS using the libtrust TLS config.
+       listener, err := net.Listen("tcp", server.Addr)
+       if err != nil {
+               log.Fatal(err)
+       }
+       tlsListener := tls.NewListener(listener, tlsConfig)
+       server.Serve(tlsListener)
+}
diff --git a/trustgraph/graph.go b/trustgraph/graph.go
new file mode 100644 (file)
index 0000000..72b0fc3
--- /dev/null
@@ -0,0 +1,50 @@
+package trustgraph
+
+import "github.com/docker/libtrust"
+
+// TrustGraph represents a graph of authorization mapping
+// public keys to nodes and grants between nodes.
+type TrustGraph interface {
+       // Verifies that the given public key is allowed to perform
+       // the given action on the given node according to the trust
+       // graph.
+       Verify(libtrust.PublicKey, string, uint16) (bool, error)
+
+       // GetGrants returns an array of all grant chains which are used to
+       // allow the requested permission.
+       GetGrants(libtrust.PublicKey, string, uint16) ([][]*Grant, error)
+}
+
+// Grant represents a transfer of permission from one part of the
+// trust graph to another. This is the only way to delegate
+// permission between two different sub trees in the graph.
+type Grant struct {
+       // Subject is the namespace being granted
+       Subject string
+
+       // Permissions is a bit map of permissions
+       Permission uint16
+
+       // Grantee represents the node being granted
+       // a permission scope.  The grantee can be
+       // either a namespace item or a key id where namespace
+       // items will always start with a '/'.
+       Grantee string
+
+       // statement represents the statement used to create
+       // this object.
+       statement *Statement
+}
+
+// Permissions
+//  Read node 0x01 (can read node, no sub nodes)
+//  Write node 0x02 (can write to node object, cannot create subnodes)
+//  Read subtree 0x04 (delegates read to each sub node)
+//  Write subtree 0x08 (delegates write to each sub node, included create on the subject)
+//
+// Permission shortcuts
+// ReadItem = 0x01
+// WriteItem = 0x03
+// ReadAccess = 0x07
+// WriteAccess = 0x0F
+// Delegate = 0x0F
diff --git a/trustgraph/memory_graph.go b/trustgraph/memory_graph.go
new file mode 100644 (file)
index 0000000..247bfa7
--- /dev/null
@@ -0,0 +1,133 @@
+package trustgraph
+
+import (
+       "strings"
+
+       "github.com/docker/libtrust"
+)
+
+type grantNode struct {
+       grants   []*Grant
+       children map[string]*grantNode
+}
+
+type memoryGraph struct {
+       roots map[string]*grantNode
+}
+
+func newGrantNode() *grantNode {
+       return &grantNode{
+               grants:   []*Grant{},
+               children: map[string]*grantNode{},
+       }
+}
+
+// NewMemoryGraph returns a new in memory trust graph created from
+// a static list of grants.  This graph is immutable after creation
+// and any alterations should create a new instance.
+func NewMemoryGraph(grants []*Grant) TrustGraph {
+       roots := map[string]*grantNode{}
+       for _, grant := range grants {
+               parts := strings.Split(grant.Grantee, "/")
+               nodes := roots
+               var node *grantNode
+               var nodeOk bool
+               for _, part := range parts {
+                       node, nodeOk = nodes[part]
+                       if !nodeOk {
+                               node = newGrantNode()
+                               nodes[part] = node
+                       }
+                       if part != "" {
+                               node.grants = append(node.grants, grant)
+                       }
+                       nodes = node.children
+               }
+       }
+       return &memoryGraph{roots}
+}
+
+func (g *memoryGraph) getGrants(name string) []*Grant {
+       nameParts := strings.Split(name, "/")
+       nodes := g.roots
+       var node *grantNode
+       var nodeOk bool
+       for _, part := range nameParts {
+               node, nodeOk = nodes[part]
+               if !nodeOk {
+                       return nil
+               }
+               nodes = node.children
+       }
+       return node.grants
+}
+
+func isSubName(name, sub string) bool {
+       if strings.HasPrefix(name, sub) {
+               if len(name) == len(sub) || name[len(sub)] == '/' {
+                       return true
+               }
+       }
+       return false
+}
+
+type walkFunc func(*Grant, []*Grant) bool
+
+func foundWalkFunc(*Grant, []*Grant) bool {
+       return true
+}
+
+func (g *memoryGraph) walkGrants(start, target string, permission uint16, f walkFunc, chain []*Grant, visited map[*Grant]bool, collect bool) bool {
+       if visited == nil {
+               visited = map[*Grant]bool{}
+       }
+       grants := g.getGrants(start)
+       subGrants := make([]*Grant, 0, len(grants))
+       for _, grant := range grants {
+               if visited[grant] {
+                       continue
+               }
+               visited[grant] = true
+               if grant.Permission&permission == permission {
+                       if isSubName(target, grant.Subject) {
+                               if f(grant, chain) {
+                                       return true
+                               }
+                       } else {
+                               subGrants = append(subGrants, grant)
+                       }
+               }
+       }
+       for _, grant := range subGrants {
+               var chainCopy []*Grant
+               if collect {
+                       chainCopy = make([]*Grant, len(chain)+1)
+                       copy(chainCopy, chain)
+                       chainCopy[len(chainCopy)-1] = grant
+               } else {
+                       chainCopy = nil
+               }
+
+               if g.walkGrants(grant.Subject, target, permission, f, chainCopy, visited, collect) {
+                       return true
+               }
+       }
+       return false
+}
+
+func (g *memoryGraph) Verify(key libtrust.PublicKey, node string, permission uint16) (bool, error) {
+       return g.walkGrants(key.KeyID(), node, permission, foundWalkFunc, nil, nil, false), nil
+}
+
+func (g *memoryGraph) GetGrants(key libtrust.PublicKey, node string, permission uint16) ([][]*Grant, error) {
+       grants := [][]*Grant{}
+       collect := func(grant *Grant, chain []*Grant) bool {
+               grantChain := make([]*Grant, len(chain)+1)
+               copy(grantChain, chain)
+               grantChain[len(grantChain)-1] = grant
+               grants = append(grants, grantChain)
+               return false
+       }
+       g.walkGrants(key.KeyID(), node, permission, collect, nil, nil, true)
+       return grants, nil
+}
diff --git a/trustgraph/memory_graph_test.go b/trustgraph/memory_graph_test.go
new file mode 100644 (file)
index 0000000..49fd0f3
--- /dev/null
@@ -0,0 +1,174 @@
+package trustgraph
+
+import (
+       "fmt"
+       "testing"
+
+       "github.com/docker/libtrust"
+)
+
+func createTestKeysAndGrants(count int) ([]*Grant, []libtrust.PrivateKey) {
+       grants := make([]*Grant, count)
+       keys := make([]libtrust.PrivateKey, count)
+       for i := 0; i < count; i++ {
+               pk, err := libtrust.GenerateECP256PrivateKey()
+               if err != nil {
+                       panic(err)
+               }
+               grant := &Grant{
+                       Subject:    fmt.Sprintf("/user-%d", i+1),
+                       Permission: 0x0f,
+                       Grantee:    pk.KeyID(),
+               }
+               keys[i] = pk
+               grants[i] = grant
+       }
+       return grants, keys
+}
+
+func testVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) {
+       if ok, err := g.Verify(k, target, permission); err != nil {
+               t.Fatalf("Unexpected error during verification: %s", err)
+       } else if !ok {
+               t.Errorf("key failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target)
+       }
+}
+
+func testNotVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) {
+       if ok, err := g.Verify(k, target, permission); err != nil {
+               t.Fatalf("Unexpected error during verification: %s", err)
+       } else if ok {
+               t.Errorf("key should have failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target)
+       }
+}
+
+func TestVerify(t *testing.T) {
+       grants, keys := createTestKeysAndGrants(4)
+       extraGrants := make([]*Grant, 3)
+       extraGrants[0] = &Grant{
+               Subject:    "/user-3",
+               Permission: 0x0f,
+               Grantee:    "/user-2",
+       }
+       extraGrants[1] = &Grant{
+               Subject:    "/user-3/sub-project",
+               Permission: 0x0f,
+               Grantee:    "/user-4",
+       }
+       extraGrants[2] = &Grant{
+               Subject:    "/user-4",
+               Permission: 0x07,
+               Grantee:    "/user-1",
+       }
+       grants = append(grants, extraGrants...)
+
+       g := NewMemoryGraph(grants)
+
+       testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+       testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1/some-project/sub-value", 0x0f)
+       testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x07)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2/", 0x0f)
+       testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3/sub-value", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-value", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/", 0x0f)
+       testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f)
+       testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project/app", 0x0f)
+       testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f)
+
+       testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f)
+       testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3/sub-value", 0x0f)
+       testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x0f)
+       testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1/", 0x0f)
+       testNotVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-2", 0x0f)
+       testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-4", 0x0f)
+       testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f)
+}
+
+func TestCircularWalk(t *testing.T) {
+       grants, keys := createTestKeysAndGrants(3)
+       user1Grant := &Grant{
+               Subject:    "/user-2",
+               Permission: 0x0f,
+               Grantee:    "/user-1",
+       }
+       user2Grant := &Grant{
+               Subject:    "/user-1",
+               Permission: 0x0f,
+               Grantee:    "/user-2",
+       }
+       grants = append(grants, user1Grant, user2Grant)
+
+       g := NewMemoryGraph(grants)
+
+       testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+       testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1", 0x0f)
+       testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f)
+
+       testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3", 0x0f)
+       testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
+}
+
+func assertGrantSame(t *testing.T, actual, expected *Grant) {
+       if actual != expected {
+               t.Fatalf("Unexpected grant retrieved\n\tExpected: %v\n\tActual: %v", expected, actual)
+       }
+}
+
+func TestGetGrants(t *testing.T) {
+       grants, keys := createTestKeysAndGrants(5)
+       extraGrants := make([]*Grant, 4)
+       extraGrants[0] = &Grant{
+               Subject:    "/user-3/friend-project",
+               Permission: 0x0f,
+               Grantee:    "/user-2/friends",
+       }
+       extraGrants[1] = &Grant{
+               Subject:    "/user-3/sub-project",
+               Permission: 0x0f,
+               Grantee:    "/user-4",
+       }
+       extraGrants[2] = &Grant{
+               Subject:    "/user-2/friends",
+               Permission: 0x0f,
+               Grantee:    "/user-5/fun-project",
+       }
+       extraGrants[3] = &Grant{
+               Subject:    "/user-5/fun-project",
+               Permission: 0x0f,
+               Grantee:    "/user-1",
+       }
+       grants = append(grants, extraGrants...)
+
+       g := NewMemoryGraph(grants)
+
+       grantChains, err := g.GetGrants(keys[3], "/user-3/sub-project/specific-app", 0x0f)
+       if err != nil {
+               t.Fatalf("Error getting grants: %s", err)
+       }
+       if len(grantChains) != 1 {
+               t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains))
+       }
+       if len(grantChains[0]) != 2 {
+               t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0]))
+       }
+       assertGrantSame(t, grantChains[0][0], grants[3])
+       assertGrantSame(t, grantChains[0][1], extraGrants[1])
+
+       grantChains, err = g.GetGrants(keys[0], "/user-3/friend-project/fun-app", 0x0f)
+       if err != nil {
+               t.Fatalf("Error getting grants: %s", err)
+       }
+       if len(grantChains) != 1 {
+               t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains))
+       }
+       if len(grantChains[0]) != 4 {
+               t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0]))
+       }
+       assertGrantSame(t, grantChains[0][0], grants[0])
+       assertGrantSame(t, grantChains[0][1], extraGrants[3])
+       assertGrantSame(t, grantChains[0][2], extraGrants[2])
+       assertGrantSame(t, grantChains[0][3], extraGrants[0])
+}
diff --git a/trustgraph/statement.go b/trustgraph/statement.go
new file mode 100644 (file)
index 0000000..7a74b55
--- /dev/null
@@ -0,0 +1,227 @@
+package trustgraph
+
+import (
+       "crypto/x509"
+       "encoding/json"
+       "io"
+       "io/ioutil"
+       "sort"
+       "strings"
+       "time"
+
+       "github.com/docker/libtrust"
+)
+
+type jsonGrant struct {
+       Subject    string `json:"subject"`
+       Permission uint16 `json:"permission"`
+       Grantee    string `json:"grantee"`
+}
+
+type jsonRevocation struct {
+       Subject    string `json:"subject"`
+       Revocation uint16 `json:"revocation"`
+       Grantee    string `json:"grantee"`
+}
+
+type jsonStatement struct {
+       Revocations []*jsonRevocation `json:"revocations"`
+       Grants      []*jsonGrant      `json:"grants"`
+       Expiration  time.Time         `json:"expiration"`
+       IssuedAt    time.Time         `json:"issuedAt"`
+}
+
+func (g *jsonGrant) Grant(statement *Statement) *Grant {
+       return &Grant{
+               Subject:    g.Subject,
+               Permission: g.Permission,
+               Grantee:    g.Grantee,
+               statement:  statement,
+       }
+}
+
+// Statement represents a set of grants made from a verifiable
+// authority.  A statement has an expiration associated with it
+// set by the authority.
+type Statement struct {
+       jsonStatement
+
+       signature *libtrust.JSONSignature
+}
+
+// IsExpired returns whether the statement has expired
+func (s *Statement) IsExpired() bool {
+       return s.Expiration.Before(time.Now().Add(-10 * time.Second))
+}
+
+// Bytes returns an indented json representation of the statement
+// in a byte array.  This value can be written to a file or stream
+// without alteration.
+func (s *Statement) Bytes() ([]byte, error) {
+       return s.signature.PrettySignature("signatures")
+}
+
+// LoadStatement loads and verifies a statement from an input stream.
+func LoadStatement(r io.Reader, authority *x509.CertPool) (*Statement, error) {
+       b, err := ioutil.ReadAll(r)
+       if err != nil {
+               return nil, err
+       }
+       js, err := libtrust.ParsePrettySignature(b, "signatures")
+       if err != nil {
+               return nil, err
+       }
+       payload, err := js.Payload()
+       if err != nil {
+               return nil, err
+       }
+       var statement Statement
+       err = json.Unmarshal(payload, &statement.jsonStatement)
+       if err != nil {
+               return nil, err
+       }
+
+       if authority == nil {
+               _, err = js.Verify()
+               if err != nil {
+                       return nil, err
+               }
+       } else {
+               _, err = js.VerifyChains(authority)
+               if err != nil {
+                       return nil, err
+               }
+       }
+       statement.signature = js
+
+       return &statement, nil
+}
+
+// CreateStatements creates and signs a statement from a stream of grants
+// and revocations in a JSON array.
+func CreateStatement(grants, revocations io.Reader, expiration time.Duration, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) {
+       var statement Statement
+       err := json.NewDecoder(grants).Decode(&statement.jsonStatement.Grants)
+       if err != nil {
+               return nil, err
+       }
+       err = json.NewDecoder(revocations).Decode(&statement.jsonStatement.Revocations)
+       if err != nil {
+               return nil, err
+       }
+       statement.jsonStatement.Expiration = time.Now().UTC().Add(expiration)
+       statement.jsonStatement.IssuedAt = time.Now().UTC()
+
+       b, err := json.MarshalIndent(&statement.jsonStatement, "", "   ")
+       if err != nil {
+               return nil, err
+       }
+
+       statement.signature, err = libtrust.NewJSONSignature(b)
+       if err != nil {
+               return nil, err
+       }
+       err = statement.signature.SignWithChain(key, chain)
+       if err != nil {
+               return nil, err
+       }
+
+       return &statement, nil
+}
+
+type statementList []*Statement
+
+func (s statementList) Len() int {
+       return len(s)
+}
+
+func (s statementList) Less(i, j int) bool {
+       return s[i].IssuedAt.Before(s[j].IssuedAt)
+}
+
+func (s statementList) Swap(i, j int) {
+       s[i], s[j] = s[j], s[i]
+}
+
+// CollapseStatements returns a single list of the valid statements as well as the
+// time when the next grant will expire.
+func CollapseStatements(statements []*Statement, useExpired bool) ([]*Grant, time.Time, error) {
+       sorted := make(statementList, 0, len(statements))
+       for _, statement := range statements {
+               if useExpired || !statement.IsExpired() {
+                       sorted = append(sorted, statement)
+               }
+       }
+       sort.Sort(sorted)
+
+       var minExpired time.Time
+       var grantCount int
+       roots := map[string]*grantNode{}
+       for i, statement := range sorted {
+               if statement.Expiration.Before(minExpired) || i == 0 {
+                       minExpired = statement.Expiration
+               }
+               for _, grant := range statement.Grants {
+                       parts := strings.Split(grant.Grantee, "/")
+                       nodes := roots
+                       g := grant.Grant(statement)
+                       grantCount = grantCount + 1
+
+                       for _, part := range parts {
+                               node, nodeOk := nodes[part]
+                               if !nodeOk {
+                                       node = newGrantNode()
+                                       nodes[part] = node
+                               }
+                               node.grants = append(node.grants, g)
+                               nodes = node.children
+                       }
+               }
+
+               for _, revocation := range statement.Revocations {
+                       parts := strings.Split(revocation.Grantee, "/")
+                       nodes := roots
+
+                       var node *grantNode
+                       var nodeOk bool
+                       for _, part := range parts {
+                               node, nodeOk = nodes[part]
+                               if !nodeOk {
+                                       break
+                               }
+                               nodes = node.children
+                       }
+                       if node != nil {
+                               for _, grant := range node.grants {
+                                       if isSubName(grant.Subject, revocation.Subject) {
+                                               grant.Permission = grant.Permission &^ revocation.Revocation
+                                       }
+                               }
+                       }
+               }
+       }
+
+       retGrants := make([]*Grant, 0, grantCount)
+       for _, rootNodes := range roots {
+               retGrants = append(retGrants, rootNodes.grants...)
+       }
+
+       return retGrants, minExpired, nil
+}
+
+// FilterStatements filters the statements to statements including the given grants.
+func FilterStatements(grants []*Grant) ([]*Statement, error) {
+       statements := map[*Statement]bool{}
+       for _, grant := range grants {
+               if grant.statement != nil {
+                       statements[grant.statement] = true
+               }
+       }
+       retStatements := make([]*Statement, len(statements))
+       var i int
+       for statement := range statements {
+               retStatements[i] = statement
+               i++
+       }
+       return retStatements, nil
+}
diff --git a/trustgraph/statement_test.go b/trustgraph/statement_test.go
new file mode 100644 (file)
index 0000000..d9c3c1a
--- /dev/null
@@ -0,0 +1,417 @@
+package trustgraph
+
+import (
+       "bytes"
+       "crypto/x509"
+       "encoding/json"
+       "testing"
+       "time"
+
+       "github.com/docker/libtrust"
+       "github.com/docker/libtrust/testutil"
+)
+
+const testStatementExpiration = time.Hour * 5
+
+func generateStatement(grants []*Grant, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) {
+       var statement Statement
+
+       statement.Grants = make([]*jsonGrant, len(grants))
+       for i, grant := range grants {
+               statement.Grants[i] = &jsonGrant{
+                       Subject:    grant.Subject,
+                       Permission: grant.Permission,
+                       Grantee:    grant.Grantee,
+               }
+       }
+       statement.IssuedAt = time.Now()
+       statement.Expiration = time.Now().Add(testStatementExpiration)
+       statement.Revocations = make([]*jsonRevocation, 0)
+
+       marshalled, err := json.MarshalIndent(statement.jsonStatement, "", "   ")
+       if err != nil {
+               return nil, err
+       }
+
+       sig, err := libtrust.NewJSONSignature(marshalled)
+       if err != nil {
+               return nil, err
+       }
+       err = sig.SignWithChain(key, chain)
+       if err != nil {
+               return nil, err
+       }
+       statement.signature = sig
+
+       return &statement, nil
+}
+
+func generateTrustChain(t *testing.T, chainLen int) (libtrust.PrivateKey, *x509.CertPool, []*x509.Certificate) {
+       caKey, err := libtrust.GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generating key: %s", err)
+       }
+       ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
+       if err != nil {
+               t.Fatalf("Error generating ca: %s", err)
+       }
+
+       parent := ca
+       parentKey := caKey
+       chain := make([]*x509.Certificate, chainLen)
+       for i := chainLen - 1; i > 0; i-- {
+               intermediatekey, err := libtrust.GenerateECP256PrivateKey()
+               if err != nil {
+                       t.Fatalf("Error generate key: %s", err)
+               }
+               chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+               if err != nil {
+                       t.Fatalf("Error generating intermdiate certificate: %s", err)
+               }
+               parent = chain[i]
+               parentKey = intermediatekey
+       }
+       trustKey, err := libtrust.GenerateECP256PrivateKey()
+       if err != nil {
+               t.Fatalf("Error generate key: %s", err)
+       }
+       chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+       if err != nil {
+               t.Fatalf("Error generate trust cert: %s", err)
+       }
+
+       caPool := x509.NewCertPool()
+       caPool.AddCert(ca)
+
+       return trustKey, caPool, chain
+}
+
+func TestLoadStatement(t *testing.T) {
+       grantCount := 4
+       grants, _ := createTestKeysAndGrants(grantCount)
+
+       trustKey, caPool, chain := generateTrustChain(t, 6)
+
+       statement, err := generateStatement(grants, trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+
+       statementBytes, err := statement.Bytes()
+       if err != nil {
+               t.Fatalf("Error getting statement bytes: %s", err)
+       }
+
+       s2, err := LoadStatement(bytes.NewReader(statementBytes), caPool)
+       if err != nil {
+               t.Fatalf("Error loading statement: %s", err)
+       }
+       if len(s2.Grants) != grantCount {
+               t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants))
+       }
+
+       pool := x509.NewCertPool()
+       _, err = LoadStatement(bytes.NewReader(statementBytes), pool)
+       if err == nil {
+               t.Fatalf("No error thrown verifying without an authority")
+       } else if _, ok := err.(x509.UnknownAuthorityError); !ok {
+               t.Fatalf("Unexpected error verifying without authority: %s", err)
+       }
+
+       s2, err = LoadStatement(bytes.NewReader(statementBytes), nil)
+       if err != nil {
+               t.Fatalf("Error loading statement: %s", err)
+       }
+       if len(s2.Grants) != grantCount {
+               t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants))
+       }
+
+       badData := make([]byte, len(statementBytes))
+       copy(badData, statementBytes)
+       badData[0] = '['
+       _, err = LoadStatement(bytes.NewReader(badData), nil)
+       if err == nil {
+               t.Fatalf("No error thrown parsing bad json")
+       }
+
+       alteredData := make([]byte, len(statementBytes))
+       copy(alteredData, statementBytes)
+       alteredData[30] = '0'
+       _, err = LoadStatement(bytes.NewReader(alteredData), nil)
+       if err == nil {
+               t.Fatalf("No error thrown from bad data")
+       }
+}
+
+func TestCollapseGrants(t *testing.T) {
+       grantCount := 8
+       grants, keys := createTestKeysAndGrants(grantCount)
+       linkGrants := make([]*Grant, 4)
+       linkGrants[0] = &Grant{
+               Subject:    "/user-3",
+               Permission: 0x0f,
+               Grantee:    "/user-2",
+       }
+       linkGrants[1] = &Grant{
+               Subject:    "/user-3/sub-project",
+               Permission: 0x0f,
+               Grantee:    "/user-4",
+       }
+       linkGrants[2] = &Grant{
+               Subject:    "/user-6",
+               Permission: 0x0f,
+               Grantee:    "/user-7",
+       }
+       linkGrants[3] = &Grant{
+               Subject:    "/user-6/sub-project/specific-app",
+               Permission: 0x0f,
+               Grantee:    "/user-5",
+       }
+       trustKey, pool, chain := generateTrustChain(t, 3)
+
+       statements := make([]*Statement, 3)
+       var err error
+       statements[0], err = generateStatement(grants[0:4], trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       statements[1], err = generateStatement(grants[4:], trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       statements[2], err = generateStatement(linkGrants, trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+
+       statementsCopy := make([]*Statement, len(statements))
+       for i, statement := range statements {
+               b, err := statement.Bytes()
+               if err != nil {
+                       t.Fatalf("Error getting statement bytes: %s", err)
+               }
+               verifiedStatement, err := LoadStatement(bytes.NewReader(b), pool)
+               if err != nil {
+                       t.Fatalf("Error loading statement: %s", err)
+               }
+               // Force sort by reversing order
+               statementsCopy[len(statementsCopy)-i-1] = verifiedStatement
+       }
+       statements = statementsCopy
+
+       collapsedGrants, expiration, err := CollapseStatements(statements, false)
+       if len(collapsedGrants) != 12 {
+               t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %s", 12, len(collapsedGrants))
+       }
+       if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) {
+               t.Fatalf("Unexpected expiration time: %s", expiration.String())
+       }
+       g := NewMemoryGraph(collapsedGrants)
+
+       testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
+       testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f)
+       testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f)
+       testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-5", 0x0f)
+       testVerified(t, g, keys[5].PublicKey(), "user-key-6", "/user-6", 0x0f)
+       testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-7", 0x0f)
+       testVerified(t, g, keys[7].PublicKey(), "user-key-8", "/user-8", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-project/specific-app", 0x0f)
+       testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f)
+       testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6", 0x0f)
+       testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f)
+       testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project/specific-app", 0x0f)
+
+       testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f)
+       testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-6/sub-project", 0x0f)
+       testNotVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project", 0x0f)
+
+       // Add revocation grant
+       statements = append(statements, &Statement{
+               jsonStatement{
+                       IssuedAt:   time.Now(),
+                       Expiration: time.Now().Add(testStatementExpiration),
+                       Grants:     []*jsonGrant{},
+                       Revocations: []*jsonRevocation{
+                               &jsonRevocation{
+                                       Subject:    "/user-1",
+                                       Revocation: 0x0f,
+                                       Grantee:    keys[0].KeyID(),
+                               },
+                               &jsonRevocation{
+                                       Subject:    "/user-2",
+                                       Revocation: 0x08,
+                                       Grantee:    keys[1].KeyID(),
+                               },
+                               &jsonRevocation{
+                                       Subject:    "/user-6",
+                                       Revocation: 0x0f,
+                                       Grantee:    "/user-7",
+                               },
+                               &jsonRevocation{
+                                       Subject:    "/user-9",
+                                       Revocation: 0x0f,
+                                       Grantee:    "/user-10",
+                               },
+                       },
+               },
+               nil,
+       })
+
+       collapsedGrants, expiration, err = CollapseStatements(statements, false)
+       if len(collapsedGrants) != 12 {
+               t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %s", 12, len(collapsedGrants))
+       }
+       if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) {
+               t.Fatalf("Unexpected expiration time: %s", expiration.String())
+       }
+       g = NewMemoryGraph(collapsedGrants)
+
+       testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+       testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
+       testNotVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f)
+
+       testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x07)
+}
+
+func TestFilterStatements(t *testing.T) {
+       grantCount := 8
+       grants, keys := createTestKeysAndGrants(grantCount)
+       linkGrants := make([]*Grant, 3)
+       linkGrants[0] = &Grant{
+               Subject:    "/user-3",
+               Permission: 0x0f,
+               Grantee:    "/user-2",
+       }
+       linkGrants[1] = &Grant{
+               Subject:    "/user-5",
+               Permission: 0x0f,
+               Grantee:    "/user-4",
+       }
+       linkGrants[2] = &Grant{
+               Subject:    "/user-7",
+               Permission: 0x0f,
+               Grantee:    "/user-6",
+       }
+
+       trustKey, _, chain := generateTrustChain(t, 3)
+
+       statements := make([]*Statement, 5)
+       var err error
+       statements[0], err = generateStatement(grants[0:2], trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       statements[1], err = generateStatement(grants[2:4], trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       statements[2], err = generateStatement(grants[4:6], trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       statements[3], err = generateStatement(grants[6:], trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       statements[4], err = generateStatement(linkGrants, trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error generating statement: %s", err)
+       }
+       collapsed, _, err := CollapseStatements(statements, false)
+       if err != nil {
+               t.Fatalf("Error collapsing grants: %s", err)
+       }
+
+       // Filter 1, all 5 statements
+       filter1, err := FilterStatements(collapsed)
+       if err != nil {
+               t.Fatalf("Error filtering statements: %s", err)
+       }
+       if len(filter1) != 5 {
+               t.Fatalf("Wrong number of statements, expected %d, received %d", 5, len(filter1))
+       }
+
+       // Filter 2, one statement
+       filter2, err := FilterStatements([]*Grant{collapsed[0]})
+       if err != nil {
+               t.Fatalf("Error filtering statements: %s", err)
+       }
+       if len(filter2) != 1 {
+               t.Fatalf("Wrong number of statements, expected %d, received %d", 1, len(filter2))
+       }
+
+       // Filter 3, 2 statements, from graph lookup
+       g := NewMemoryGraph(collapsed)
+       lookupGrants, err := g.GetGrants(keys[1], "/user-3", 0x0f)
+       if err != nil {
+               t.Fatalf("Error looking up grants: %s", err)
+       }
+       if len(lookupGrants) != 1 {
+               t.Fatalf("Wrong numberof grant chains returned from lookup, expected %d, received %d", 1, len(lookupGrants))
+       }
+       if len(lookupGrants[0]) != 2 {
+               t.Fatalf("Wrong number of grants looked up, expected %d, received %d", 2, len(lookupGrants))
+       }
+       filter3, err := FilterStatements(lookupGrants[0])
+       if err != nil {
+               t.Fatalf("Error filtering statements: %s", err)
+       }
+       if len(filter3) != 2 {
+               t.Fatalf("Wrong number of statements, expected %d, received %d", 2, len(filter3))
+       }
+
+}
+
+func TestCreateStatement(t *testing.T) {
+       grantJSON := bytes.NewReader([]byte(`[
+   {
+      "subject": "/user-2",
+      "permission": 15,
+      "grantee": "/user-1"
+   },
+   {
+      "subject": "/user-7",
+      "permission": 1,
+      "grantee": "/user-9"
+   },
+   {
+      "subject": "/user-3",
+      "permission": 15,
+      "grantee": "/user-2"
+   }
+]`))
+       revocationJSON := bytes.NewReader([]byte(`[
+   {
+      "subject": "user-8",
+      "revocation": 12,
+      "grantee": "user-9"
+   }
+]`))
+
+       trustKey, pool, chain := generateTrustChain(t, 3)
+
+       statement, err := CreateStatement(grantJSON, revocationJSON, testStatementExpiration, trustKey, chain)
+       if err != nil {
+               t.Fatalf("Error creating statement: %s", err)
+       }
+
+       b, err := statement.Bytes()
+       if err != nil {
+               t.Fatalf("Error retrieving bytes: %s", err)
+       }
+
+       verified, err := LoadStatement(bytes.NewReader(b), pool)
+       if err != nil {
+               t.Fatalf("Error loading statement: %s", err)
+       }
+
+       if len(verified.Grants) != 3 {
+               t.Errorf("Unexpected number of grants, expected %d, received %d", 3, len(verified.Grants))
+       }
+
+       if len(verified.Revocations) != 1 {
+               t.Errorf("Unexpected number of revocations, expected %d, received %d", 1, len(verified.Revocations))
+       }
+}
diff --git a/util.go b/util.go
new file mode 100644 (file)
index 0000000..3b2fac9
--- /dev/null
+++ b/util.go
@@ -0,0 +1,209 @@
+package libtrust
+
+import (
+       "bytes"
+       "crypto/elliptic"
+       "crypto/x509"
+       "encoding/base32"
+       "encoding/base64"
+       "encoding/binary"
+       "encoding/pem"
+       "errors"
+       "fmt"
+       "math/big"
+       "strings"
+)
+
+// joseBase64UrlEncode encodes the given data using the standard base64 url
+// encoding format but with all trailing '=' characters ommitted in accordance
+// with the jose specification.
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
+func joseBase64UrlEncode(b []byte) string {
+       return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
+}
+
+// joseBase64UrlDecode decodes the given string using the standard base64 url
+// decoder but first adds the appropriate number of trailing '=' characters in
+// accordance with the jose specification.
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
+func joseBase64UrlDecode(s string) ([]byte, error) {
+       switch len(s) % 4 {
+       case 0:
+       case 2:
+               s += "=="
+       case 3:
+               s += "="
+       default:
+               return nil, errors.New("illegal base64url string")
+       }
+       return base64.URLEncoding.DecodeString(s)
+}
+
+func keyIDEncode(b []byte) string {
+       s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=")
+       var buf bytes.Buffer
+       var i int
+       for i = 0; i < len(s)/4-1; i++ {
+               start := i * 4
+               end := start + 4
+               buf.WriteString(s[start:end] + ":")
+       }
+       buf.WriteString(s[i*4:])
+       return buf.String()
+}
+
+func stringFromMap(m map[string]interface{}, key string) (string, error) {
+       val, ok := m[key]
+       if !ok {
+               return "", fmt.Errorf("%q value not specified", key)
+       }
+
+       str, ok := val.(string)
+       if !ok {
+               return "", fmt.Errorf("%q value must be a string", key)
+       }
+       delete(m, key)
+
+       return str, nil
+}
+
+func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) {
+       curveByteLen := (curve.Params().BitSize + 7) >> 3
+
+       cBytes, err := joseBase64UrlDecode(cB64Url)
+       if err != nil {
+               return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+       }
+       cByteLength := len(cBytes)
+       if cByteLength != curveByteLen {
+               return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen)
+       }
+       return new(big.Int).SetBytes(cBytes), nil
+}
+
+func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) {
+       dBytes, err := joseBase64UrlDecode(dB64Url)
+       if err != nil {
+               return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+       }
+
+       // The length of this octet string MUST be ceiling(log-base-2(n)/8)
+       // octets (where n is the order of the curve). This is because the private
+       // key d must be in the interval [1, n-1] so the bitlength of d should be
+       // no larger than the bitlength of n-1. The easiest way to find the octet
+       // length is to take bitlength(n-1), add 7 to force a carry, and shift this
+       // bit sequence right by 3, which is essentially dividing by 8 and adding
+       // 1 if there is any remainder. Thus, the private key value d should be
+       // output to (bitlength(n-1)+7)>>3 octets.
+       n := curve.Params().N
+       octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
+       dByteLength := len(dBytes)
+
+       if dByteLength != octetLength {
+               return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength)
+       }
+
+       return new(big.Int).SetBytes(dBytes), nil
+}
+
+func parseRSAModulusParam(nB64Url string) (*big.Int, error) {
+       nBytes, err := joseBase64UrlDecode(nB64Url)
+       if err != nil {
+               return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+       }
+
+       return new(big.Int).SetBytes(nBytes), nil
+}
+
+func serializeRSAPublicExponentParam(e int) []byte {
+       // We MUST use the minimum number of octets to represent E.
+       // E is supposed to be 65537 for performance and security reasons
+       // and is what golang's rsa package generates, but it might be
+       // different if imported from some other generator.
+       buf := make([]byte, 4)
+       binary.BigEndian.PutUint32(buf, uint32(e))
+       var i int
+       for i = 0; i < 8; i++ {
+               if buf[i] != 0 {
+                       break
+               }
+       }
+       return buf[i:]
+}
+
+func parseRSAPublicExponentParam(eB64Url string) (int, error) {
+       eBytes, err := joseBase64UrlDecode(eB64Url)
+       if err != nil {
+               return 0, fmt.Errorf("invalid base64 URL encoding: %s", err)
+       }
+       // Only the minimum number of bytes were used to represent E, but
+       // binary.BigEndian.Uint32 expects at least 4 bytes, so we need
+       // to add zero padding if necassary.
+       byteLen := len(eBytes)
+       buf := make([]byte, 4-byteLen, 4)
+       eBytes = append(buf, eBytes...)
+
+       return int(binary.BigEndian.Uint32(eBytes)), nil
+}
+
+func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) {
+       b64Url, err := stringFromMap(m, key)
+       if err != nil {
+               return nil, err
+       }
+
+       paramBytes, err := joseBase64UrlDecode(b64Url)
+       if err != nil {
+               return nil, fmt.Errorf("invaled base64 URL encoding: %s", err)
+       }
+
+       return new(big.Int).SetBytes(paramBytes), nil
+}
+
+func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) {
+       pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}}
+       for k, v := range headers {
+               switch val := v.(type) {
+               case string:
+                       pemBlock.Headers[k] = val
+               case []string:
+                       if k == "hosts" {
+                               pemBlock.Headers[k] = strings.Join(val, ",")
+                       } else {
+                               // Return error, non-encodable type
+                       }
+               default:
+                       // Return error, non-encodable type
+               }
+       }
+
+       return pemBlock, nil
+}
+
+func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) {
+       cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
+       if err != nil {
+               return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err)
+       }
+
+       pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+       if err != nil {
+               return nil, err
+       }
+
+       addPEMHeadersToKey(pemBlock, pubKey)
+
+       return pubKey, nil
+}
+
+func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) {
+       for key, value := range pemBlock.Headers {
+               var safeVal interface{}
+               if key == "hosts" {
+                       safeVal = strings.Split(value, ",")
+               } else {
+                       safeVal = value
+               }
+               pubKey.AddExtendedField(key, safeVal)
+       }
+}