-
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathauth.go
135 lines (117 loc) · 3.33 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright 2021 Changkun Ou. All rights reserved.
// Use of this source code is governed by a MIT
// license that can be found in the LICENSE file.
package main
import (
"errors"
"fmt"
"log"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"changkun.de/x/login"
"changkun.de/x/redir/internal/config"
"changkun.de/x/redir/internal/utils"
)
var errUnauthorized = errors.New("request unauthorized")
// blocklist holds the ip that should be blocked for further requests.
//
// This map may keep grow without releasing memory because of
// continuously attempts. we also do not persist this type of block info
// to the disk, which means if we reboot the service then all the blocker
// are gone and they can attack the server again.
// We clear the map very month.
var blocklist sync.Map // map[string]*blockinfo{}
func init() {
t := time.NewTicker(time.Hour * 24 * 30)
go func() {
for range t.C {
blocklist.Range(func(k, v interface{}) bool {
blocklist.Delete(k)
return true
})
}
}()
}
type blockinfo struct {
failCount int64
lastFail atomic.Value // time.Time
blockTime atomic.Value // time.Duration
}
const maxFailureAttempts = 3
func (s *server) handleAuth(w http.ResponseWriter, r *http.Request) (user string, err error) {
switch config.Conf.Auth.Enable {
case config.None:
return
case config.SSO:
user, err := login.HandleAuth(w, r)
if err != nil {
uu, _ := url.Parse(config.Conf.Auth.SSO)
q := uu.Query()
q.Set("redirect", "https://"+r.Host+r.URL.String())
uu.RawQuery = q.Encode()
http.Redirect(w, r, uu.String(), http.StatusFound)
}
return user, err
case config.Basic:
}
w.Header().Set("WWW-Authenticate", `Basic realm="redir"`)
u, p, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
err = fmt.Errorf("%w: failed to parsing basic auth", errUnauthorized)
return
}
// check if the IP failure attempts are too much
// if so, direct abort the request without checking credentials
ip := utils.ReadIP(r)
if i, ok := blocklist.Load(ip); ok {
info := i.(*blockinfo)
count := atomic.LoadInt64(&info.failCount)
if count > maxFailureAttempts {
// if the ip is under block, then directly abort
last := info.lastFail.Load().(time.Time)
bloc := info.blockTime.Load().(time.Duration)
if time.Now().UTC().Sub(last.Add(bloc)) < 0 {
log.Printf("block ip %v, too much failure attempts. Block time: %v, release until: %v\n",
ip, bloc, last.Add(bloc))
err = fmt.Errorf("%w: too much failure attempts", errUnauthorized)
return
}
// clear the failcount, but increase the next block time
atomic.StoreInt64(&info.failCount, 0)
info.blockTime.Store(bloc * 2)
}
}
defer func() {
if !errors.Is(err, errUnauthorized) {
return
}
if i, ok := blocklist.Load(ip); !ok {
info := &blockinfo{
failCount: 1,
}
info.lastFail.Store(time.Now().UTC())
info.blockTime.Store(time.Second * 10)
blocklist.Store(ip, info)
} else {
info := i.(*blockinfo)
atomic.AddInt64(&info.failCount, 1)
info.lastFail.Store(time.Now().UTC())
}
}()
found := false
for _, account := range config.Conf.Auth.Basic {
if u == account.Username && p == account.Password {
found = true
break
}
}
if !found {
w.WriteHeader(http.StatusUnauthorized)
return "", fmt.Errorf("%w: username or password is invalid", errUnauthorized)
}
return u, nil
}