Просмотр исходного кода

Merge pull request #231 from 9seconds/whitelists

Whitelist support
tags/v2.1.3^2
Sergey Arkhipov 4 лет назад
Родитель
Сommit
16c06f247c
Аккаунт пользователя с таким Email не найден

+ 21
- 0
example.config.toml Просмотреть файл

@@ -174,6 +174,27 @@ urls = [
174 174
 # How often do we need to update a blocklist set.
175 175
 update-each = "24h"
176 176
 
177
+# Allowlist is an opposite to a blocklist. Only those IPs that are coming from
178
+# subnets defined in these lists are allowed. All others will be rejected.
179
+#
180
+# If this feature is disabled, then there won't be any check performed by this
181
+# validator. It is possible to combine both blocklist and whitelist.
182
+[defense.allowlist]
183
+# You can enable/disable this feature.
184
+enabled = false
185
+# This is a limiter for concurrency. In order to protect website
186
+# from overloading, we download files in this number of threads.
187
+download-concurrency = 2
188
+# A list of URLs in FireHOL format (https://iplists.firehol.org/)
189
+# You can provider links here (starts with https:// or http://) or
190
+# path to a local file, but in this case it should be absolute.
191
+urls = [
192
+    # "https://iplists.firehol.org/files/firehol_level1.netset",
193
+    # "/local.file"
194
+
195
+]
196
+update-each = "24h"
197
+
177 198
 # statsd statistics integration.
178 199
 [stats.statsd]
179 200
 # enabled/disabled

+ 18
- 6
internal/cli/run_proxy.go Просмотреть файл

@@ -86,15 +86,15 @@ func makeAntiReplayCache(conf *config.Config) mtglib.AntiReplayCache {
86 86
 	)
87 87
 }
88 88
 
