package cli import ( "context" "fmt" "net" "os" "time" "github.com/dolonet/mtg-multi/antireplay" "github.com/dolonet/mtg-multi/events" "github.com/dolonet/mtg-multi/internal/config" "github.com/dolonet/mtg-multi/internal/proxyprotocol" "github.com/dolonet/mtg-multi/internal/utils" "github.com/dolonet/mtg-multi/ipblocklist" "github.com/dolonet/mtg-multi/ipblocklist/files" "github.com/dolonet/mtg-multi/logger" "github.com/dolonet/mtg-multi/mtglib" "github.com/dolonet/mtg-multi/network/v2" "github.com/dolonet/mtg-multi/stats" "github.com/pires/go-proxyproto" "github.com/rs/zerolog" "github.com/yl2chen/cidranger" ) func makeLogger(conf *config.Config) mtglib.Logger { zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs zerolog.TimestampFieldName = "timestamp" zerolog.LevelFieldName = "level" if conf.Debug.Get(false) { zerolog.SetGlobalLevel(zerolog.DebugLevel) } else { zerolog.SetGlobalLevel(zerolog.WarnLevel) } baseLogger := zerolog.New(os.Stdout).With().Timestamp().Logger() return logger.NewZeroLogger(baseLogger) } func makeNetwork(conf *config.Config, version string) (mtglib.Network, error) { resolver, err := network.GetDNS(conf.GetDNS()) if err != nil { return nil, fmt.Errorf("cannot create DNS resolver: %w", err) } base := network.New( resolver, "", conf.Network.Timeout.TCP.Get(0), conf.Network.Timeout.HTTP.Get(0), conf.Network.Timeout.Idle.Get(0), ) proxyDialers := make([]mtglib.Network, len(conf.Network.Proxies)) for idx, v := range conf.Network.Proxies { value, err := network.NewProxyNetwork(base, v.Get(nil)) if err != nil { return nil, fmt.Errorf("cannot use %v for proxy url: %w", v.Get(nil), err) } proxyDialers[idx] = value } switch len(proxyDialers) { case 0: return base, nil case 1: return proxyDialers[0], nil } value, err := network.Join(proxyDialers...) if err != nil { panic(err) } return value, nil } func makeAntiReplayCache(conf *config.Config) mtglib.AntiReplayCache { if !conf.Defense.AntiReplay.Enabled.Get(false) { return antireplay.NewNoop() } return antireplay.NewStableBloomFilter( conf.Defense.AntiReplay.MaxSize.Get(antireplay.DefaultStableBloomFilterMaxSize), conf.Defense.AntiReplay.ErrorRate.Get(antireplay.DefaultStableBloomFilterErrorRate), ) } func makeIPBlocklist(conf config.ListConfig, logger mtglib.Logger, ntw mtglib.Network, updateCallback ipblocklist.FireholUpdateCallback, ) (mtglib.IPBlocklist, error) { if !conf.Enabled.Get(false) { return ipblocklist.NewNoop(), nil } remoteURLs := []string{} localFiles := []string{} for _, v := range conf.URLs { if v.IsRemote() { remoteURLs = append(remoteURLs, v.String()) } else { localFiles = append(localFiles, v.String()) } } blocklist, err := ipblocklist.NewFirehol(logger.Named("ipblockist"), ntw, conf.DownloadConcurrency.Get(1), remoteURLs, localFiles, updateCallback) if err != nil { return nil, fmt.Errorf("incorrect parameters for firehol: %w", err) } go blocklist.Run(conf.UpdateEach.Get(ipblocklist.DefaultFireholUpdateEach)) return blocklist, nil } func makeIPAllowlist(conf config.ListConfig, logger mtglib.Logger, ntw mtglib.Network, updateCallback ipblocklist.FireholUpdateCallback, ) (mtglib.IPBlocklist, error) { var ( allowlist mtglib.IPBlocklist err error ) if !conf.Enabled.Get(false) { allowlist, err = ipblocklist.NewFireholFromFiles( logger.Named("ipblocklist"), 1, []files.File{ files.NewMem([]*net.IPNet{ cidranger.AllIPv4, cidranger.AllIPv6, }), }, updateCallback, ) go allowlist.Run(conf.UpdateEach.Get(ipblocklist.DefaultFireholUpdateEach)) } else { allowlist, err = makeIPBlocklist( conf, logger, ntw, updateCallback, ) } if err != nil { return nil, fmt.Errorf("cannot build allowlist: %w", err) } return allowlist, nil } func makeEventStream(conf *config.Config, logger mtglib.Logger) (mtglib.EventStream, error) { factories := make([]events.ObserverFactory, 0, 2) if conf.Stats.StatsD.Enabled.Get(false) { statsdFactory, err := stats.NewStatsd( conf.Stats.StatsD.Address.Get(""), logger.Named("statsd"), conf.Stats.StatsD.MetricPrefix.Get(stats.DefaultStatsdMetricPrefix), conf.Stats.StatsD.TagFormat.Get(stats.DefaultStatsdTagFormat)) if err != nil { return nil, fmt.Errorf("cannot build statsd observer: %w", err) } factories = append(factories, statsdFactory.Make) } if conf.Stats.Prometheus.Enabled.Get(false) { prometheus := stats.NewPrometheus( conf.Stats.Prometheus.MetricPrefix.Get(stats.DefaultMetricPrefix), conf.Stats.Prometheus.HTTPPath.Get("/"), ) listener, err := net.Listen("tcp", conf.Stats.Prometheus.BindTo.Get("")) if err != nil { return nil, fmt.Errorf("cannot start a listener for prometheus: %w", err) } go prometheus.Serve(listener) //nolint: errcheck factories = append(factories, prometheus.Make) } if len(factories) > 0 { return events.NewEventStream(factories), nil } return events.NewNoopStream(), nil } func runProxy(conf *config.Config, version string) error { //nolint: funlen logger := makeLogger(conf) logger.BindJSON("configuration", conf.String()).Debug("configuration") eventStream, err := makeEventStream(conf, logger) if err != nil { return fmt.Errorf("cannot build event stream: %w", err) } ntw, err := makeNetwork(conf, version) if err != nil { return fmt.Errorf("cannot build network: %w", err) } blocklist, err := makeIPBlocklist( conf.Defense.Blocklist, logger.Named("blocklist"), ntw, func(ctx context.Context, size int) { eventStream.Send(ctx, mtglib.NewEventIPListSize(size, true)) }) if err != nil { return fmt.Errorf("cannot build ip blocklist: %w", err) } allowlist, err := makeIPAllowlist( conf.Defense.Allowlist, logger.Named("allowlist"), ntw, func(ctx context.Context, size int) { eventStream.Send(ctx, mtglib.NewEventIPListSize(size, false)) }, ) if err != nil { return fmt.Errorf("cannot build ip allowlist: %w", err) } doppelGangerURLs := make([]string, len(conf.Defense.Doppelganger.URLs)) for i, v := range conf.Defense.Doppelganger.URLs { doppelGangerURLs[i] = v.String() } opts := mtglib.ProxyOpts{ Logger: logger, Network: ntw, AntiReplayCache: makeAntiReplayCache(conf), IPBlocklist: blocklist, IPAllowlist: allowlist, EventStream: eventStream, Secrets: conf.GetSecrets(), Concurrency: conf.GetConcurrency(mtglib.DefaultConcurrency), DomainFrontingPort: conf.GetDomainFrontingPort(mtglib.DefaultDomainFrontingPort), DomainFrontingIP: conf.GetDomainFrontingIP(nil), DomainFrontingProxyProtocol: conf.GetDomainFrontingProxyProtocol(false), PreferIP: conf.PreferIP.Get(mtglib.DefaultPreferIP), AutoUpdate: conf.AutoUpdate.Get(false), AllowFallbackOnUnknownDC: conf.AllowFallbackOnUnknownDC.Get(false), TolerateTimeSkewness: conf.TolerateTimeSkewness.Value, IdleTimeout: conf.Network.Timeout.Idle.Get(time.Minute), DoppelGangerURLs: doppelGangerURLs, DoppelGangerPerRaid: conf.Defense.Doppelganger.Repeats.Get(mtglib.DoppelGangerPerRaid), DoppelGangerEach: conf.Defense.Doppelganger.UpdateEach.Get(mtglib.DoppelGangerEach), DoppelGangerDRS: conf.Defense.Doppelganger.DRS.Get(false), APIBindTo: conf.APIBindTo.Get(""), ThrottleMaxConnections: conf.Throttle.MaxConnections.Get(0), ThrottleCheckInterval: conf.Throttle.CheckInterval.Get(5 * time.Second), } proxy, err := mtglib.NewProxy(opts) if err != nil { return fmt.Errorf("cannot create a proxy: %w", err) } listener, err := utils.NewListener(conf.BindTo.Get(""), 0) if err != nil { return fmt.Errorf("cannot start proxy: %w", err) } if conf.ProxyProtocolListener.Get(false) { listener = &proxyprotocol.ListenerAdapter{ Listener: proxyproto.Listener{ Listener: listener, }, } } ctx := utils.RootContext() go proxy.Serve(listener) //nolint: errcheck <-ctx.Done() listener.Close() //nolint: errcheck proxy.Shutdown() return nil }