Remove length-based hash lookups

Using the length of a digest to infer the hash algorithm is somewhat
fragile - if we end up with multiple hash algorithms that share the same
digest length, things will break. Instead, pass more complete digest
information through to relevant functions and figure things out by
mapping the TPM hash algorithm to the appropriate Golang type.
This commit is contained in:
Matthew Garrett 2020-04-14 14:38:24 -07:00 committed by GitHub
parent fe41cef1db
commit 0815f5e221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 33 deletions

View File

@ -223,49 +223,38 @@ func (a *AKPublic) validate20Quote(quote Quote, pcrs []PCR, nonce []byte) error
} }
pcrByIndex := map[int][]byte{} pcrByIndex := map[int][]byte{}
pcrDigestLength := HashAlg(att.AttestedQuoteInfo.PCRSelection.Hash).cryptoHash().Size() pcrDigestAlg := HashAlg(att.AttestedQuoteInfo.PCRSelection.Hash).cryptoHash()
for _, pcr := range pcrs { for _, pcr := range pcrs {
// TODO(jsonp): Use pcr.DigestAlg once #116 is fixed. if pcr.DigestAlg == pcrDigestAlg {
if len(pcr.Digest) == pcrDigestLength {
pcrByIndex[pcr.Index] = pcr.Digest pcrByIndex[pcr.Index] = pcr.Digest
} }
} }
n := len(att.AttestedQuoteInfo.PCRDigest) sigHash.Reset()
hash, ok := hashBySize[n]
if !ok {
return fmt.Errorf("quote used unsupported hash algorithm length: %d", n)
}
h := hash.New()
for _, index := range att.AttestedQuoteInfo.PCRSelection.PCRs { for _, index := range att.AttestedQuoteInfo.PCRSelection.PCRs {
digest, ok := pcrByIndex[index] digest, ok := pcrByIndex[index]
if !ok { if !ok {
return fmt.Errorf("quote was over PCR %d which wasn't provided", index) return fmt.Errorf("quote was over PCR %d which wasn't provided", index)
} }
h.Write(digest) sigHash.Write(digest)
} }
if !bytes.Equal(h.Sum(nil), att.AttestedQuoteInfo.PCRDigest) { if !bytes.Equal(sigHash.Sum(nil), att.AttestedQuoteInfo.PCRDigest) {
return fmt.Errorf("quote digest didn't match pcrs provided") return fmt.Errorf("quote digest didn't match pcrs provided")
} }
return nil return nil
} }
var hashBySize = map[int]crypto.Hash{ func extend(pcr PCR, replay []byte, e rawEvent) (pcrDigest []byte, eventDigest []byte, err error) {
crypto.SHA1.Size(): crypto.SHA1, h := pcr.DigestAlg
crypto.SHA256.Size(): crypto.SHA256,
}
func extend(pcr, replay []byte, e rawEvent) (pcrDigest []byte, eventDigest []byte, err error) {
h, ok := hashBySize[len(pcr)]
if !ok {
return nil, nil, fmt.Errorf("pcr %d was not a known hash size: %d", e.index, len(pcr))
}
for _, digest := range e.digests { for _, digest := range e.digests {
if len(digest) != len(pcr) { if digest.hash != pcr.DigestAlg {
continue continue
} }
if len(digest.data) != len(pcr.Digest) {
return nil, nil, fmt.Errorf("digest data length (%d) doesn't match PCR digest length (%d)", len(digest.data), len(pcr.Digest));
}
hash := h.New() hash := h.New()
if len(replay) != 0 { if len(replay) != 0 {
hash.Write(replay) hash.Write(replay)
@ -273,10 +262,10 @@ func extend(pcr, replay []byte, e rawEvent) (pcrDigest []byte, eventDigest []byt
b := make([]byte, h.Size()) b := make([]byte, h.Size())
hash.Write(b) hash.Write(b)
} }
hash.Write(digest) hash.Write(digest.data)
return hash.Sum(nil), digest, nil return hash.Sum(nil), digest.data, nil
} }
return nil, nil, fmt.Errorf("no event digest matches pcr length: %d", len(pcr)) return nil, nil, fmt.Errorf("no event digest matches pcr algorithm: %v", pcr.DigestAlg)
} }
// replayPCR replays the event log for a specific PCR, using pcr and // replayPCR replays the event log for a specific PCR, using pcr and
@ -294,7 +283,7 @@ func replayPCR(rawEvents []rawEvent, pcr PCR) ([]Event, bool) {
continue continue
} }
replayValue, digest, err := extend(pcr.Digest, replay, e) replayValue, digest, err := extend(pcr, replay, e)
if err != nil { if err != nil {
return nil, false return nil, false
} }
@ -489,12 +478,17 @@ func parseSpecIDEvent(b []byte) (*specIDEvent, error) {
return &e, nil return &e, nil
} }
type digest struct {
hash crypto.Hash
data []byte
}
type rawEvent struct { type rawEvent struct {
sequence int sequence int
index int index int
typ EventType typ EventType
data []byte data []byte
digests [][]byte digests []digest
} }
// TPM 1.2 event log format. See "5.1 SHA1 Event Log Entry Format" // TPM 1.2 event log format. See "5.1 SHA1 Event Log Entry Format"
@ -526,15 +520,19 @@ func parseRawEvent(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err er
if h.EventSize > uint32(r.Len()) { if h.EventSize > uint32(r.Len()) {
return event, &eventSizeErr{h.EventSize, r.Len()} return event, &eventSizeErr{h.EventSize, r.Len()}
} }
data := make([]byte, int(h.EventSize)) data := make([]byte, int(h.EventSize))
if _, err := io.ReadFull(r, data); err != nil { if _, err := io.ReadFull(r, data); err != nil {
return event, err return event, err
} }
digests := []digest{{hash: crypto.SHA1, data: h.Digest[:]}}
return rawEvent{ return rawEvent{
typ: EventType(h.Type), typ: EventType(h.Type),
data: data, data: data,
index: int(h.PCRIndex), index: int(h.PCRIndex),
digests: [][]byte{h.Digest[:]}, digests: digests,
}, nil }, nil
} }
@ -547,6 +545,7 @@ type rawEvent2Header struct {
func parseRawEvent2(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err error) { func parseRawEvent2(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err error) {
var h rawEvent2Header var h rawEvent2Header
if err = binary.Read(r, binary.LittleEndian, &h); err != nil { if err = binary.Read(r, binary.LittleEndian, &h); err != nil {
return event, err return event, err
} }
@ -564,7 +563,8 @@ func parseRawEvent2(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err e
if err := binary.Read(r, binary.LittleEndian, &algID); err != nil { if err := binary.Read(r, binary.LittleEndian, &algID); err != nil {
return event, err return event, err
} }
var digest []byte var digest digest
for _, alg := range specID.algs { for _, alg := range specID.algs {
if alg.ID != algID { if alg.ID != algID {
continue continue
@ -572,12 +572,13 @@ func parseRawEvent2(r *bytes.Buffer, specID *specIDEvent) (event rawEvent, err e
if uint16(r.Len()) < alg.Size { if uint16(r.Len()) < alg.Size {
return event, fmt.Errorf("reading digest: %v", io.ErrUnexpectedEOF) return event, fmt.Errorf("reading digest: %v", io.ErrUnexpectedEOF)
} }
digest = make([]byte, alg.Size) digest.data = make([]byte, alg.Size)
digest.hash = HashAlg(alg.ID).cryptoHash()
} }
if len(digest) == 0 { if len(digest.data) == 0 {
return event, fmt.Errorf("unknown algorithm ID %x", algID) return event, fmt.Errorf("unknown algorithm ID %x", algID)
} }
if _, err := io.ReadFull(r, digest); err != nil { if _, err := io.ReadFull(r, digest.data); err != nil {
return event, err return event, err
} }
event.digests = append(event.digests, digest) event.digests = append(event.digests, digest)

File diff suppressed because one or more lines are too long