-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathalgo_hs.go
105 lines (92 loc) · 1.92 KB
/
algo_hs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package jwt
import (
"crypto"
"crypto/hmac"
"hash"
"sync"
)
// NewSignerHS returns a new HMAC-based signer.
func NewSignerHS(alg Algorithm, key []byte) (*HSAlg, error) {
return newHS(alg, key)
}
// NewVerifierHS returns a new HMAC-based verifier.
func NewVerifierHS(alg Algorithm, key []byte) (*HSAlg, error) {
return newHS(alg, key)
}
func newHS(alg Algorithm, key []byte) (*HSAlg, error) {
if len(key) == 0 {
return nil, ErrNilKey
}
hash, ok := getHashHMAC(alg)
if !ok {
return nil, ErrUnsupportedAlg
}
return &HSAlg{
alg: alg,
hash: hash,
key: key,
hashPool: &sync.Pool{
New: func() any {
return hmac.New(hash.New, key)
},
},
}, nil
}
func getHashHMAC(alg Algorithm) (crypto.Hash, bool) {
switch alg {
case HS256:
return crypto.SHA256, true
case HS384:
return crypto.SHA384, true
case HS512:
return crypto.SHA512, true
default:
return 0, false
}
}
type HSAlg struct {
alg Algorithm
hash crypto.Hash
key []byte
hashPool *sync.Pool
}
func (hs *HSAlg) Algorithm() Algorithm {
return hs.alg
}
func (hs *HSAlg) SignSize() int {
return hs.hash.Size()
}
func (hs *HSAlg) Sign(payload []byte) ([]byte, error) {
return hs.sign(payload)
}
func (hs *HSAlg) Verify(token *Token) error {
switch {
case !token.isValid():
return ErrUninitializedToken
case !constTimeAlgEqual(token.Header().Algorithm, hs.alg):
return ErrAlgorithmMismatch
default:
return hs.verify(token.PayloadPart(), token.Signature())
}
}
func (hs *HSAlg) verify(payload, signature []byte) error {
digest, err := hs.sign(payload)
if err != nil {
return err
}
if !hmac.Equal(signature, digest) {
return ErrInvalidSignature
}
return nil
}
func (hs *HSAlg) sign(payload []byte) ([]byte, error) {
hasher := hs.hashPool.Get().(hash.Hash)
defer func() {
hasher.Reset()
hs.hashPool.Put(hasher)
}()
if _, err := hasher.Write(payload); err != nil {
return nil, err
}
return hasher.Sum(nil), nil
}