Refactor *platformTPM -> tpmBase interface (#160)

This commit is contained in:
Tom D 2020-05-05 14:56:40 -07:00 committed by GitHub
parent ab116a02a1
commit 1045ef6327
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 99 additions and 80 deletions

View File

@ -85,10 +85,10 @@ const (
)
type ak interface {
close(*platformTPM) error
close(tpmBase) error
marshal() ([]byte, error)
activateCredential(tpm *platformTPM, in EncryptedCredential) ([]byte, error)
quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error)
activateCredential(tpm tpmBase, in EncryptedCredential) ([]byte, error)
quote(t tpmBase, nonce []byte, alg HashAlg) (*Quote, error)
attestationParameters() AttestationParameters
}

View File

@ -33,7 +33,7 @@ func setupSimulatedTPM(t *testing.T) (*simulator.Simulator, *TPM) {
if err != nil {
t.Fatal(err)
}
return tpm, &TPM{tpm: &platformTPM{
return tpm, &TPM{tpm: &linuxTPM{
version: TPMVersion20,
interf: TPMInterfaceKernelManaged,
sysPath: "/dev/tpmrm0",
@ -240,7 +240,7 @@ func TestSimTPM20Persistence(t *testing.T) {
sim, tpm := setupSimulatedTPM(t)
defer sim.Close()
ekHnd, _, err := tpm.tpm.getPrimaryKeyHandle(commonEkEquivalentHandle)
ekHnd, _, err := tpm.tpm.(*linuxTPM).getPrimaryKeyHandle(commonEkEquivalentHandle)
if err != nil {
t.Fatalf("getPrimaryKeyHandle() failed: %v", err)
}
@ -248,7 +248,7 @@ func TestSimTPM20Persistence(t *testing.T) {
t.Fatalf("bad EK-equivalent handle: got 0x%x, wanted 0x%x", ekHnd, commonEkEquivalentHandle)
}
ekHnd, p, err := tpm.tpm.getPrimaryKeyHandle(commonEkEquivalentHandle)
ekHnd, p, err := tpm.tpm.(*linuxTPM).getPrimaryKeyHandle(commonEkEquivalentHandle)
if err != nil {
t.Fatalf("second getPrimaryKeyHandle() failed: %v", err)
}

View File

@ -49,11 +49,16 @@ func (k *key12) marshal() ([]byte, error) {
return out.Serialize()
}
func (k *key12) close(tpm *platformTPM) error {
func (k *key12) close(tpm tpmBase) error {
return nil // No state for tpm 1.2.
}
func (k *key12) activateCredential(t *platformTPM, in EncryptedCredential) ([]byte, error) {
func (k *key12) activateCredential(tb tpmBase, in EncryptedCredential) ([]byte, error) {
t, ok := tb.(*linuxTPM)
if !ok {
return nil, fmt.Errorf("expected *linuxTPM, got %T", tb)
}
cred, err := attestation.AIKChallengeResponse(t.ctx, k.blob, in.Credential, in.Secret)
if err != nil {
return nil, fmt.Errorf("failed to activate ak: %v", err)
@ -61,7 +66,11 @@ func (k *key12) activateCredential(t *platformTPM, in EncryptedCredential) ([]by
return cred, nil
}
func (k *key12) quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error) {
func (k *key12) quote(tb tpmBase, nonce []byte, alg HashAlg) (*Quote, error) {
t, ok := tb.(*linuxTPM)
if !ok {
return nil, fmt.Errorf("expected *linuxTPM, got %T", tb)
}
if alg != HashSHA1 {
return nil, fmt.Errorf("only SHA1 algorithms supported on TPM 1.2, not %v", alg)
}
@ -120,11 +129,20 @@ func (k *key20) marshal() ([]byte, error) {
}).Serialize()
}
func (k *key20) close(tpm *platformTPM) error {
func (k *key20) close(t tpmBase) error {
tpm, ok := t.(*linuxTPM)
if !ok {
return fmt.Errorf("expected *linuxTPM, got %T", t)
}
return tpm2.FlushContext(tpm.rwc, k.hnd)
}
func (k *key20) activateCredential(t *platformTPM, in EncryptedCredential) ([]byte, error) {
func (k *key20) activateCredential(tb tpmBase, in EncryptedCredential) ([]byte, error) {
t, ok := tb.(*linuxTPM)
if !ok {
return nil, fmt.Errorf("expected *linuxTPM, got %T", tb)
}
ekHnd, _, err := t.getPrimaryKeyHandle(commonEkEquivalentHandle)
if err != nil {
return nil, err
@ -154,7 +172,11 @@ func (k *key20) activateCredential(t *platformTPM, in EncryptedCredential) ([]by
}, k.hnd, ekHnd, in.Credential[2:], in.Secret[2:])
}
func (k *key20) quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error) {
func (k *key20) quote(tb tpmBase, nonce []byte, alg HashAlg) (*Quote, error) {
t, ok := tb.(*linuxTPM)
if !ok {
return nil, fmt.Errorf("expected *linuxTPM, got %T", tb)
}
return quote20(t.rwc, k.hnd, tpm2.Algorithm(alg), nonce)
}

View File

@ -47,7 +47,11 @@ func (k *key12) marshal() ([]byte, error) {
return out.Serialize()
}
func (k *key12) activateCredential(tpm *platformTPM, in EncryptedCredential) ([]byte, error) {
func (k *key12) activateCredential(t tpmBase, in EncryptedCredential) ([]byte, error) {
tpm, ok := t.(*windowsTPM)
if !ok {
return nil, fmt.Errorf("expected *windowsTPM, got %T", t)
}
secretKey, err := tpm.pcp.ActivateCredential(k.hnd, in.Credential)
if err != nil {
return nil, err
@ -55,10 +59,14 @@ func (k *key12) activateCredential(tpm *platformTPM, in EncryptedCredential) ([]
return decryptCredential(secretKey, in.Secret)
}
func (k *key12) quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error) {
func (k *key12) quote(tb tpmBase, nonce []byte, alg HashAlg) (*Quote, error) {
if alg != HashSHA1 {
return nil, fmt.Errorf("only SHA1 algorithms supported on TPM 1.2, not %v", alg)
}
t, ok := tb.(*windowsTPM)
if !ok {
return nil, fmt.Errorf("expected *windowsTPM, got %T", tb)
}
tpmKeyHnd, err := t.pcp.TPMKeyHandle(k.hnd)
if err != nil {
@ -93,7 +101,7 @@ func (k *key12) quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error)
}, nil
}
func (k *key12) close(tpm *platformTPM) error {
func (k *key12) close(tpm tpmBase) error {
return closeNCryptObject(k.hnd)
}
@ -139,11 +147,19 @@ func (k *key20) marshal() ([]byte, error) {
return out.Serialize()
}
func (k *key20) activateCredential(tpm *platformTPM, in EncryptedCredential) ([]byte, error) {
func (k *key20) activateCredential(t tpmBase, in EncryptedCredential) ([]byte, error) {
tpm, ok := t.(*windowsTPM)
if !ok {
return nil, fmt.Errorf("expected *windowsTPM, got %T", t)
}
return tpm.pcp.ActivateCredential(k.hnd, append(in.Credential, in.Secret...))
}
func (k *key20) quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error) {
func (k *key20) quote(tb tpmBase, nonce []byte, alg HashAlg) (*Quote, error) {
t, ok := tb.(*windowsTPM)
if !ok {
return nil, fmt.Errorf("expected *windowsTPM, got %T", tb)
}
tpmKeyHnd, err := t.pcp.TPMKeyHandle(k.hnd)
if err != nil {
return nil, fmt.Errorf("TPMKeyHandle() failed: %v", err)
@ -156,7 +172,7 @@ func (k *key20) quote(t *platformTPM, nonce []byte, alg HashAlg) (*Quote, error)
return quote20(tpm, tpmKeyHnd, alg.goTPMAlg(), nonce)
}
func (k *key20) close(tpm *platformTPM) error {
func (k *key20) close(tpm tpmBase) error {
return closeNCryptObject(k.hnd)
}

View File

@ -267,11 +267,24 @@ func readAllPCRs20(tpm io.ReadWriter, alg tpm2.Algorithm) (map[uint32][]byte, er
return out, nil
}
// tpmBase defines the implementation of a TPM invariant.
type tpmBase interface {
close() error
tpmVersion() TPMVersion
eks() ([]EK, error)
info() (*TPMInfo, error)
loadAK(opaqueBlob []byte) (*AK, error)
newAK(opts *AKConfig) (*AK, error)
pcrs(alg HashAlg) ([]PCR, error)
measurementLog() ([]byte, error)
}
//TPM interfaces with a TPM device on the system.
type TPM struct {
// tpm holds a platform specific implementation of TPM logic: Windows or Linux.
// see *_linux.go and *_windows.go files for definitions of these structs.
tpm *platformTPM
// tpm refers to a concrete implementation of TPM logic, based on the current
// platform and TPM version.
tpm tpmBase
}
// Close shuts down the connection to the TPM.

View File

@ -42,8 +42,8 @@ const (
tpmRoot = "/sys/class/tpm"
)
// platformTPM interfaces with a TPM device on the system.
type platformTPM struct {
// linuxTPM interfaces with a TPM device on the system.
type linuxTPM struct {
version TPMVersion
interf TPMInterface
@ -52,6 +52,8 @@ type platformTPM struct {
ctx *tspi.Context
}
func (*linuxTPM) isTPMBase() {}
func probeSystemTPMs() ([]probedTPM, error) {
var tpms []probedTPM
@ -122,7 +124,7 @@ func openTPM(tpm probedTPM) (*TPM, error) {
}
}
return &TPM{tpm: &platformTPM{
return &TPM{tpm: &linuxTPM{
version: tpm.Version,
interf: interf,
sysPath: tpm.Path,
@ -131,11 +133,11 @@ func openTPM(tpm probedTPM) (*TPM, error) {
}}, nil
}
func (t *platformTPM) tpmVersion() TPMVersion {
func (t *linuxTPM) tpmVersion() TPMVersion {
return t.version
}
func (t *platformTPM) close() error {
func (t *linuxTPM) close() error {
switch t.version {
case TPMVersion12:
return t.ctx.Close()
@ -160,7 +162,7 @@ func readTPM12VendorAttributes(context *tspi.Context) (TCGVendorID, string, erro
}
// Info returns information about the TPM.
func (t *platformTPM) info() (*TPMInfo, error) {
func (t *linuxTPM) info() (*TPMInfo, error) {
tInfo := TPMInfo{
Version: t.version,
Interface: t.interf,
@ -188,7 +190,7 @@ func (t *platformTPM) info() (*TPMInfo, error) {
}
// Return value: handle, whether we generated a new one, error
func (t *platformTPM) getPrimaryKeyHandle(pHnd tpmutil.Handle) (tpmutil.Handle, bool, error) {
func (t *linuxTPM) getPrimaryKeyHandle(pHnd tpmutil.Handle) (tpmutil.Handle, bool, error) {
_, _, _, err := tpm2.ReadPublic(t.rwc, pHnd)
if err == nil {
// Found the persistent handle, assume it's the key we want.
@ -223,7 +225,7 @@ func readEKCertFromNVRAM12(ctx *tspi.Context) (*x509.Certificate, error) {
return ParseEKCertificate(ekCert)
}
func (t *platformTPM) eks() ([]EK, error) {
func (t *linuxTPM) eks() ([]EK, error) {
switch t.version {
case TPMVersion12:
cert, err := readEKCertFromNVRAM12(t.ctx)
@ -267,7 +269,7 @@ func (t *platformTPM) eks() ([]EK, error) {
}
}
func (t *platformTPM) newAK(opts *AKConfig) (*AK, error) {
func (t *linuxTPM) newAK(opts *AKConfig) (*AK, error) {
switch t.version {
case TPMVersion12:
pub, blob, err := attestation.CreateAIK(t.ctx)
@ -313,7 +315,7 @@ func (t *platformTPM) newAK(opts *AKConfig) (*AK, error) {
}
}
func (t *platformTPM) loadAK(opaqueBlob []byte) (*AK, error) {
func (t *linuxTPM) loadAK(opaqueBlob []byte) (*AK, error) {
sKey, err := deserializeKey(opaqueBlob, t.version)
if err != nil {
return nil, fmt.Errorf("deserializeKey() failed: %v", err)
@ -355,7 +357,7 @@ func allPCRs12(ctx *tspi.Context) (map[uint32][]byte, error) {
return PCRs, nil
}
func (t *platformTPM) pcrs(alg HashAlg) ([]PCR, error) {
func (t *linuxTPM) pcrs(alg HashAlg) ([]PCR, error) {
var PCRs map[uint32][]byte
var err error
@ -391,6 +393,6 @@ func (t *platformTPM) pcrs(alg HashAlg) ([]PCR, error) {
return out, nil
}
func (t *platformTPM) measurementLog() ([]byte, error) {
func (t *linuxTPM) measurementLog() ([]byte, error) {
return ioutil.ReadFile("/sys/kernel/security/tpm0/binary_bios_measurements")
}

View File

@ -22,9 +22,6 @@ import (
var errUnsupported = errors.New("tpm operations not supported from given build parameters")
type platformTPM struct {
}
func probeSystemTPMs() ([]probedTPM, error) {
return nil, errUnsupported
}
@ -32,36 +29,3 @@ func probeSystemTPMs() ([]probedTPM, error) {
func openTPM(tpm probedTPM) (*TPM, error) {
return nil, errUnsupported
}
func (t *platformTPM) tpmVersion() TPMVersion {
return TPMVersionAgnostic
}
func (t *platformTPM) close() error {
return errUnsupported
}
func (t *platformTPM) info() (*TPMInfo, error) {
return nil, errUnsupported
}
func (t *platformTPM) loadAK(opaqueBlob []byte) (*AK, error) {
return nil, errUnsupported
}
func (t *platformTPM) eks() ([]EK, error) {
return nil, errUnsupported
}
func (t *platformTPM) newAK(opts *AKConfig) (*AK, error) {
return nil, errUnsupported
}
func (t *platformTPM) pcrs(alg HashAlg) ([]PCR, error) {
return nil, errUnsupported
}
func (t *platformTPM) measurementLog() ([]byte, error) {
return nil, errUnsupported
}

View File

@ -34,11 +34,13 @@ import (
var wellKnownAuth [20]byte
type platformTPM struct {
type windowsTPM struct {
version TPMVersion
pcp *winPCP
}
func (*windowsTPM) isTPMBase() {}
func probeSystemTPMs() ([]probedTPM, error) {
// Windows systems appear to only support a single abstracted TPM.
// If we fail to initialize the Platform Crypto Provider, we assume
@ -89,17 +91,17 @@ func openTPM(tpm probedTPM) (*TPM, error) {
return nil, fmt.Errorf("tbsConvertVersion(%v) failed: %v", info.TBSInfo.TPMVersion, err)
}
return &TPM{tpm: &platformTPM{
return &TPM{tpm: &windowsTPM{
pcp: pcp,
version: vers,
}}, nil
}
func (t *platformTPM) tpmVersion() TPMVersion {
func (t *windowsTPM) tpmVersion() TPMVersion {
return t.version
}
func (t *platformTPM) close() error {
func (t *windowsTPM) close() error {
return t.pcp.Close()
}
@ -112,7 +114,7 @@ func readTPM12VendorAttributes(tpm io.ReadWriter) (TCGVendorID, string, error) {
return vendorID, vendorID.String(), nil
}
func (t *platformTPM) info() (*TPMInfo, error) {
func (t *windowsTPM) info() (*TPMInfo, error) {
tInfo := TPMInfo{
Version: t.version,
Interface: TPMInterfaceKernelManaged,
@ -142,7 +144,7 @@ func (t *platformTPM) info() (*TPMInfo, error) {
return &tInfo, nil
}
func (t *platformTPM) eks() ([]EK, error) {
func (t *windowsTPM) eks() ([]EK, error) {
ekCerts, err := t.pcp.EKCerts()
if err != nil {
return nil, fmt.Errorf("could not read EKCerts: %v", err)
@ -171,7 +173,7 @@ func (t *platformTPM) eks() ([]EK, error) {
return []EK{ek}, nil
}
func (t *platformTPM) ekPub() (*rsa.PublicKey, error) {
func (t *windowsTPM) ekPub() (*rsa.PublicKey, error) {
p, err := t.pcp.EKPub()
if err != nil {
return nil, fmt.Errorf("could not read ekpub: %v", err)
@ -264,7 +266,7 @@ func decryptCredential(secretKey, blob []byte) ([]byte, error) {
return secret, nil
}
func (t *platformTPM) newAK(opts *AKConfig) (*AK, error) {
func (t *windowsTPM) newAK(opts *AKConfig) (*AK, error) {
nameHex := make([]byte, 5)
if n, err := rand.Read(nameHex); err != nil || n != len(nameHex) {
return nil, fmt.Errorf("rand.Read() failed with %d/%d bytes read and error: %v", n, len(nameHex), err)
@ -291,7 +293,7 @@ func (t *platformTPM) newAK(opts *AKConfig) (*AK, error) {
}
}
func (t *platformTPM) loadAK(opaqueBlob []byte) (*AK, error) {
func (t *windowsTPM) loadAK(opaqueBlob []byte) (*AK, error) {
sKey, err := deserializeKey(opaqueBlob, t.version)
if err != nil {
return nil, fmt.Errorf("deserializeKey() failed: %v", err)
@ -334,7 +336,7 @@ func allPCRs12(tpm io.ReadWriter) (map[uint32][]byte, error) {
return out, nil
}
func (t *platformTPM) pcrs(alg HashAlg) ([]PCR, error) {
func (t *windowsTPM) pcrs(alg HashAlg) ([]PCR, error) {
var PCRs map[uint32][]byte
switch t.version {
@ -377,7 +379,7 @@ func (t *platformTPM) pcrs(alg HashAlg) ([]PCR, error) {
return out, nil
}
func (t *platformTPM) measurementLog() ([]byte, error) {
func (t *windowsTPM) measurementLog() ([]byte, error) {
context, err := tpmtbs.CreateContext(tpmtbs.TPMVersion20, tpmtbs.IncludeTPM20|tpmtbs.IncludeTPM12)
if err != nil {
return nil, err