Skip to content

Commit d48db2c

Browse files
committed
chore: cleanup dns policy match code
1 parent 4c10d42 commit d48db2c

File tree

9 files changed

+223
-400
lines changed

9 files changed

+223
-400
lines changed

config/config.go

+125-145
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ import (
2020
N "github.com/metacubex/mihomo/common/net"
2121
"github.com/metacubex/mihomo/common/utils"
2222
"github.com/metacubex/mihomo/component/auth"
23+
"github.com/metacubex/mihomo/component/cidr"
2324
"github.com/metacubex/mihomo/component/fakeip"
2425
"github.com/metacubex/mihomo/component/geodata"
25-
"github.com/metacubex/mihomo/component/geodata/router"
2626
P "github.com/metacubex/mihomo/component/process"
2727
"github.com/metacubex/mihomo/component/resolver"
2828
SNIFF "github.com/metacubex/mihomo/component/sniffer"
@@ -114,33 +114,25 @@ type NTP struct {
114114

115115
// DNS config
116116
type DNS struct {
117-
Enable bool `yaml:"enable"`
118-
PreferH3 bool `yaml:"prefer-h3"`
119-
IPv6 bool `yaml:"ipv6"`
120-
IPv6Timeout uint `yaml:"ipv6-timeout"`
121-
UseSystemHosts bool `yaml:"use-system-hosts"`
122-
NameServer []dns.NameServer `yaml:"nameserver"`
123-
Fallback []dns.NameServer `yaml:"fallback"`
124-
FallbackFilter FallbackFilter `yaml:"fallback-filter"`
125-
Listen string `yaml:"listen"`
126-
EnhancedMode C.DNSMode `yaml:"enhanced-mode"`
127-
DefaultNameserver []dns.NameServer `yaml:"default-nameserver"`
128-
CacheAlgorithm string `yaml:"cache-algorithm"`
117+
Enable bool
118+
PreferH3 bool
119+
IPv6 bool
120+
IPv6Timeout uint
121+
UseSystemHosts bool
122+
NameServer []dns.NameServer
123+
Fallback []dns.NameServer
124+
FallbackIPFilter []C.Rule
125+
FallbackDomainFilter []C.Rule
126+
Listen string
127+
EnhancedMode C.DNSMode
128+
DefaultNameserver []dns.NameServer
129+
CacheAlgorithm string
129130
FakeIPRange *fakeip.Pool
130131
Hosts *trie.DomainTrie[resolver.HostValue]
131-
NameServerPolicy *orderedmap.OrderedMap[string, []dns.NameServer]
132+
NameServerPolicy []dns.Policy
132133
ProxyServerNameserver []dns.NameServer
133134
}
134135

135-
// FallbackFilter config
136-
type FallbackFilter struct {
137-
GeoIP bool `yaml:"geoip"`
138-
GeoIPCode string `yaml:"geoip-code"`
139-
IPCIDR []netip.Prefix `yaml:"ipcidr"`
140-
Domain []string `yaml:"domain"`
141-
GeoSite []router.DomainMatcher `yaml:"geosite"`
142-
}
143-
144136
// Profile config
145137
type Profile struct {
146138
StoreSelected bool `yaml:"store-selected"`
@@ -1205,125 +1197,81 @@ func parsePureDNSServer(server string) string {
12051197
}
12061198
}
12071199

1208-
func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], ruleProviders map[string]providerTypes.RuleProvider, respectRules bool, preferH3 bool) (*orderedmap.OrderedMap[string, []dns.NameServer], error) {
1209-
policy := orderedmap.New[string, []dns.NameServer]()
1210-
updatedPolicy := orderedmap.New[string, any]()
1200+
func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rules []C.Rule, ruleProviders map[string]providerTypes.RuleProvider, respectRules bool, preferH3 bool) ([]dns.Policy, error) {
1201+
var tmpPolicy []dns.Policy
12111202
re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`)
12121203

12131204
for pair := nsPolicy.Oldest(); pair != nil; pair = pair.Next() {
12141205
k, v := pair.Key, pair.Value
1206+
servers, err := utils.ToStringSlice(v)
1207+
if err != nil {
1208+
return nil, err
1209+
}
1210+
nameservers, err := parseNameServer(servers, respectRules, preferH3)
1211+
if err != nil {
1212+
return nil, err
1213+
}
12151214
if strings.Contains(strings.ToLower(k), ",") {
12161215
if strings.Contains(k, "geosite:") {
12171216
subkeys := strings.Split(k, ":")
12181217
subkeys = subkeys[1:]
12191218
subkeys = strings.Split(subkeys[0], ",")
12201219
for _, subkey := range subkeys {
12211220
newKey := "geosite:" + subkey
1222-
updatedPolicy.Store(newKey, v)
1221+
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: newKey, NameServers: nameservers})
12231222
}
12241223
} else if strings.Contains(strings.ToLower(k), "rule-set:") {
12251224
subkeys := strings.Split(k, ":")
12261225
subkeys = subkeys[1:]
12271226
subkeys = strings.Split(subkeys[0], ",")
12281227
for _, subkey := range subkeys {
12291228
newKey := "rule-set:" + subkey
1230-
updatedPolicy.Store(newKey, v)
1229+
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: newKey, NameServers: nameservers})
12311230
}
12321231
} else if re.MatchString(k) {
12331232
subkeys := strings.Split(k, ",")
12341233
for _, subkey := range subkeys {
1235-
updatedPolicy.Store(subkey, v)
1234+
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: subkey, NameServers: nameservers})
12361235
}
12371236
}
12381237
} else {
12391238
if strings.Contains(strings.ToLower(k), "geosite:") {
1240-
updatedPolicy.Store("geosite:"+k[8:], v)
1239+
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers})
12411240
} else if strings.Contains(strings.ToLower(k), "rule-set:") {
1242-
updatedPolicy.Store("rule-set:"+k[9:], v)
1243-
}
1244-
updatedPolicy.Store(k, v)
1245-
}
1246-
}
1247-
1248-
for pair := updatedPolicy.Oldest(); pair != nil; pair = pair.Next() {
1249-
domain, server := pair.Key, pair.Value
1250-
servers, err := utils.ToStringSlice(server)
1251-
if err != nil {
1252-
return nil, err
1253-
}
1254-
nameservers, err := parseNameServer(servers, respectRules, preferH3)
1255-
if err != nil {
1256-
return nil, err
1257-
}
1258-
if _, valid := trie.ValidAndSplitDomain(domain); !valid {
1259-
return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain)
1260-
}
1261-
if strings.HasPrefix(domain, "rule-set:") {
1262-
domainSetName := domain[9:]
1263-
if provider, ok := ruleProviders[domainSetName]; !ok {
1264-
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
1241+
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers})
12651242
} else {
1266-
switch provider.Behavior() {
1267-
case providerTypes.IPCIDR:
1268-
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", provider.Behavior())
1269-
case providerTypes.Classical:
1270-
log.Warnln("%s provider is %s, only matching it contain domain rule", provider.Name(), provider.Behavior())
1271-
}
1243+
tmpPolicy = append(tmpPolicy, dns.Policy{Domain: k, NameServers: nameservers})
12721244
}
12731245
}
1274-
policy.Store(domain, nameservers)
12751246
}
12761247

1277-
return policy, nil
1278-
}
1279-
1280-
func parseFallbackIPCIDR(ips []string) ([]netip.Prefix, error) {
1281-
var ipNets []netip.Prefix
1282-
1283-
for idx, ip := range ips {
1284-
ipnet, err := netip.ParsePrefix(ip)
1285-
if err != nil {
1286-
return nil, fmt.Errorf("DNS FallbackIP[%d] format error: %s", idx, err.Error())
1287-
}
1288-
ipNets = append(ipNets, ipnet)
1289-
}
1290-
1291-
return ipNets, nil
1292-
}
1293-
1294-
func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]router.DomainMatcher, error) {
1295-
var sites []router.DomainMatcher
1296-
if len(countries) > 0 {
1297-
if err := geodata.InitGeoSite(); err != nil {
1298-
return nil, fmt.Errorf("can't initial GeoSite: %s", err)
1299-
}
1300-
log.Warnln("replace fallback-filter.geosite with nameserver-policy, it will be removed in the future")
1301-
}
1248+
var policy []dns.Policy
1249+
for _, p := range tmpPolicy {
1250+
domain, nameservers := p.Domain, p.NameServers
13021251

1303-
for _, country := range countries {
1304-
found := false
1305-
for _, rule := range rules {
1306-
if rule.RuleType() == C.GEOSITE {
1307-
if strings.EqualFold(country, rule.Payload()) {
1308-
found = true
1309-
sites = append(sites, rule.(C.RuleGeoSite).GetDomainMatcher())
1310-
log.Infoln("Start initial GeoSite dns fallback filter from rule `%s`", country)
1311-
}
1252+
if strings.HasPrefix(domain, "rule-set:") {
1253+
domainSetName := domain[9:]
1254+
rule, err := parseDomainRuleSet(domainSetName, ruleProviders)
1255+
if err != nil {
1256+
return nil, err
13121257
}
1313-
}
1314-
1315-
if !found {
1316-
matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country)
1258+
policy = append(policy, dns.Policy{Rule: rule, NameServers: nameservers})
1259+
} else if strings.HasPrefix(domain, "geosite:") {
1260+
country := domain[8:]
1261+
rule, err := parseGEOSITE(country, rules)
13171262
if err != nil {
13181263
return nil, err
13191264
}
1320-
1321-
sites = append(sites, matcher)
1322-
1323-
log.Infoln("Start initial GeoSite dns fallback filter `%s`, records: %d", country, recordsCount)
1265+
policy = append(policy, dns.Policy{Rule: rule, NameServers: nameservers})
1266+
} else {
1267+
if _, valid := trie.ValidAndSplitDomain(domain); !valid {
1268+
return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain)
1269+
}
1270+
policy = append(policy, dns.Policy{Domain: domain, NameServers: nameservers})
13241271
}
13251272
}
1326-
return sites, nil
1273+
1274+
return policy, nil
13271275
}
13281276

13291277
func paresNTP(rawCfg *RawConfig) *NTP {
@@ -1357,10 +1305,6 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
13571305
IPv6: cfg.IPv6,
13581306
UseSystemHosts: cfg.UseSystemHosts,
13591307
EnhancedMode: cfg.EnhancedMode,
1360-
FallbackFilter: FallbackFilter{
1361-
IPCIDR: []netip.Prefix{},
1362-
GeoSite: []router.DomainMatcher{},
1363-
},
13641308
}
13651309
var err error
13661310
if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer, cfg.RespectRules, cfg.PreferH3); err != nil {
@@ -1371,7 +1315,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
13711315
return nil, err
13721316
}
13731317

1374-
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, ruleProviders, cfg.RespectRules, cfg.PreferH3); err != nil {
1318+
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, rules, ruleProviders, cfg.RespectRules, cfg.PreferH3); err != nil {
13751319
return nil, err
13761320
}
13771321

@@ -1438,18 +1382,51 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
14381382
dnsCfg.FakeIPRange = pool
14391383
}
14401384

1385+
var rule C.Rule
14411386
if len(cfg.Fallback) != 0 {
1442-
dnsCfg.FallbackFilter.GeoIP = cfg.FallbackFilter.GeoIP
1443-
dnsCfg.FallbackFilter.GeoIPCode = cfg.FallbackFilter.GeoIPCode
1444-
if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
1445-
dnsCfg.FallbackFilter.IPCIDR = fallbackip
1387+
if cfg.FallbackFilter.GeoIP {
1388+
rule, err = RC.NewGEOIP(cfg.FallbackFilter.GeoIPCode, "", false, true)
1389+
if err != nil {
1390+
return nil, fmt.Errorf("load GeoIP dns fallback filter error, %w", err)
1391+
}
1392+
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
14461393
}
1447-
dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain
1448-
fallbackGeoSite, err := parseFallbackGeoSite(cfg.FallbackFilter.GeoSite, rules)
1449-
if err != nil {
1450-
return nil, fmt.Errorf("load GeoSite dns fallback filter error, %w", err)
1394+
if len(cfg.FallbackFilter.IPCIDR) > 0 {
1395+
cidrSet := cidr.NewIpCidrSet()
1396+
for idx, ipcidr := range cfg.FallbackFilter.IPCIDR {
1397+
err = cidrSet.AddIpCidrForString(ipcidr)
1398+
if err != nil {
1399+
return nil, fmt.Errorf("DNS FallbackIP[%d] format error: %w", idx, err)
1400+
}
1401+
}
1402+
err = cidrSet.Merge()
1403+
if err != nil {
1404+
return nil, err
1405+
}
1406+
rule = RP.NewIpCidrSet(cidrSet, "")
1407+
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
1408+
}
1409+
if len(cfg.FallbackFilter.Domain) > 0 {
1410+
domainTrie := trie.New[struct{}]()
1411+
for idx, domain := range cfg.FallbackFilter.Domain {
1412+
err = domainTrie.Insert(domain, struct{}{})
1413+
if err != nil {
1414+
return nil, fmt.Errorf("DNS FallbackDomain[%d] format error: %w", idx, err)
1415+
}
1416+
}
1417+
rule = RP.NewDomainSet(domainTrie.NewDomainSet(), "")
1418+
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
1419+
}
1420+
if len(cfg.FallbackFilter.GeoSite) > 0 {
1421+
log.Warnln("replace fallback-filter.geosite with nameserver-policy, it will be removed in the future")
1422+
for idx, geoSite := range cfg.FallbackFilter.GeoSite {
1423+
rule, err = parseGEOSITE(geoSite, rules)
1424+
if err != nil {
1425+
return nil, fmt.Errorf("DNS FallbackGeosite[%d] format error: %w", idx, err)
1426+
}
1427+
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
1428+
}
14511429
}
1452-
dnsCfg.FallbackFilter.GeoSite = fallbackGeoSite
14531430
}
14541431

14551432
if cfg.UseHosts {
@@ -1636,44 +1613,21 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], rules
16361613
subkeys = subkeys[1:]
16371614
subkeys = strings.Split(subkeys[0], ",")
16381615
for _, country := range subkeys {
1639-
found := false
1640-
for _, rule = range rules {
1641-
if rule.RuleType() == C.GEOSITE {
1642-
if strings.EqualFold(country, rule.Payload()) {
1643-
found = true
1644-
domainRules = append(domainRules, rule)
1645-
}
1646-
}
1647-
}
1648-
if !found {
1649-
rule, err = RC.NewGEOSITE(country, "")
1650-
if err != nil {
1651-
return nil, err
1652-
}
1653-
domainRules = append(domainRules, rule)
1616+
rule, err = parseGEOSITE(country, rules)
1617+
if err != nil {
1618+
return nil, err
16541619
}
1620+
domainRules = append(domainRules, rule)
16551621
}
16561622
} else if strings.Contains(domainLower, "rule-set:") {
16571623
subkeys := strings.Split(domain, ":")
16581624
subkeys = subkeys[1:]
16591625
subkeys = strings.Split(subkeys[0], ",")
16601626
for _, domainSetName := range subkeys {
1661-
if rp, ok := ruleProviders[domainSetName]; !ok {
1662-
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
1663-
} else {
1664-
switch rp.Behavior() {
1665-
case providerTypes.IPCIDR:
1666-
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", rp.Behavior())
1667-
case providerTypes.Classical:
1668-
log.Warnln("%s provider is %s, only matching it contain domain rule", rp.Name(), rp.Behavior())
1669-
default:
1670-
}
1671-
}
1672-
rule, err = RP.NewRuleSet(domainSetName, "", true)
1627+
rule, err = parseDomainRuleSet(domainSetName, ruleProviders)
16731628
if err != nil {
16741629
return nil, err
16751630
}
1676-
16771631
domainRules = append(domainRules, rule)
16781632
}
16791633
} else {
@@ -1692,3 +1646,29 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], rules
16921646
}
16931647
return
16941648
}
1649+
1650+
func parseDomainRuleSet(domainSetName string, ruleProviders map[string]providerTypes.RuleProvider) (C.Rule, error) {
1651+
if rp, ok := ruleProviders[domainSetName]; !ok {
1652+
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
1653+
} else {
1654+
switch rp.Behavior() {
1655+
case providerTypes.IPCIDR:
1656+
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", rp.Behavior())
1657+
case providerTypes.Classical:
1658+
log.Warnln("%s provider is %s, only matching it contain domain rule", rp.Name(), rp.Behavior())
1659+
default:
1660+
}
1661+
}
1662+
return RP.NewRuleSet(domainSetName, "", true)
1663+
}
1664+
1665+
func parseGEOSITE(country string, rules []C.Rule) (C.Rule, error) {
1666+
for _, rule := range rules {
1667+
if rule.RuleType() == C.GEOSITE {
1668+
if strings.EqualFold(country, rule.Payload()) {
1669+
return rule, nil
1670+
}
1671+
}
1672+
}
1673+
return RC.NewGEOSITE(country, "")
1674+
}

0 commit comments

Comments
 (0)