89
-func makeIPBlocklist(conf *config.Config, logger mtglib.Logger, ntw mtglib.Network) (mtglib.IPBlocklist, error) {
90
-	if !conf.Defense.Blocklist.Enabled.Get(false) {
89
+func makeIPBlocklist(conf config.ListConfig, logger mtglib.Logger, ntw mtglib.Network) (mtglib.IPBlocklist, error) {
90
+	if !conf.Enabled.Get(false) {
91 91
 		return ipblocklist.NewNoop(), nil
92 92
 	}
93 93
 
94 94
 	remoteURLs := []string{}
95 95
 	localFiles := []string{}
96 96
 
97
-	for _, v := range conf.Defense.Blocklist.URLs {
97
+	for _, v := range conf.URLs {
98 98
 		if v.IsRemote() {
99 99
 			remoteURLs = append(remoteURLs, v.String())
100 100
 		} else {
@@ -104,7 +104,7 @@ func makeIPBlocklist(conf *config.Config, logger mtglib.Logger, ntw mtglib.Netwo
104 104
 
105 105
 	firehol, err := ipblocklist.NewFirehol(logger.Named("ipblockist"),
106 106
 		ntw,
107
-		conf.Defense.Blocklist.DownloadConcurrency.Get(1),
107
+		conf.DownloadConcurrency.Get(1),
108 108
 		remoteURLs,
109 109
 		localFiles)
110 110
 	if err != nil {
@@ -153,7 +153,7 @@ func makeEventStream(conf *config.Config, logger mtglib.Logger) (mtglib.EventStr
153 153
 	return events.NewNoopStream(), nil
154 154
 }
155 155
 
156
-func runProxy(conf *config.Config, version string) error {
156
+func runProxy(conf *config.Config, version string) error { // nolint: funlen
157 157
 	logger := makeLogger(conf)
158 158
 
159 159
 	logger.BindJSON("configuration", conf.String()).Debug("configuration")
@@ -163,11 +163,22 @@ func runProxy(conf *config.Config, version string) error {
163 163
 		return fmt.Errorf("cannot build network: %w", err)
164 164
 	}
165 165
 
166
-	blocklist, err := makeIPBlocklist(conf, logger, ntw)
166
+	blocklist, err := makeIPBlocklist(conf.Defense.Blocklist, logger, ntw)
167 167
 	if err != nil {
168 168
 		return fmt.Errorf("cannot build ip blocklist: %w", err)
169 169
 	}
170 170
 
171
+	var whitelist mtglib.IPBlocklist
172
+
173
+	if conf.Defense.Allowlist.Enabled.Get(false) {
174
+		whlist, err := makeIPBlocklist(conf.Defense.Allowlist, logger, ntw)
175
+		if err != nil {
176
+			return fmt.Errorf("cannot build ip blocklist: %w", err)
177
+		}
178
+
179
+		whitelist = whlist
180
+	}
181
+
171 182
 	eventStream, err := makeEventStream(conf, logger)
172 183
 	if err != nil {
173 184
 		return fmt.Errorf("cannot build event stream: %w", err)
@@ -178,6 +189,7 @@ func runProxy(conf *config.Config, version string) error {
178 189
 		Network:         ntw,
179 190
 		AntiReplayCache: makeAntiReplayCache(conf),
180 191
 		IPBlocklist:     blocklist,
192
+		IPWhitelist:     whitelist,
181 193
 		EventStream:     eventStream,
182 194
 
183 195
 		Secret:             conf.Secret,

+ 20
- 9
internal/config/config.go Просмотреть файл

@@ -8,6 +8,18 @@ import (
8 8
 	"github.com/9seconds/mtg/v2/mtglib"
9 9
 )
10 10
 
11
+type Optional struct {
12
+	Enabled TypeBool `json:"enabled"`
13
+}
14
+
15
+type ListConfig struct {
16
+	Optional
17
+
18
+	DownloadConcurrency TypeConcurrency    `json:"downloadConcurrency"`
19
+	URLs                []TypeBlocklistURI `json:"urls"`
20
+	UpdateEach          TypeDuration       `json:"updateEach"`
21
+}
22
+
11 23
 type Config struct {
12 24
 	Debug                    TypeBool        `json:"debug"`
13 25
 	AllowFallbackOnUnknownDC TypeBool        `json:"allowFallbackOnUnknownDc"`
@@ -20,16 +32,13 @@ type Config struct {
20 32
 	Concurrency              TypeConcurrency `json:"concurrency"`
21 33
 	Defense                  struct {
22 34
 		AntiReplay struct {
23
-			Enabled   TypeBool      `json:"enabled"`
35
+			Optional
36
+
24 37
 			MaxSize   TypeBytes     `json:"maxSize"`
25 38
 			ErrorRate TypeErrorRate `json:"errorRate"`
26 39
 		} `json:"antiReplay"`
27
-		Blocklist struct {
28
-			Enabled             TypeBool           `json:"enabled"`
29
-			DownloadConcurrency TypeConcurrency    `json:"downloadConcurrency"`
30
-			URLs                []TypeBlocklistURI `json:"urls"`
31
-			UpdateEach          TypeDuration       `json:"updateEach"`
32
-		} `json:"blocklist"`
40
+		Blocklist ListConfig `json:"blocklist"`
41
+		Allowlist ListConfig `json:"allowlist"`
33 42
 	} `json:"defense"`
34 43
 	Network struct {
35 44
 		Timeout struct {
@@ -42,13 +51,15 @@ type Config struct {
42 51
 	} `json:"network"`
43 52
 	Stats struct {
44 53
 		StatsD struct {
45
-			Enabled      TypeBool            `json:"enabled"`
54
+			Optional
55
+
46 56
 			Address      TypeHostPort        `json:"address"`
47 57
 			MetricPrefix TypeMetricPrefix    `json:"metricPrefix"`
48 58
 			TagFormat    TypeStatsdTagFormat `json:"tagFormat"`
49 59
 		} `json:"statsd"`
50 60
 		Prometheus struct {
51
-			Enabled      TypeBool         `json:"enabled"`
61
+			Optional
62
+
52 63
 			BindTo       TypeHostPort     `json:"bindTo"`
53 64
 			HTTPPath     TypeHTTPPath     `json:"httpPath"`
54 65
 			MetricPrefix TypeMetricPrefix `json:"metricPrefix"`

+ 6
- 0
internal/config/parse.go Просмотреть файл

@@ -30,6 +30,12 @@ type tomlConfig struct {
30 30
 			URLs                []string `toml:"urls" json:"urls,omitempty"`
31 31
 			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
32 32
 		} `toml:"blocklist" json:"blocklist,omitempty"`
33
+		Allowlist struct {
34
+			Enabled             bool     `toml:"enabled" json:"enabled,omitempty"`
35
+			DownloadConcurrency uint     `toml:"download-concurrency" json:"downloadConcurrency,omitempty"`
36
+			URLs                []string `toml:"urls" json:"urls,omitempty"`
37
+			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
38
+		} `toml:"allowlist" json:"allowlist,omitempty"`
33 39
 	} `toml:"defense" json:"defense,omitempty"`
34 40
 	Network struct {
35 41
 		Timeout struct {

+ 63
- 0
ipblocklist/files/http.go Просмотреть файл

@@ -0,0 +1,63 @@
1
+package files
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"io"
7
+	"net/http"
8
+	"net/url"
9
+)
10
+
11
+type httpFile struct {
12
+	http *http.Client
13
+	url  string
14
+}
15
+
16
+func (h httpFile) Open(ctx context.Context) (io.ReadCloser, error) {
17
+	request, err := http.NewRequestWithContext(ctx, http.MethodGet, h.url, nil)
18
+	if err != nil {
19
+		panic(err)
20
+	}
21
+
22
+	response, err := h.http.Do(request)
23
+	if err != nil {
24
+		if response != nil {
25
+			io.Copy(io.Discard, response.Body) // nolint: errcheck
26
+			response.Body.Close()
27
+		}
28
+
29
+		return nil, fmt.Errorf("cannot get url %s: %w", h.url, err)
30
+	}
31
+
32
+	if response.StatusCode >= http.StatusBadRequest {
33
+		return nil, fmt.Errorf("unexpected status code %d", response.StatusCode)
34
+	}
35
+
36
+	return response.Body, nil
37
+}
38
+
39
+func (h httpFile) String() string {
40
+	return h.url
41
+}
42
+
43
+func NewHTTP(client *http.Client, endpoint string) (File, error) {
44
+	if client == nil {
45
+		return nil, ErrBadHTTPClient
46
+	}
47
+
48
+	parsed, err := url.Parse(endpoint)
49
+	if err != nil {
50
+		return nil, fmt.Errorf("incorrect url %s: %w", endpoint, err)
51
+	}
52
+
53
+	switch parsed.Scheme {
54
+	case "http", "https":
55
+	default:
56
+		return nil, fmt.Errorf("unsupported url %s", endpoint)
57
+	}
58
+
59
+	return httpFile{
60
+		http: client,
61
+		url:  endpoint,
62
+	}, nil
63
+}

+ 90
- 0
ipblocklist/files/http_test.go Просмотреть файл

@@ -0,0 +1,90 @@
1
+package files_test
2
+
3
+import (
4
+	"context"
5
+	"io"
6
+	"net/http"
7
+	"net/http/httptest"
8
+	"strings"
9
+	"testing"
10
+
11
+	"github.com/9seconds/mtg/v2/ipblocklist/files"
12
+	"github.com/stretchr/testify/suite"
13
+)
14
+
15
+type HTTPTestSuite struct {
16
+	suite.Suite
17
+
18
+	httpClient *http.Client
19
+	httpServer *httptest.Server
20
+	ctx        context.Context
21
+	ctxCancel  context.CancelFunc
22
+}
23
+
24
+func (suite *HTTPTestSuite) makeFile(path string) (files.File, error) {
25
+	return files.NewHTTP(suite.httpClient, suite.httpServer.URL+"/"+path) // nolint: wrapcheck
26
+}
27
+
28
+func (suite *HTTPTestSuite) SetupSuite() {
29
+	mux := http.NewServeMux()
30
+
31
+	mux.Handle("/", http.FileServer(http.Dir("testdata")))
32
+
33
+	suite.httpServer = httptest.NewServer(mux)
34
+	suite.httpClient = suite.httpServer.Client()
35
+}
36
+
37
+func (suite *HTTPTestSuite) SetupTest() {
38
+	suite.ctx, suite.ctxCancel = context.WithCancel(context.Background())
39
+}
40
+
41
+func (suite *HTTPTestSuite) TearDownTest() {
42
+	suite.ctxCancel()
43
+	suite.httpServer.CloseClientConnections()
44
+}
45
+
46
+func (suite *HTTPTestSuite) TearDownSuite() {
47
+	suite.httpServer.Close()
48
+}
49
+
50
+func (suite *HTTPTestSuite) TestBadURL() {
51
+	_, err := files.NewHTTP(suite.httpClient, "sdfsdf")
52
+	suite.Error(err)
53
+}
54
+
55
+func (suite *HTTPTestSuite) TestBadSchema() {
56
+	_, err := files.NewHTTP(suite.httpClient, "gopher://lala")
57
+	suite.Error(err)
58
+}
59
+
60
+func (suite *HTTPTestSuite) TestNilHTTPClient() {
61
+	_, err := files.NewHTTP(nil, "")
62
+	suite.Error(err)
63
+}
64
+
65
+func (suite *HTTPTestSuite) TestAbsentFile() {
66
+	file, err := suite.makeFile("absent")
67
+	suite.NoError(err)
68
+
69
+	_, err = file.Open(suite.ctx)
70
+	suite.Error(err)
71
+}
72
+
73
+func (suite *HTTPTestSuite) TestOk() {
74
+	file, err := suite.makeFile("readable")
75
+	suite.NoError(err)
76
+
77
+	readCloser, err := file.Open(suite.ctx)
78
+	suite.NoError(err)
79
+
80
+	defer readCloser.Close()
81
+
82
+	data, err := io.ReadAll(readCloser)
83
+	suite.NoError(err)
84
+	suite.Equal("Hooray!", strings.TrimSpace(string(data)))
85
+}
86
+
87
+func TestHTTP(t *testing.T) {
88
+	t.Parallel()
89
+	suite.Run(t, &HTTPTestSuite{})
90
+}

+ 14
- 0
ipblocklist/files/init.go Просмотреть файл

@@ -0,0 +1,14 @@
1
+package files
2
+
3
+import (
4
+	"context"
5
+	"errors"
6
+	"io"
7
+)
8
+
9
+var ErrBadHTTPClient = errors.New("incorrect http client")
10
+
11
+type File interface {
12
+	Open(context.Context) (io.ReadCloser, error)
13
+	String() string
14
+}

+ 30
- 0
ipblocklist/files/local.go Просмотреть файл

@@ -0,0 +1,30 @@
1
+package files
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"io"
7
+	"os"
8
+)
9
+
10
+type localFile struct {
11
+	path string
12
+}
13
+
14
+func (l localFile) Open(ctx context.Context) (io.ReadCloser, error) {
15
+	return os.Open(l.path) // nolint: wrapcheck
16
+}
17
+
18
+func (l localFile) String() string {
19
+	return l.path
20
+}
21
+
22
+func NewLocal(path string) (File, error) {
23
+	if stat, err := os.Stat(path); os.IsNotExist(err) || stat.IsDir() || stat.Mode().Perm()&0o400 == 0 {
24
+		return nil, fmt.Errorf("%s is not a readable file", path)
25
+	}
26
+
27
+	return localFile{
28
+		path: path,
29
+	}, nil
30
+}

+ 55
- 0
ipblocklist/files/local_test.go Просмотреть файл

@@ -0,0 +1,55 @@
1
+package files_test
2
+
3
+import (
4
+	"context"
5
+	"io"
6
+	"path/filepath"
7
+	"strings"
8
+	"testing"
9
+
10
+	"github.com/9seconds/mtg/v2/ipblocklist/files"
11
+	"github.com/stretchr/testify/assert"
12
+	"github.com/stretchr/testify/suite"
13
+)
14
+
15
+type LocalTestSuite struct {
16
+	suite.Suite
17
+}
18
+
19
+func (suite *LocalTestSuite) getLocalFile(name string) string {
20
+	return filepath.Join("testdata", name)
21
+}
22
+
23
+func (suite *LocalTestSuite) TestIncorrect() {
24
+	names := []string{
25
+		"absent",
26
+		"directory",
27
+	}
28
+
29
+	for _, v := range names {
30
+		value := v
31
+
32
+		suite.T().Run(v, func(t *testing.T) {
33
+			_, err := files.NewLocal(suite.getLocalFile(value))
34
+			assert.Error(t, err)
35
+		})
36
+	}
37
+}
38
+
39
+func (suite *LocalTestSuite) TestOk() {
40
+	file, err := files.NewLocal(suite.getLocalFile("readable"))
41
+	suite.NoError(err)
42
+
43
+	reader, err := file.Open(context.Background())
44
+	suite.NoError(err)
45
+
46
+	data, err := io.ReadAll(reader)
47
+	suite.NoError(err)
48
+
49
+	suite.Equal("Hooray!", strings.TrimSpace(string(data)))
50
+}
51
+
52
+func TestLocal(t *testing.T) {
53
+	t.Parallel()
54
+	suite.Run(t, &LocalTestSuite{})
55
+}

+ 0
- 0
ipblocklist/files/testdata/directory/.gitkeep Просмотреть файл


+ 1
- 0
ipblocklist/files/testdata/readable Просмотреть файл

@@ -0,0 +1 @@
1
+Hooray!

+ 55
- 131
ipblocklist/firehol.go Просмотреть файл

@@ -4,16 +4,13 @@ import (
4 4
 	"bufio"
5 5
 	"context"
6 6
 	"fmt"
7
-	"io"
8 7
 	"net"
9
-	"net/http"
10
-	"net/url"
11
-	"os"
12 8
 	"regexp"
13 9
 	"strings"
14 10
 	"sync"
15 11
 	"time"
16 12
 
13
+	"github.com/9seconds/mtg/v2/ipblocklist/files"
17 14
 	"github.com/9seconds/mtg/v2/mtglib"
18 15
 	"github.com/kentik/patricia"
19 16
 	"github.com/kentik/patricia/bool_tree"
@@ -41,20 +38,16 @@ var fireholRegexpComment = regexp.MustCompile(`\s*#.*?$`)
41 38
 //     127.0.0.1   # you can specify an IP
42 39
 //     10.0.0.0/8  # or cidr
43 40
 type Firehol struct {
44
-	ctx       context.Context
45
-	ctxCancel context.CancelFunc
46
-	logger    mtglib.Logger
41
+	ctx         context.Context
42
+	ctxCancel   context.CancelFunc
43
+	logger      mtglib.Logger
44
+	updateMutex sync.RWMutex
47 45
 
48
-	rwMutex sync.RWMutex
46
+	blocklists []files.File
49 47
 
50
-	remoteURLs []string
51
-	localFiles []string
52
-
53
-	httpClient *http.Client
54 48
 	workerPool *ants.Pool
55
-
56
-	treeV4 *bool_tree.TreeV4
57
-	treeV6 *bool_tree.TreeV6
49
+	treeV4     *bool_tree.TreeV4
50
+	treeV6     *bool_tree.TreeV6
58 51
 }
59 52
 
60 53
 // Shutdown stop a background update process.
@@ -68,8 +61,8 @@ func (f *Firehol) Contains(ip net.IP) bool {
68 61
 		return true
69 62
 	}
70 63
 
71
-	f.rwMutex.RLock()
72
-	defer f.rwMutex.RUnlock()
64
+	f.updateMutex.RLock()
65
+	defer f.updateMutex.RUnlock()
73 66
 
74 67
 	if ip4 := ip.To4(); ip4 != nil {
75 68
 		return f.containsIPv4(ip4)
@@ -98,22 +91,14 @@ func (f *Firehol) Run(updateEach time.Duration) {
98 91
 		}
99 92
 	}()
100 93
 
101
-	if err := f.update(); err != nil {
102
-		f.logger.WarningError("cannot update blocklist", err)
103
-	} else {
104
-		f.logger.Info("blocklist was updated")
105
-	}
94
+	f.update()
106 95
 
107 96
 	for {
108 97
 		select {
109 98
 		case <-f.ctx.Done():
110 99
 			return
111 100
 		case <-ticker.C:
112
-			if err := f.update(); err != nil {
113
-				f.logger.WarningError("cannot update blocklist", err)
114
-			} else {
115
-				f.logger.Info("blocklist was updated")
116
-			}
101
+			f.update()
117 102
 		}
118 103
 	}
119 104
 }
@@ -138,121 +123,53 @@ func (f *Firehol) containsIPv6(addr net.IP) bool {
138 123
 	return false
139 124
 }
140 125
 
141
-func (f *Firehol) update() error { // nolint: funlen, cyclop
126
+func (f *Firehol) update() {
142 127
 	ctx, cancel := context.WithCancel(f.ctx)
143 128
 	defer cancel()
144 129
 
145 130
 	wg := &sync.WaitGroup{}
146
-	wg.Add(len(f.remoteURLs) + len(f.localFiles))
131
+	wg.Add(len(f.blocklists))
147 132
 
148 133
 	treeMutex := &sync.Mutex{}
149 134
 	v4tree := bool_tree.NewTreeV4()
150 135
 	v6tree := bool_tree.NewTreeV6()
151 136
 
152
-	errorChan := make(chan error, 1)
153
-	defer close(errorChan)
154
-
155
-	for _, v := range f.localFiles {
156
-		go func(filename string) {
137
+	for _, v := range f.blocklists {
138
+		go func(file files.File) {
157 139
 			defer wg.Done()
158 140
 
159
-			if err := f.updateLocalFile(ctx, filename, treeMutex, v4tree, v6tree); err != nil {
160
-				cancel()
161
-				f.logger.BindStr("filename", filename).WarningError("cannot update", err)
141
+			logger := f.logger.BindStr("filename", file.String())
162 142
 
163
-				select {
164
-				case errorChan <- err:
165
-				default:
166
-				}
167
-			}
168
-		}(v)
169
-	}
143
+			fileContent, err := file.Open(ctx)
144
+			if err != nil {
145
+				logger.WarningError("update has failed", err)
170 146
 
171
-	for _, v := range f.remoteURLs {
172
-		value := v
173
-
174
-		f.workerPool.Submit(func() { // nolint: errcheck
175
-			defer wg.Done()
147
+				return
148
+			}
176 149
 
177
-			if err := f.updateRemoteURL(ctx, value, treeMutex, v4tree, v6tree); err != nil {
178
-				cancel()
179
-				f.logger.BindStr("url", value).WarningError("cannot update", err)
150
+			defer fileContent.Close()
180 151
 
181
-				select {
182
-				case errorChan <- err:
183
-				default:
184
-				}
152
+			if err := f.updateFromFile(treeMutex, v4tree, v6tree, bufio.NewScanner(fileContent)); err != nil {
153
+				logger.WarningError("update has failed", err)
185 154
 			}
186
-		})
155
+		}(v)
187 156
 	}
188 157
 
189 158
 	wg.Wait()
190 159
 
191
-	select {
192
-	case err := <-errorChan:
193
-		return fmt.Errorf("cannot update trees: %w", err)
194
-	default:
195
-	}
196
-
197
-	f.rwMutex.Lock()
198
-	defer f.rwMutex.Unlock()
160
+	f.updateMutex.Lock()
161
+	defer f.updateMutex.Unlock()
199 162
 
200 163
 	f.treeV4 = v4tree
201 164
 	f.treeV6 = v6tree
202 165
 
203
-	return nil
204
-}
205
-
206
-func (f *Firehol) updateLocalFile(ctx context.Context, filename string,
207
-	mutex sync.Locker,
208
-	v4tree *bool_tree.TreeV4, v6tree *bool_tree.TreeV6) error {
209
-	filefp, err := os.Open(filename)
210
-	if err != nil {
211
-		return fmt.Errorf("cannot open file: %w", err)
212
-	}
213
-
214
-	go func(ctx context.Context, closer io.Closer) {
215
-		<-ctx.Done()
216
-		closer.Close()
217
-	}(ctx, filefp)
218
-
219
-	defer filefp.Close()
220
-
221
-	return f.updateTrees(mutex, filefp, v4tree, v6tree)
166
+	f.logger.Info("blocklist was updated")
222 167
 }
223 168
 
224
-func (f *Firehol) updateRemoteURL(ctx context.Context, url string,
225
-	mutex sync.Locker,
226
-	v4tree *bool_tree.TreeV4, v6tree *bool_tree.TreeV6) error {
227
-	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
228
-	if err != nil {
229
-		return fmt.Errorf("cannot build a request: %w", err)
230
-	}
231
-
232
-	resp, err := f.httpClient.Do(req) // nolint: bodyclose
233
-	if err != nil {
234
-		return fmt.Errorf("cannot request a remote URL %s: %w", url, err)
235
-	}
236
-
237
-	go func(ctx context.Context, closer io.Closer) {
238
-		<-ctx.Done()
239
-		closer.Close()
240
-	}(ctx, resp.Body)
241
-
242
-	defer func(rc io.ReadCloser) {
243
-		io.Copy(io.Discard, rc) // nolint: errcheck
244
-		rc.Close()
245
-	}(resp.Body)
246
-
247
-	return f.updateTrees(mutex, resp.Body, v4tree, v6tree)
248
-}
249
-
250
-func (f *Firehol) updateTrees(mutex sync.Locker,
251
-	reader io.Reader,
169
+func (f *Firehol) updateFromFile(mutex sync.Locker,
252 170
 	v4tree *bool_tree.TreeV4,
253
-	v6tree *bool_tree.TreeV6) error {
254
-	scanner := bufio.NewScanner(reader)
255
-
171
+	v6tree *bool_tree.TreeV6,
172
+	scanner *bufio.Scanner) error {
256 173
 	for scanner.Scan() {
257 174
 		text := scanner.Text()
258 175
 		text = fireholRegexpComment.ReplaceAllLiteralString(text, "")
@@ -271,7 +188,7 @@ func (f *Firehol) updateTrees(mutex sync.Locker,
271 188
 	}
272 189
 
273 190
 	if scanner.Err() != nil {
274
-		return fmt.Errorf("cannot parse a response: %w", scanner.Err())
191
+		return fmt.Errorf("cannot parse a file: %w", scanner.Err())
275 192
 	}
276 193
 
277 194
 	return nil
@@ -317,27 +234,36 @@ func (f *Firehol) updateAddToTrees(ip net.IP, cidr uint,
317 234
 // when it is necessary.
318 235
 func NewFirehol(logger mtglib.Logger, network mtglib.Network,
319 236
 	downloadConcurrency uint,
320
-	remoteURLs []string,
237
+	urls []string,
321 238
 	localFiles []string) (*Firehol, error) {
322
-	for _, v := range remoteURLs {
323
-		parsed, err := url.Parse(v)
239
+	blocklists := []files.File{}
240
+
241
+	for _, v := range localFiles {
242
+		file, err := files.NewLocal(v)
324 243
 		if err != nil {
325
-			return nil, fmt.Errorf("incorrect url %s: %w", v, err)
244
+			return nil, fmt.Errorf("cannot create a local file %s: %w", v, err)
326 245
 		}
327 246
 
328
-		switch parsed.Scheme {
329
-		case "http", "https":
330
-		default:
331
-			return nil, fmt.Errorf("unsupported url %s", v)
332
-		}
247
+		blocklists = append(blocklists, file)
333 248
 	}
334 249
 
335
-	for _, v := range localFiles {
336
-		if stat, err := os.Stat(v); os.IsNotExist(err) || stat.IsDir() || stat.Mode().Perm()&0o400 == 0 {
337
-			return nil, fmt.Errorf("%s is not a readable file", v)
250
+	httpClient := network.MakeHTTPClient(nil)
251
+
252
+	for _, v := range urls {
253
+		file, err := files.NewHTTP(httpClient, v)
254
+		if err != nil {
255
+			return nil, fmt.Errorf("cannot create a HTTP file %s: %w", v, err)
338 256
 		}
257
+
258
+		blocklists = append(blocklists, file)
339 259
 	}
340 260
 
261
+	return NewFireholFromFiles(logger, downloadConcurrency, blocklists)
262
+}
263
+
264
+func NewFireholFromFiles(logger mtglib.Logger,
265
+	downloadConcurrency uint,
266
+	blocklists []files.File) (*Firehol, error) {
341 267
 	if downloadConcurrency == 0 {
342 268
 		downloadConcurrency = DefaultFireholDownloadConcurrency
343 269
 	}
@@ -349,11 +275,9 @@ func NewFirehol(logger mtglib.Logger, network mtglib.Network,
349 275
 		ctx:        ctx,
350 276
 		ctxCancel:  cancel,
351 277
 		logger:     logger.Named("firehol"),
352
-		httpClient: network.MakeHTTPClient(nil),
353 278
 		treeV4:     bool_tree.NewTreeV4(),
354 279
 		treeV6:     bool_tree.NewTreeV6(),
355 280
 		workerPool: workerPool,
356
-		remoteURLs: remoteURLs,
357
-		localFiles: localFiles,
281
+		blocklists: blocklists,
358 282
 	}, nil
359 283
 }

+ 14
- 4
mtglib/proxy.go Просмотреть файл

@@ -33,7 +33,8 @@ type Proxy struct {
33 33
 	secret          Secret
34 34
 	network         Network
35 35
 	antiReplayCache AntiReplayCache
36
-	ipBlocklist     IPBlocklist
36
+	blocklist       IPBlocklist
37
+	whitelist       IPBlocklist
37 38
 	eventStream     EventStream
38 39
 	logger          Logger
39 40
 }
@@ -91,7 +92,7 @@ func (p *Proxy) ServeConn(conn net.Conn) {
91 92
 }
92 93
 
93 94
 // Serve starts a proxy on a given listener.
94
-func (p *Proxy) Serve(listener net.Listener) error {
95
+func (p *Proxy) Serve(listener net.Listener) error { // nolint: cyclop
95 96
 	p.streamWaitGroup.Add(1)
96 97
 	defer p.streamWaitGroup.Done()
97 98
 
@@ -109,7 +110,15 @@ func (p *Proxy) Serve(listener net.Listener) error {
109 110
 		ipAddr := conn.RemoteAddr().(*net.TCPAddr).IP
110 111
 		logger := p.logger.BindStr("ip", ipAddr.String())
111 112
 
112
-		if p.ipBlocklist.Contains(ipAddr) {
113
+		if p.whitelist != nil && !p.whitelist.Contains(ipAddr) {
114
+			conn.Close()
115
+			logger.Info("ip was rejected by whitelist")
116
+			p.eventStream.Send(p.ctx, NewEventIPBlocklisted(ipAddr))
117
+
118
+			continue
119
+		}
120
+
121
+		if p.blocklist.Contains(ipAddr) {
113 122
 			conn.Close()
114 123
 			logger.Info("ip was blacklisted")
115 124
 			p.eventStream.Send(p.ctx, NewEventIPBlocklisted(ipAddr))
@@ -291,7 +300,8 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) {
291 300
 		secret:                   opts.Secret,
292 301
 		network:                  opts.Network,
293 302
 		antiReplayCache:          opts.AntiReplayCache,
294
-		ipBlocklist:              opts.IPBlocklist,
303
+		blocklist:                opts.IPBlocklist,
304
+		whitelist:                opts.IPWhitelist,
295 305
 		eventStream:              opts.EventStream,
296 306
 		logger:                   opts.getLogger("proxy"),
297 307
 		domainFrontingPort:       opts.getDomainFrontingPort(),

+ 5
- 0
mtglib/proxy_opts.go Просмотреть файл

@@ -28,6 +28,11 @@ type ProxyOpts struct {
28 28
 	// This is a mandatory setting.
29 29
 	IPBlocklist IPBlocklist
30 30
 
31
+	// IPWhitelist defines a whitelist of IPs to allow to use proxy.
32
+	//
33
+	// This is an optional setting, ignored by default (no restrictions).
34
+	IPWhitelist IPBlocklist
35
+
31 36
 	// EventStream defines an instance of event stream.
32 37
 	//
33 38
 	// This ia a mandatory setting.

Загрузка…
Отмена
Сохранить