瀏覽代碼

Refactor firehol

tags/v2.1.3^2
9seconds 4 年之前
父節點
當前提交
558fec60de
共有 5 個檔案被更改,包括 64 行新增135 行删除
  1. 5
    1
      ipblocklist/files/http.go
  2. 1
    1
      ipblocklist/files/http_test.go
  3. 1
    0
      ipblocklist/files/init.go
  4. 7
    7
      ipblocklist/files/local.go
  5. 50
    126
      ipblocklist/firehol.go

+ 5
- 1
ipblocklist/files/http.go 查看文件

@@ -22,7 +22,7 @@ func (h httpFile) Open(ctx context.Context) (io.ReadCloser, error) {
22 22
 	response, err := h.http.Do(request)
23 23
 	if err != nil {
24 24
 		if response != nil {
25
-			io.Copy(io.Discard, response.Body)
25
+			io.Copy(io.Discard, response.Body) // nolint: errcheck
26 26
 			response.Body.Close()
27 27
 		}
28 28
 
@@ -36,6 +36,10 @@ func (h httpFile) Open(ctx context.Context) (io.ReadCloser, error) {
36 36
 	return response.Body, nil
37 37
 }
38 38
 
39
+func (h httpFile) String() string {
40
+	return h.url
41
+}
42
+
39 43
 func NewHTTP(client *http.Client, endpoint string) (File, error) {
40 44
 	if client == nil {
41 45
 		return nil, ErrBadHTTPClient

+ 1
- 1
ipblocklist/files/http_test.go 查看文件

@@ -22,7 +22,7 @@ type HTTPTestSuite struct {
22 22
 }
23 23
 
24 24
 func (suite *HTTPTestSuite) makeFile(path string) (files.File, error) {
25
-	return files.NewHTTP(suite.httpClient, suite.httpServer.URL+"/"+path)
25
+	return files.NewHTTP(suite.httpClient, suite.httpServer.URL+"/"+path) // nolint: wrapcheck
26 26
 }
27 27
 
28 28
 func (suite *HTTPTestSuite) SetupSuite() {

+ 1
- 0
ipblocklist/files/init.go 查看文件

@@ -10,4 +10,5 @@ var ErrBadHTTPClient = errors.New("incorrect http client")
10 10
 
11 11
 type File interface {
12 12
 	Open(context.Context) (io.ReadCloser, error)
13
+	String() string
13 14
 }

+ 7
- 7
ipblocklist/files/local.go 查看文件

@@ -4,18 +4,19 @@ import (
4 4
 	"context"
5 5
 	"fmt"
6 6
 	"io"
7
-	"io/fs"
8 7
 	"os"
9
-	"path/filepath"
10 8
 )
11 9
 
12 10
 type localFile struct {
13
-	root fs.FS
14
-	name string
11
+	path string
15 12
 }
16 13
 
17 14
 func (l localFile) Open(ctx context.Context) (io.ReadCloser, error) {
18
-	return l.root.Open(l.name)
15
+	return os.Open(l.path) // nolint: wrapcheck
16
+}
17
+
18
+func (l localFile) String() string {
19
+	return l.path
19 20
 }
20 21
 
21 22
 func NewLocal(path string) (File, error) {
@@ -24,7 +25,6 @@ func NewLocal(path string) (File, error) {
24 25
 	}
25 26
 
26 27
 	return localFile{
27
-		root: os.DirFS(filepath.Dir(path)),
28
-		name: filepath.Base(path),
28
+		path: path,
29 29
 	}, nil
30 30
 }

+ 50
- 126
ipblocklist/firehol.go 查看文件

@@ -4,16 +4,13 @@ import (
4 4
 	"bufio"
5 5
 	"context"
6 6
 	"fmt"
7
-	"io"
8 7
 	"net"
9
-	"net/http"
10
-	"net/url"
11
-	"os"
12 8
 	"regexp"
13 9
 	"strings"
14 10
 	"sync"
15 11
 	"time"
16 12
 
13
+	"github.com/9seconds/mtg/v2/ipblocklist/files"
17 14
 	"github.com/9seconds/mtg/v2/mtglib"
18 15
 	"github.com/kentik/patricia"
19 16
 	"github.com/kentik/patricia/bool_tree"
@@ -41,20 +38,16 @@ var fireholRegexpComment = regexp.MustCompile(`\s*#.*?$`)
41 38
 //     127.0.0.1   # you can specify an IP
42 39
 //     10.0.0.0/8  # or cidr
43 40
 type Firehol struct {
44
-	ctx       context.Context
45
-	ctxCancel context.CancelFunc
46
-	logger    mtglib.Logger
47
-
41
+	ctx         context.Context
42
+	ctxCancel   context.CancelFunc
43
+	logger      mtglib.Logger
48 44
 	updateMutex sync.RWMutex
49 45
 
50
-	remoteURLs []string
51
-	localFiles []string
46
+	blocklists []files.File
52 47
 
53
-	httpClient *http.Client
54 48
 	workerPool *ants.Pool
55
-
56
-	treeV4 *bool_tree.TreeV4
57
-	treeV6 *bool_tree.TreeV6
49
+	treeV4     *bool_tree.TreeV4
50
+	treeV6     *bool_tree.TreeV6
58 51
 }
59 52
 
60 53
 // Shutdown stop a background update process.
@@ -98,22 +91,14 @@ func (f *Firehol) Run(updateEach time.Duration) {
98 91
 		}
99 92
 	}()
100 93
 
101
-	if err := f.update(); err != nil {
102
-		f.logger.WarningError("cannot update blocklist", err)
103
-	} else {
104
-		f.logger.Info("blocklist was updated")
105
-	}
94
+	f.update()
106 95
 
107 96
 	for {
108 97
 		select {
109 98
 		case <-f.ctx.Done():
110 99
 			return
111 100
 		case <-ticker.C:
112
-			if err := f.update(); err != nil {
113
-				f.logger.WarningError("cannot update blocklist", err)
114
-			} else {
115
-				f.logger.Info("blocklist was updated")
116
-			}
101
+			f.update()
117 102
 		}
118 103
 	}
119 104
 }
@@ -138,121 +123,53 @@ func (f *Firehol) containsIPv6(addr net.IP) bool {
138 123
 	return false
139 124
 }
140 125
 
141
-func (f *Firehol) update() error { // nolint: funlen, cyclop
126
+func (f *Firehol) update() {
142 127
 	ctx, cancel := context.WithCancel(f.ctx)
143 128
 	defer cancel()
144 129
 
145 130
 	wg := &sync.WaitGroup{}
146
-	wg.Add(len(f.remoteURLs) + len(f.localFiles))
131
+	wg.Add(len(f.blocklists))
147 132
 
148 133
 	treeMutex := &sync.Mutex{}
149 134
 	v4tree := bool_tree.NewTreeV4()
150 135
 	v6tree := bool_tree.NewTreeV6()
151 136
 
152
-	errorChan := make(chan error, 1)
153
-	defer close(errorChan)
154
-
155
-	for _, v := range f.localFiles {
156
-		go func(filename string) {
137
+	for _, v := range f.blocklists {
138
+		go func(file files.File) {
157 139
 			defer wg.Done()
158 140
 
159
-			if err := f.updateLocalFile(ctx, filename, treeMutex, v4tree, v6tree); err != nil {
160
-				cancel()
161
-				f.logger.BindStr("filename", filename).WarningError("cannot update", err)
141
+			logger := f.logger.BindStr("filename", file.String())
162 142
 
163
-				select {
164
-				case errorChan <- err:
165
-				default:
166
-				}
167
-			}
168
-		}(v)
169
-	}
143
+			fileContent, err := file.Open(ctx)
144
+			if err != nil {
145
+				logger.WarningError("update has failed", err)
170 146
 
171
-	for _, v := range f.remoteURLs {
172
-		value := v
173
-
174
-		f.workerPool.Submit(func() { // nolint: errcheck
175
-			defer wg.Done()
147
+				return
148
+			}
176 149
 
177
-			if err := f.updateRemoteURL(ctx, value, treeMutex, v4tree, v6tree); err != nil {
178
-				cancel()
179
-				f.logger.BindStr("url", value).WarningError("cannot update", err)
150
+			defer fileContent.Close()
180 151
 
181
-				select {
182
-				case errorChan <- err:
183
-				default:
184
-				}
152
+			if err := f.updateFromFile(treeMutex, v4tree, v6tree, bufio.NewScanner(fileContent)); err != nil {
153
+				logger.WarningError("update has failed", err)
185 154
 			}
186
-		})
155
+		}(v)
187 156
 	}
188 157
 
189 158
 	wg.Wait()
190 159
 
191
-	select {
192
-	case err := <-errorChan:
193
-		return fmt.Errorf("cannot update trees: %w", err)
194
-	default:
195
-	}
196
-
197 160
 	f.updateMutex.Lock()
198 161
 	defer f.updateMutex.Unlock()
199 162
 
200 163
 	f.treeV4 = v4tree
201 164
 	f.treeV6 = v6tree
202 165
 
203
-	return nil
204
-}
205
-
206
-func (f *Firehol) updateLocalFile(ctx context.Context, filename string,
207
-	mutex sync.Locker,
208
-	v4tree *bool_tree.TreeV4, v6tree *bool_tree.TreeV6) error {
209
-	filefp, err := os.Open(filename)
210
-	if err != nil {
211
-		return fmt.Errorf("cannot open file: %w", err)
212
-	}
213
-
214
-	go func(ctx context.Context, closer io.Closer) {
215
-		<-ctx.Done()
216
-		closer.Close()
217
-	}(ctx, filefp)
218
-
219
-	defer filefp.Close()
220
-
221
-	return f.updateTrees(mutex, filefp, v4tree, v6tree)
166
+	f.logger.Info("blocklist was updated")
222 167
 }
223 168
 
224
-func (f *Firehol) updateRemoteURL(ctx context.Context, url string,
225
-	mutex sync.Locker,
226
-	v4tree *bool_tree.TreeV4, v6tree *bool_tree.TreeV6) error {
227
-	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
228
-	if err != nil {
229
-		return fmt.Errorf("cannot build a request: %w", err)
230
-	}
231
-
232
-	resp, err := f.httpClient.Do(req) // nolint: bodyclose
233
-	if err != nil {
234
-		return fmt.Errorf("cannot request a remote URL %s: %w", url, err)
235
-	}
236
-
237
-	go func(ctx context.Context, closer io.Closer) {
238
-		<-ctx.Done()
239
-		closer.Close()
240
-	}(ctx, resp.Body)
241
-
242
-	defer func(rc io.ReadCloser) {
243
-		io.Copy(io.Discard, rc) // nolint: errcheck
244
-		rc.Close()
245
-	}(resp.Body)
246
-
247
-	return f.updateTrees(mutex, resp.Body, v4tree, v6tree)
248
-}
249
-
250
-func (f *Firehol) updateTrees(mutex sync.Locker,
251
-	reader io.Reader,
169
+func (f *Firehol) updateFromFile(mutex sync.Locker,
252 170
 	v4tree *bool_tree.TreeV4,
253
-	v6tree *bool_tree.TreeV6) error {
254
-	scanner := bufio.NewScanner(reader)
255
-
171
+	v6tree *bool_tree.TreeV6,
172
+	scanner *bufio.Scanner) error {
256 173
 	for scanner.Scan() {
257 174
 		text := scanner.Text()
258 175
 		text = fireholRegexpComment.ReplaceAllLiteralString(text, "")
@@ -271,7 +188,7 @@ func (f *Firehol) updateTrees(mutex sync.Locker,
271 188
 	}
272 189
 
273 190
 	if scanner.Err() != nil {
274
-		return fmt.Errorf("cannot parse a response: %w", scanner.Err())
191
+		return fmt.Errorf("cannot parse a file: %w", scanner.Err())
275 192
 	}
276 193
 
277 194
 	return nil
@@ -317,27 +234,36 @@ func (f *Firehol) updateAddToTrees(ip net.IP, cidr uint,
317 234
 // when it is necessary.
318 235
 func NewFirehol(logger mtglib.Logger, network mtglib.Network,
319 236
 	downloadConcurrency uint,
320
-	remoteURLs []string,
237
+	urls []string,
321 238
 	localFiles []string) (*Firehol, error) {
322
-	for _, v := range remoteURLs {
323
-		parsed, err := url.Parse(v)
239
+	blocklists := []files.File{}
240
+
241
+	for _, v := range localFiles {
242
+		file, err := files.NewLocal(v)
324 243
 		if err != nil {
325
-			return nil, fmt.Errorf("incorrect url %s: %w", v, err)
244
+			return nil, fmt.Errorf("cannot create a local file %s: %w", v, err)
326 245
 		}
327 246
 
328
-		switch parsed.Scheme {
329
-		case "http", "https":
330
-		default:
331
-			return nil, fmt.Errorf("unsupported url %s", v)
332
-		}
247
+		blocklists = append(blocklists, file)
333 248
 	}
334 249
 
335
-	for _, v := range localFiles {
336
-		if stat, err := os.Stat(v); os.IsNotExist(err) || stat.IsDir() || stat.Mode().Perm()&0o400 == 0 {
337
-			return nil, fmt.Errorf("%s is not a readable file", v)
250
+	httpClient := network.MakeHTTPClient(nil)
251
+
252
+	for _, v := range urls {
253
+		file, err := files.NewHTTP(httpClient, v)
254
+		if err != nil {
255
+			return nil, fmt.Errorf("cannot create a HTTP file %s: %w", v, err)
338 256
 		}
257
+
258
+		blocklists = append(blocklists, file)
339 259
 	}
340 260
 
261
+	return NewFireholFromFiles(logger, downloadConcurrency, blocklists)
262
+}
263
+
264
+func NewFireholFromFiles(logger mtglib.Logger,
265
+	downloadConcurrency uint,
266
+	blocklists []files.File) (*Firehol, error) {
341 267
 	if downloadConcurrency == 0 {
342 268
 		downloadConcurrency = DefaultFireholDownloadConcurrency
343 269
 	}
@@ -349,11 +275,9 @@ func NewFirehol(logger mtglib.Logger, network mtglib.Network,
349 275
 		ctx:        ctx,
350 276
 		ctxCancel:  cancel,
351 277
 		logger:     logger.Named("firehol"),
352
-		httpClient: network.MakeHTTPClient(nil),
353 278
 		treeV4:     bool_tree.NewTreeV4(),
354 279
 		treeV6:     bool_tree.NewTreeV6(),
355 280
 		workerPool: workerPool,
356
-		remoteURLs: remoteURLs,
357
-		localFiles: localFiles,
281
+		blocklists: blocklists,
358 282
 	}, nil
359 283
 }

Loading…
取消
儲存