diff --git a/attest/attest-tool/attest-tool.go b/attest/attest-tool/attest-tool.go index 41711e4..aad3327 100644 --- a/attest/attest-tool/attest-tool.go +++ b/attest/attest-tool/attest-tool.go @@ -113,11 +113,14 @@ func runCommand(tpm *attest.TPM) error { } case "list-pcrs": - pcrs, alg, err := tpm.PCRs() + alg := attest.HashSHA1 + if *useSHA256 { + alg = attest.HashSHA256 + } + pcrs, err := tpm.PCRs(alg) if err != nil { return fmt.Errorf("failed to read PCRs: %v", err) } - fmt.Printf("PCR digest: %v\n", alg) for _, pcr := range pcrs { fmt.Printf("PCR[%d]: %x\n", pcr.Index, pcr.Digest) } @@ -180,13 +183,9 @@ func runDump(tpm *attest.TPM) (*internal.Dump, error) { return nil, fmt.Errorf("failed to read measurement log: %v", err) } // Get PCR values. - pcrs, _, err := tpm.PCRs() - if err != nil { + if out.Log.PCRs, err = tpm.PCRs(out.Quote.Alg); err != nil { return nil, fmt.Errorf("failed to read PCRs: %v", err) } - for _, pcr := range pcrs { - out.Log.PCRs = append(out.Log.PCRs, pcr) - } return &out, nil } diff --git a/attest/attest.go b/attest/attest.go index 994987a..ebde9c8 100644 --- a/attest/attest.go +++ b/attest/attest.go @@ -254,6 +254,26 @@ var ( HashSHA256 = HashAlg(tpm2.AlgSHA256) ) +func (a HashAlg) cryptoHash() crypto.Hash { + switch a { + case HashSHA1: + return crypto.SHA1 + case HashSHA256: + return crypto.SHA256 + } + return 0 +} + +func (a HashAlg) goTPMAlg() tpm2.Algorithm { + switch a { + case HashSHA1: + return tpm2.AlgSHA1 + case HashSHA256: + return tpm2.AlgSHA256 + } + return 0 +} + var ( defaultOpenConfig = &OpenConfig{} diff --git a/attest/attest_simulated_tpm20_test.go b/attest/attest_simulated_tpm20_test.go index ed53c4c..459f0a8 100644 --- a/attest/attest_simulated_tpm20_test.go +++ b/attest/attest_simulated_tpm20_test.go @@ -25,7 +25,6 @@ import ( "github.com/google/certificate-transparency-go/x509" "github.com/google/go-tpm-tools/simulator" - "github.com/google/go-tpm/tpm2" ) func setupSimulatedTPM(t *testing.T) (*simulator.Simulator, *TPM) { @@ -193,16 +192,14 @@ func TestSimTPM20PCRs(t *testing.T) { sim, tpm := setupSimulatedTPM(t) defer sim.Close() - PCRs, alg, err := tpm.PCRs() + PCRs, err := tpm.PCRs(HashSHA256) if err != nil { t.Fatalf("PCRs() failed: %v", err) } if len(PCRs) != 24 { t.Errorf("len(PCRs) = %d, want %d", len(PCRs), 24) } - if got, want := tpm2.AlgSHA256, alg; got != want { - t.Errorf("alg = %v, want %v", got, want) - } + for i, pcr := range PCRs { if len(pcr.Digest) != pcr.DigestAlg.Size() { t.Errorf("PCR %d len(digest) = %d, expected match with digest algorithm size (%d)", pcr.Index, len(pcr.Digest), pcr.DigestAlg.Size()) diff --git a/attest/attest_test.go b/attest/attest_test.go index 6b9a41b..c7d0e7d 100644 --- a/attest/attest_test.go +++ b/attest/attest_test.go @@ -146,7 +146,7 @@ func TestPCRs(t *testing.T) { } defer tpm.Close() - PCRs, _, err := tpm.PCRs() + PCRs, err := tpm.PCRs(HashSHA1) if err != nil { t.Fatalf("PCRs() failed: %v", err) } diff --git a/attest/attest_tpm12_test.go b/attest/attest_tpm12_test.go index 72a194f..6a4be84 100644 --- a/attest/attest_tpm12_test.go +++ b/attest/attest_tpm12_test.go @@ -54,7 +54,7 @@ func TestTPM12PCRs(t *testing.T) { tpm := openTPM12(t) defer tpm.Close() - PCRs, _, err := tpm.PCRs() + PCRs, err := tpm.PCRs(HashSHA1) if err != nil { t.Fatalf("Failed to get PCR values: %v", err) } diff --git a/attest/tpm.go b/attest/tpm.go index 2cf82bb..c0d52fa 100644 --- a/attest/tpm.go +++ b/attest/tpm.go @@ -16,7 +16,6 @@ package attest import ( "bytes" - "crypto" "encoding/binary" "fmt" "io" @@ -251,16 +250,6 @@ func readAllPCRs20(tpm io.ReadWriter, alg tpm2.Algorithm) (map[uint32][]byte, er return out, nil } -func allPCRs20(tpm io.ReadWriter) (map[uint32][]byte, crypto.Hash, error) { - out256, err256 := readAllPCRs20(tpm, tpm2.AlgSHA256) - if err256 != nil { - // TPM may not implement active banks with SHA256 - try SHA1. - out1, err1 := readAllPCRs20(tpm, tpm2.AlgSHA1) - return out1, crypto.SHA1, err1 - } - return out256, crypto.SHA256, nil -} - // LoadAIK loads a previously-created aik into the TPM for use. // A key loaded via this function needs to be closed with .Close(). // Only blobs generated by calling AIK.Serialize() are valid parameters diff --git a/attest/tpm_linux.go b/attest/tpm_linux.go index a98a9ae..55abed9 100644 --- a/attest/tpm_linux.go +++ b/attest/tpm_linux.go @@ -17,7 +17,6 @@ package attest import ( - "crypto" "crypto/rsa" "encoding/binary" "errors" @@ -358,49 +357,45 @@ func allPCRs12(ctx *tspi.Context) (map[uint32][]byte, error) { return PCRs, nil } -// PCRs returns the present value of all Platform Configuration Registers. -func (t *TPM) PCRs() (map[int]PCR, tpm2.Algorithm, error) { +// TODO: Refactor PCRs() into a file not subject to build tags, and implement +// platform-specific logic in private methods. + +// PCRs returns the present value of Platform Configuration Registers with the +// given digest algorithm. +func (t *TPM) PCRs(alg HashAlg) ([]PCR, error) { var PCRs map[uint32][]byte - var alg crypto.Hash var err error switch t.version { case TPMVersion12: + if alg != HashSHA1 { + return nil, fmt.Errorf("non-SHA1 algorithm %v is not supported on TPM 1.2", alg) + } PCRs, err = allPCRs12(t.ctx) if err != nil { - return nil, 0, fmt.Errorf("failed to read PCRs: %v", err) + return nil, fmt.Errorf("failed to read PCRs: %v", err) } - alg = crypto.SHA1 case TPMVersion20: - PCRs, alg, err = allPCRs20(t.rwc) + PCRs, err = readAllPCRs20(t.rwc, alg.goTPMAlg()) if err != nil { - return nil, 0, fmt.Errorf("failed to read PCRs: %v", err) + return nil, fmt.Errorf("failed to read PCRs: %v", err) } default: - return nil, 0, fmt.Errorf("unsupported TPM version: %x", t.version) + return nil, fmt.Errorf("unsupported TPM version: %x", t.version) } - out := map[int]PCR{} - var lastAlg crypto.Hash + out := make([]PCR, len(PCRs)) for index, digest := range PCRs { out[int(index)] = PCR{ Index: int(index), Digest: digest, - DigestAlg: alg, + DigestAlg: alg.cryptoHash(), } - lastAlg = alg } - switch lastAlg { - case crypto.SHA1: - return out, tpm2.AlgSHA1, nil - case crypto.SHA256: - return out, tpm2.AlgSHA256, nil - default: - return nil, 0, fmt.Errorf("unexpected algorithm: %v", lastAlg) - } + return out, nil } // MeasurementLog returns the present value of the System Measurement Log. diff --git a/attest/tpm_windows.go b/attest/tpm_windows.go index 2f8b6ec..d0cee68 100644 --- a/attest/tpm_windows.go +++ b/attest/tpm_windows.go @@ -18,7 +18,6 @@ package attest import ( "bytes" - "crypto" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -34,7 +33,6 @@ import ( "net/http" tpm1 "github.com/google/go-tpm/tpm" - "github.com/google/go-tpm/tpm2" tpmtbs "github.com/google/go-tpm/tpmutil/tbs" ) @@ -364,55 +362,49 @@ func allPCRs12(tpm io.ReadWriter) (map[uint32][]byte, error) { return out, nil } -// PCRs returns the present value of all Platform Configuration Registers. -func (t *TPM) PCRs() (map[int]PCR, tpm2.Algorithm, error) { +// PCRs returns the present value of Platform Configuration Registers with the +// given digest algorithm. +func (t *TPM) PCRs(alg HashAlg) ([]PCR, error) { var PCRs map[uint32][]byte - var alg crypto.Hash + switch t.version { case TPMVersion12: - alg = crypto.SHA1 + if alg != HashSHA1 { + return nil, fmt.Errorf("non-SHA1 algorithm %v is not supported on TPM 1.2", alg) + } tpm, err := t.pcp.TPMCommandInterface() if err != nil { - return nil, 0, fmt.Errorf("TPMCommandInterface() failed: %v", err) + return nil, fmt.Errorf("TPMCommandInterface() failed: %v", err) } PCRs, err = allPCRs12(tpm) if err != nil { - return nil, 0, fmt.Errorf("failed to read PCRs: %v", err) + return nil, fmt.Errorf("failed to read PCRs: %v", err) } case TPMVersion20: tpm, err := t.pcp.TPMCommandInterface() if err != nil { - return nil, 0, fmt.Errorf("TPMCommandInterface() failed: %v", err) + return nil, fmt.Errorf("TPMCommandInterface() failed: %v", err) } - PCRs, alg, err = allPCRs20(tpm) + PCRs, err = readAllPCRs20(tpm, alg.goTPMAlg()) if err != nil { - return nil, 0, fmt.Errorf("failed to read PCRs: %v", err) + return nil, fmt.Errorf("failed to read PCRs: %v", err) } default: - return nil, 0, fmt.Errorf("unsupported TPM version: %x", t.version) + return nil, fmt.Errorf("unsupported TPM version: %x", t.version) } - out := map[int]PCR{} - var lastAlg crypto.Hash + out := make([]PCR, len(PCRs)) for index, digest := range PCRs { out[int(index)] = PCR{ Index: int(index), Digest: digest, - DigestAlg: alg, + DigestAlg: alg.cryptoHash(), } - lastAlg = alg } - switch lastAlg { - case crypto.SHA1: - return out, tpm2.AlgSHA1, nil - case crypto.SHA256: - return out, tpm2.AlgSHA256, nil - default: - return nil, 0, fmt.Errorf("unexpected algorithm: %v", lastAlg) - } + return out, nil } // MeasurementLog returns the present value of the System Measurement Log.