Files
kvm/firewall.go
luckfox-eng29 d5bfaffd86 Update App version to 0.1.2
Signed-off-by: luckfox-eng29 <eng29@luckfox.com>
2026-03-23 12:04:54 +08:00

1471 lines
35 KiB
Go

package kvm
import (
"context"
"errors"
"fmt"
"net"
"os/exec"
"strconv"
"strings"
"time"
)
const (
firewallChainInput = "KVM_UI_INPUT"
firewallChainOutput = "KVM_UI_OUTPUT"
firewallChainForward = "KVM_UI_FORWARD"
firewallChainNatPrerouting = "KVM_UI_PREROUTING"
firewallChainNatPostrouting = "KVM_UI_POSTROUTING"
firewallNatOutputCommentPrefix = "KVM_UI_PF"
)
func ApplyFirewallConfig(cfg *FirewallConfig) error {
if cfg == nil {
return nil
}
if _, err := exec.LookPath("iptables"); err != nil {
return fmt.Errorf("iptables not found: %w", err)
}
if err := validateFirewallConfig(cfg); err != nil {
return err
}
needPortForward := len(cfg.PortForwards) > 0
natSupported, err := iptablesTableSupported("nat")
if err != nil {
return err
}
if needPortForward && !natSupported {
return fmt.Errorf("iptables nat table not supported; port forwarding unavailable")
}
manageNat := natSupported
if err := ensureFirewallChains(manageNat); err != nil {
return err
}
if err := flushFirewallChains(manageNat); err != nil {
return err
}
if err := ensureFirewallJumps(manageNat); err != nil {
return err
}
if err := buildBaseRules(cfg.Base); err != nil {
return err
}
if err := buildCommunicationRules(cfg.Rules); err != nil {
return err
}
if err := buildPortForwardRules(cfg.PortForwards, manageNat); err != nil {
return err
}
if err := appendDefaultPolicies(cfg.Base); err != nil {
return err
}
return nil
}
func validateFirewallConfig(cfg *FirewallConfig) error {
policies := []struct {
name string
value string
}{
{"inputPolicy", cfg.Base.InputPolicy},
{"outputPolicy", cfg.Base.OutputPolicy},
{"forwardPolicy", cfg.Base.ForwardPolicy},
}
for _, p := range policies {
if _, err := normalizeFirewallAction(p.value); err != nil {
return fmt.Errorf("invalid %s: %w", p.name, err)
}
}
for i, r := range cfg.Rules {
if _, err := normalizeFirewallChain(r.Chain); err != nil {
return fmt.Errorf("rules[%d].chain: %w", i, err)
}
if r.SourceIP != "" && !isValidIPOrCIDR(r.SourceIP) {
return fmt.Errorf("rules[%d].sourceIP: invalid ip", i)
}
if r.DestinationIP != "" && !isValidIPOrCIDR(r.DestinationIP) {
return fmt.Errorf("rules[%d].destinationIP: invalid ip", i)
}
if r.SourcePort != nil && (*r.SourcePort < 1 || *r.SourcePort > 65535) {
return fmt.Errorf("rules[%d].sourcePort: out of range", i)
}
if r.DestinationPort != nil && (*r.DestinationPort < 1 || *r.DestinationPort > 65535) {
return fmt.Errorf("rules[%d].destinationPort: out of range", i)
}
if len(r.Protocols) == 0 {
return fmt.Errorf("rules[%d].protocols: required", i)
}
for _, proto := range r.Protocols {
if _, err := normalizeFirewallProtocol(proto); err != nil {
return fmt.Errorf("rules[%d].protocols: %w", i, err)
}
}
if _, err := normalizeFirewallAction(r.Action); err != nil {
return fmt.Errorf("rules[%d].action: %w", i, err)
}
}
for i, r := range cfg.PortForwards {
if !isManagedPortForward(r) {
continue
}
chain := strings.ToLower(strings.TrimSpace(r.Chain))
if chain == "" {
if isLocalRedirectDestination(r.DestinationIP) {
chain = "output"
} else {
chain = "prerouting"
}
}
if chain != "output" && chain != "prerouting" && chain != "prerouting_redirect" {
return fmt.Errorf("portForwards[%d].chain: unsupported", i)
}
if r.SourcePort < 1 || r.SourcePort > 65535 {
return fmt.Errorf("portForwards[%d].sourcePort: out of range", i)
}
if r.DestinationPort < 1 || r.DestinationPort > 65535 {
return fmt.Errorf("portForwards[%d].destinationPort: out of range", i)
}
if chain == "prerouting" {
ip := net.ParseIP(r.DestinationIP)
if ip == nil {
return fmt.Errorf("portForwards[%d].destinationIP: invalid ip", i)
}
if ip.IsUnspecified() {
return fmt.Errorf("portForwards[%d].destinationIP: invalid ip", i)
}
}
if chain != "prerouting" && r.DestinationIP != "" && net.ParseIP(r.DestinationIP) == nil {
return fmt.Errorf("portForwards[%d].destinationIP: invalid ip", i)
}
if len(r.Protocols) == 0 {
return fmt.Errorf("portForwards[%d].protocols: required", i)
}
for _, proto := range r.Protocols {
if _, err := normalizeFirewallProtocol(proto); err != nil {
return fmt.Errorf("portForwards[%d].protocols: %w", i, err)
}
switch strings.ToLower(strings.TrimSpace(proto)) {
case "tcp", "udp", "sctp", "dccp":
default:
return fmt.Errorf("portForwards[%d].protocols: %s not supported for port forwarding", i, proto)
}
}
}
return nil
}
func isManagedPortForward(r FirewallPortRule) bool {
return r.Managed == nil || *r.Managed
}
func isValidIPOrCIDR(s string) bool {
t := strings.TrimSpace(s)
if t == "" {
return false
}
if strings.Contains(t, "/") {
_, _, err := net.ParseCIDR(t)
return err == nil
}
return net.ParseIP(t) != nil
}
func ensureFirewallChains(needNat bool) error {
for _, chain := range []string{firewallChainInput, firewallChainOutput, firewallChainForward} {
if err := ensureChain("filter", chain); err != nil {
return err
}
}
if needNat {
for _, chain := range []string{firewallChainNatPrerouting, firewallChainNatPostrouting} {
if err := ensureChain("nat", chain); err != nil {
return err
}
}
}
return nil
}
func flushFirewallChains(needNat bool) error {
for _, chain := range []string{firewallChainInput, firewallChainOutput, firewallChainForward} {
if err := iptables("filter", "-F", chain); err != nil {
return err
}
}
if needNat {
for _, chain := range []string{firewallChainNatPrerouting, firewallChainNatPostrouting} {
if err := iptables("nat", "-F", chain); err != nil {
return err
}
}
}
return nil
}
func ensureFirewallJumps(needNat bool) error {
if err := ensureJump("filter", "INPUT", firewallChainInput); err != nil {
return err
}
if err := ensureJump("filter", "OUTPUT", firewallChainOutput); err != nil {
return err
}
if err := ensureJump("filter", "FORWARD", firewallChainForward); err != nil {
return err
}
if needNat {
if err := ensureJump("nat", "PREROUTING", firewallChainNatPrerouting); err != nil {
return err
}
if err := ensureJump("nat", "POSTROUTING", firewallChainNatPostrouting); err != nil {
return err
}
}
return nil
}
func iptablesTableSupported(table string) (bool, error) {
err := iptables(table, "-S")
if err == nil {
return true, nil
}
if isIptablesTableMissingErr(err) {
return false, nil
}
return false, err
}
func isIptablesTableMissingErr(err error) bool {
if err == nil {
return false
}
msg := err.Error()
if strings.Contains(msg, "Table does not exist") {
return true
}
if strings.Contains(msg, "can't initialize iptables table") {
return true
}
return false
}
func buildBaseRules(base FirewallBaseRule) error {
if err := iptables("filter", "-A", firewallChainInput, "-i", "lo", "-j", "ACCEPT"); err != nil {
return err
}
if err := iptables("filter", "-A", firewallChainOutput, "-o", "lo", "-j", "ACCEPT"); err != nil {
return err
}
for _, chain := range []string{firewallChainInput, firewallChainOutput, firewallChainForward} {
if err := iptables("filter", "-A", chain, "-m", "conntrack", "--ctstate", "ESTABLISHED,RELATED", "-j", "ACCEPT"); err != nil {
return err
}
}
return nil
}
func appendDefaultPolicies(base FirewallBaseRule) error {
inputDefault, err := normalizeFirewallAction(base.InputPolicy)
if err != nil {
return err
}
outputDefault, err := normalizeFirewallAction(base.OutputPolicy)
if err != nil {
return err
}
forwardDefault, err := normalizeFirewallAction(base.ForwardPolicy)
if err != nil {
return err
}
if err := iptables("filter", "-A", firewallChainInput, "-j", inputDefault); err != nil {
return err
}
if err := iptables("filter", "-A", firewallChainOutput, "-j", outputDefault); err != nil {
return err
}
if err := iptables("filter", "-A", firewallChainForward, "-j", forwardDefault); err != nil {
return err
}
return nil
}
func buildCommunicationRules(rules []FirewallRule) error {
for _, r := range rules {
chain, err := normalizeFirewallChain(r.Chain)
if err != nil {
return err
}
target, err := normalizeFirewallAction(r.Action)
if err != nil {
return err
}
protos, err := normalizeProtocolList(r.Protocols)
if err != nil {
return err
}
for _, proto := range protos {
args := []string{"-A", chain}
if r.SourceIP != "" {
args = append(args, "-s", r.SourceIP)
}
if r.DestinationIP != "" {
args = append(args, "-d", r.DestinationIP)
}
if proto != "" {
args = append(args, "-p", proto)
}
if r.SourcePort != nil && protoSupportsPorts(proto) {
args = append(args, "--sport", strconv.Itoa(*r.SourcePort))
}
if r.DestinationPort != nil && protoSupportsPorts(proto) {
args = append(args, "--dport", strconv.Itoa(*r.DestinationPort))
}
if strings.TrimSpace(r.Comment) != "" {
args = append(args, "-m", "comment", "--comment", r.Comment)
}
args = append(args, "-j", target)
if err := iptables("filter", args...); err != nil {
return err
}
}
}
return nil
}
func buildPortForwardRules(rules []FirewallPortRule, manageNat bool) error {
if !manageNat {
return nil
}
if err := clearManagedNatOutputRules(); err != nil {
return err
}
localRedirects := make([]FirewallPortRule, 0)
preroutingRedirects := make([]FirewallPortRule, 0)
dnatForwards := make([]FirewallPortRule, 0)
for _, r := range rules {
if !isManagedPortForward(r) {
continue
}
chain := strings.ToLower(strings.TrimSpace(r.Chain))
if chain == "" {
if isLocalRedirectDestination(r.DestinationIP) {
chain = "output"
} else {
chain = "prerouting"
}
}
if chain == "output" {
localRedirects = append(localRedirects, r)
} else if chain == "prerouting_redirect" {
preroutingRedirects = append(preroutingRedirects, r)
} else {
dnatForwards = append(dnatForwards, r)
}
}
if err := buildNatOutputRedirectRules(localRedirects); err != nil {
return err
}
if err := buildNatPreroutingRedirectRules(preroutingRedirects); err != nil {
return err
}
if len(dnatForwards) == 0 {
return nil
}
if err := sysctlWrite("net.ipv4.ip_forward", "1"); err != nil {
logger.Warn().Err(err).Msg("failed to enable ip_forward")
}
if err := iptables("nat", "-A", firewallChainNatPostrouting, "-m", "conntrack", "--ctstate", "DNAT", "-j", "MASQUERADE"); err != nil {
return err
}
for _, r := range dnatForwards {
protos, err := normalizeProtocolList(r.Protocols)
if err != nil {
return err
}
for _, proto := range protos {
if proto == "" {
continue
}
preroutingArgs := []string{
"-A", firewallChainNatPrerouting,
"-p", proto,
"--dport", strconv.Itoa(r.SourcePort),
}
if strings.TrimSpace(r.Comment) != "" {
preroutingArgs = append(preroutingArgs, "-m", "comment", "--comment", r.Comment)
}
preroutingArgs = append(
preroutingArgs,
"-j", "DNAT",
"--to-destination", fmt.Sprintf("%s:%d", r.DestinationIP, r.DestinationPort),
)
if err := iptables("nat", preroutingArgs...); err != nil {
return err
}
forwardArgs := []string{
"-A", firewallChainForward,
"-p", proto,
"-d", r.DestinationIP,
"--dport", strconv.Itoa(r.DestinationPort),
"-m", "conntrack",
"--ctstate", "NEW,ESTABLISHED,RELATED",
}
if strings.TrimSpace(r.Comment) != "" {
forwardArgs = append(forwardArgs, "-m", "comment", "--comment", r.Comment)
}
forwardArgs = append(forwardArgs, "-j", "ACCEPT")
if err := iptables("filter", forwardArgs...); err != nil {
return err
}
}
}
return nil
}
func buildNatPreroutingRedirectRules(rules []FirewallPortRule) error {
for _, r := range rules {
protos, err := normalizeProtocolList(r.Protocols)
if err != nil {
return err
}
for _, proto := range protos {
if proto == "" {
continue
}
args := []string{
"-A", firewallChainNatPrerouting,
"-p", proto,
"--dport", strconv.Itoa(r.SourcePort),
}
if strings.TrimSpace(r.Comment) != "" {
args = append(args, "-m", "comment", "--comment", r.Comment)
}
args = append(
args,
"-j", "REDIRECT",
"--to-ports", strconv.Itoa(r.DestinationPort),
)
if err := iptables("nat", args...); err != nil {
return err
}
}
}
return nil
}
func isLocalRedirectDestination(dstIP string) bool {
switch strings.TrimSpace(dstIP) {
case "0.0.0.0", "127.0.0.1":
return true
default:
return false
}
}
func clearManagedNatOutputRules() error {
out, err := iptablesOutput("nat", "-S", "OUTPUT")
if err != nil {
if isIptablesTableMissingErr(err) {
return nil
}
return err
}
lines := strings.Split(strings.ReplaceAll(out, "\r\n", "\n"), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "-A OUTPUT ") {
continue
}
if !strings.Contains(line, "--comment") {
continue
}
if !strings.Contains(line, firewallNatOutputCommentPrefix) {
continue
}
tokens, err := splitShellLike(line)
if err != nil || len(tokens) < 2 {
continue
}
if tokens[0] != "-A" || tokens[1] != "OUTPUT" {
continue
}
tokens[0] = "-D"
_ = iptables("nat", tokens...)
}
return nil
}
func buildNatOutputRedirectRules(rules []FirewallPortRule) error {
for _, r := range rules {
protos, err := normalizeProtocolList(r.Protocols)
if err != nil {
return err
}
comment := formatManagedPortForwardComment(r.Comment)
for _, proto := range protos {
if proto == "" {
continue
}
args := []string{
"-A", "OUTPUT",
"-p", proto,
"--dport", strconv.Itoa(r.SourcePort),
"-m", "comment", "--comment", comment,
"-j", "REDIRECT",
"--to-ports", strconv.Itoa(r.DestinationPort),
}
if err := iptables("nat", args...); err != nil {
return err
}
}
}
return nil
}
func formatManagedPortForwardComment(userComment string) string {
c := strings.TrimSpace(userComment)
if c == "" {
return firewallNatOutputCommentPrefix
}
return firewallNatOutputCommentPrefix + ":" + c
}
func normalizeFirewallChain(chain string) (string, error) {
switch strings.ToLower(strings.TrimSpace(chain)) {
case "input":
return firewallChainInput, nil
case "output":
return firewallChainOutput, nil
case "forward":
return firewallChainForward, nil
default:
return "", fmt.Errorf("unsupported chain %q", chain)
}
}
func normalizeFirewallAction(action string) (string, error) {
switch strings.ToLower(strings.TrimSpace(action)) {
case "accept":
return "ACCEPT", nil
case "drop":
return "DROP", nil
case "reject":
return "REJECT", nil
default:
return "", fmt.Errorf("unsupported action %q", action)
}
}
func normalizeFirewallProtocol(proto string) (string, error) {
switch strings.ToLower(strings.TrimSpace(proto)) {
case "any":
return "", nil
case "tcp", "udp", "icmp", "igmp", "sctp", "dccp":
return strings.ToLower(strings.TrimSpace(proto)), nil
default:
return "", fmt.Errorf("unsupported protocol %q", proto)
}
}
func normalizeProtocolList(protos []string) ([]string, error) {
hasAny := false
normalized := make([]string, 0, len(protos))
for _, p := range protos {
np, err := normalizeFirewallProtocol(p)
if err != nil {
return nil, err
}
if np == "" {
hasAny = true
continue
}
normalized = append(normalized, np)
}
if hasAny {
return []string{""}, nil
}
if len(normalized) == 0 {
return []string{""}, nil
}
return normalized, nil
}
func protoSupportsPorts(proto string) bool {
switch proto {
case "tcp", "udp", "sctp", "dccp":
return true
default:
return false
}
}
func ensureChain(table, chain string) error {
if err := iptables(table, "-nL", chain); err == nil {
return nil
}
err := iptables(table, "-N", chain)
if err == nil {
return nil
}
if strings.Contains(err.Error(), "Chain already exists") {
return nil
}
return err
}
func ensureJump(table, fromChain, toChain string) error {
checkErr := iptables(table, "-C", fromChain, "-j", toChain)
if checkErr == nil {
return nil
}
return iptables(table, "-I", fromChain, "1", "-j", toChain)
}
func iptables(table string, args ...string) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
allArgs := append([]string{}, args...)
cmd := exec.CommandContext(ctx, "iptables", append([]string{"-t", table}, allArgs...)...)
out, err := cmd.CombinedOutput()
if err == nil {
return nil
}
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("iptables timeout: %s", strings.Join(append([]string{"-t", table}, allArgs...), " "))
}
return fmt.Errorf("iptables failed: %s: %w: %s", strings.Join(append([]string{"-t", table}, allArgs...), " "), err, strings.TrimSpace(string(out)))
}
func sysctlWrite(key, value string) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sysctl", "-w", fmt.Sprintf("%s=%s", key, value))
out, err := cmd.CombinedOutput()
if err == nil {
return nil
}
return fmt.Errorf("sysctl failed: %w: %s", err, strings.TrimSpace(string(out)))
}
func iptablesOutput(table string, args ...string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "iptables", append([]string{"-t", table}, args...)...)
out, err := cmd.CombinedOutput()
if err == nil {
return string(out), nil
}
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("iptables timeout: %s", strings.Join(append([]string{"-t", table}, args...), " "))
}
return "", fmt.Errorf("iptables failed: %s: %w: %s", strings.Join(append([]string{"-t", table}, args...), " "), err, strings.TrimSpace(string(out)))
}
func splitShellLike(s string) ([]string, error) {
var out []string
var cur strings.Builder
inQuote := false
quoteChar := byte(0)
esc := false
b := []byte(strings.TrimSpace(s))
for i := 0; i < len(b); i++ {
ch := b[i]
if esc {
cur.WriteByte(ch)
esc = false
continue
}
if ch == '\\' {
esc = true
continue
}
if inQuote {
if ch == quoteChar {
inQuote = false
continue
}
cur.WriteByte(ch)
continue
}
if ch == '"' || ch == '\'' {
inQuote = true
quoteChar = ch
continue
}
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
if cur.Len() > 0 {
out = append(out, cur.String())
cur.Reset()
}
continue
}
cur.WriteByte(ch)
}
if esc {
return nil, fmt.Errorf("unterminated escape")
}
if inQuote {
return nil, fmt.Errorf("unterminated quote")
}
if cur.Len() > 0 {
out = append(out, cur.String())
}
return out, nil
}
func boolPtr(v bool) *bool {
return &v
}
type iptablesParsedRule struct {
chain string
srcIP string
dstIP string
proto string
sport *int
dport *int
toPorts *int
jump string
inIface string
outIface string
ctstate string
toDest string
comment string
}
func ReadFirewallConfigFromSystem() (*FirewallConfig, error) {
if _, err := exec.LookPath("iptables"); err != nil {
return nil, fmt.Errorf("iptables not found: %w", err)
}
inputLines, exists, err := iptablesChainSpecLines("filter", firewallChainInput)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
outputLines, _, err := iptablesChainSpecLines("filter", firewallChainOutput)
if err != nil {
return nil, err
}
forwardLines, _, err := iptablesChainSpecLines("filter", firewallChainForward)
if err != nil {
return nil, err
}
preroutingLines, _, err := iptablesChainSpecLines("nat", firewallChainNatPrerouting)
if err != nil {
return nil, err
}
systemPreroutingLines, _, err := iptablesChainSpecLines("nat", "PREROUTING")
if err != nil {
return nil, err
}
natOutputLines, _, err := iptablesChainSpecLines("nat", "OUTPUT")
if err != nil {
return nil, err
}
inputRules := parseIptablesSpecLines(inputLines)
outputRules := parseIptablesSpecLines(outputLines)
forwardRules := parseIptablesSpecLines(forwardLines)
preroutingRules := parseIptablesSpecLines(preroutingLines)
systemPreroutingRules := parseIptablesSpecLines(systemPreroutingLines)
natOutputRules := parseIptablesSpecLines(natOutputLines)
base := FirewallBaseRule{
InputPolicy: chainDefaultPolicy(inputRules),
OutputPolicy: chainDefaultPolicy(outputRules),
ForwardPolicy: chainDefaultPolicy(forwardRules),
}
inputRules = stripDefaultPolicyRule(inputRules)
outputRules = stripDefaultPolicyRule(outputRules)
forwardRules = stripDefaultPolicyRule(forwardRules)
portForwards := make([]FirewallPortRule, 0)
portForwards = append(portForwards, parsePortForwardsFromNat(preroutingRules)...)
portForwards = append(portForwards, parsePortForwardsFromSystemPrerouting(systemPreroutingRules)...)
portForwards = append(portForwards, parsePortForwardsFromNatOutput(natOutputRules)...)
portForwards = append(portForwards, parsePortForwardsFromSystemNatOutput(natOutputRules)...)
forwardRules = filterAutoForwardRules(forwardRules, portForwards)
commRules := make([]FirewallRule, 0)
commRules = append(commRules, parseCommRulesFromChain("input", inputRules)...)
commRules = append(commRules, parseCommRulesFromChain("output", outputRules)...)
commRules = append(commRules, parseCommRulesFromChain("forward", forwardRules)...)
commRules = groupFirewallRules(commRules)
portForwards = groupPortForwards(portForwards)
return &FirewallConfig{
Base: base,
Rules: commRules,
PortForwards: portForwards,
}, nil
}
func iptablesChainSpecLines(table, chain string) ([]string, bool, error) {
out, err := iptablesOutput(table, "-S", chain)
if err == nil {
lines := strings.Split(strings.ReplaceAll(out, "\r\n", "\n"), "\n")
res := make([]string, 0, len(lines))
for _, l := range lines {
l = strings.TrimSpace(l)
if l == "" {
continue
}
if strings.HasPrefix(l, "-A ") {
res = append(res, l)
}
}
return res, true, nil
}
msg := err.Error()
if isIptablesTableMissingErr(err) {
return nil, false, nil
}
if strings.Contains(msg, "No chain/target/match by that name") || strings.Contains(msg, "No such file") {
return nil, false, nil
}
return nil, false, err
}
func parseIptablesSpecLines(lines []string) []iptablesParsedRule {
out := make([]iptablesParsedRule, 0, len(lines))
for _, line := range lines {
tokens, err := splitShellLike(line)
if err != nil {
continue
}
r := parseIptablesTokens(tokens)
if r.chain == "" {
continue
}
out = append(out, r)
}
return out
}
func parseIptablesTokens(tokens []string) iptablesParsedRule {
var r iptablesParsedRule
for i := 0; i < len(tokens); i++ {
switch tokens[i] {
case "-A":
if i+1 < len(tokens) {
r.chain = tokens[i+1]
i++
}
case "-s":
if i+1 < len(tokens) {
r.srcIP = tokens[i+1]
i++
}
case "-d":
if i+1 < len(tokens) {
r.dstIP = tokens[i+1]
i++
}
case "-p":
if i+1 < len(tokens) {
r.proto = strings.ToLower(strings.TrimSpace(tokens[i+1]))
i++
}
case "--sport", "--source-port":
if i+1 < len(tokens) {
if v, err := strconv.Atoi(tokens[i+1]); err == nil {
r.sport = &v
}
i++
}
case "--dport", "--destination-port":
if i+1 < len(tokens) {
if v, err := strconv.Atoi(tokens[i+1]); err == nil {
r.dport = &v
}
i++
}
case "-j":
if i+1 < len(tokens) {
r.jump = strings.ToUpper(strings.TrimSpace(tokens[i+1]))
i++
}
case "-i":
if i+1 < len(tokens) {
r.inIface = tokens[i+1]
i++
}
case "-o":
if i+1 < len(tokens) {
r.outIface = tokens[i+1]
i++
}
case "--ctstate":
if i+1 < len(tokens) {
r.ctstate = tokens[i+1]
i++
}
case "--to-destination":
if i+1 < len(tokens) {
r.toDest = tokens[i+1]
i++
}
case "--to-ports":
if i+1 < len(tokens) {
if v, err := strconv.Atoi(tokens[i+1]); err == nil {
r.toPorts = &v
}
i++
}
case "--comment":
if i+1 < len(tokens) {
r.comment = tokens[i+1]
i++
}
}
}
return r
}
func chainDefaultPolicy(rules []iptablesParsedRule) string {
for i := len(rules) - 1; i >= 0; i-- {
r := rules[i]
if isUnconditionalDefaultRule(r) {
switch r.jump {
case "ACCEPT":
return "accept"
case "DROP":
return "drop"
case "REJECT":
return "reject"
}
}
}
return "accept"
}
func stripDefaultPolicyRule(rules []iptablesParsedRule) []iptablesParsedRule {
out := make([]iptablesParsedRule, 0, len(rules))
for i := 0; i < len(rules); i++ {
if isUnconditionalDefaultRule(rules[i]) {
continue
}
out = append(out, rules[i])
}
return out
}
func isUnconditionalDefaultRule(r iptablesParsedRule) bool {
if r.jump != "ACCEPT" && r.jump != "DROP" && r.jump != "REJECT" {
return false
}
if r.srcIP != "" || r.dstIP != "" || r.proto != "" || r.sport != nil || r.dport != nil {
return false
}
if r.inIface != "" || r.outIface != "" || r.ctstate != "" || r.toDest != "" {
return false
}
if strings.TrimSpace(r.comment) != "" {
return false
}
return true
}
func parsePortForwardsFromNat(prerouting []iptablesParsedRule) []FirewallPortRule {
out := make([]FirewallPortRule, 0)
for _, r := range prerouting {
if r.dport == nil || r.proto == "" {
continue
}
switch r.jump {
case "DNAT":
dstIP, dstPort := parseToDestination(r.toDest)
if dstIP == "" || dstPort == 0 {
continue
}
out = append(out, FirewallPortRule{
Chain: "prerouting",
Managed: boolPtr(true),
SourcePort: *r.dport,
Protocols: []string{r.proto},
DestinationIP: dstIP,
DestinationPort: dstPort,
Comment: r.comment,
})
case "REDIRECT":
if r.toPorts == nil {
continue
}
out = append(out, FirewallPortRule{
Chain: "prerouting_redirect",
Managed: boolPtr(true),
SourcePort: *r.dport,
Protocols: []string{r.proto},
DestinationIP: "0.0.0.0",
DestinationPort: *r.toPorts,
Comment: r.comment,
})
default:
continue
}
}
return out
}
func parsePortForwardsFromSystemPrerouting(rules []iptablesParsedRule) []FirewallPortRule {
out := make([]FirewallPortRule, 0)
for _, r := range rules {
if r.chain != "PREROUTING" {
continue
}
if r.dport == nil || r.proto == "" {
continue
}
switch r.jump {
case "DNAT":
dstIP, dstPort := parseToDestination(r.toDest)
if dstIP == "" || dstPort == 0 {
continue
}
out = append(out, FirewallPortRule{
Chain: "prerouting",
Managed: boolPtr(false),
SourcePort: *r.dport,
Protocols: []string{r.proto},
DestinationIP: dstIP,
DestinationPort: dstPort,
Comment: r.comment,
})
case "REDIRECT":
if r.toPorts == nil {
continue
}
out = append(out, FirewallPortRule{
Chain: "prerouting_redirect",
Managed: boolPtr(false),
SourcePort: *r.dport,
Protocols: []string{r.proto},
DestinationIP: "0.0.0.0",
DestinationPort: *r.toPorts,
Comment: r.comment,
})
default:
continue
}
}
return out
}
func parsePortForwardsFromNatOutput(rules []iptablesParsedRule) []FirewallPortRule {
out := make([]FirewallPortRule, 0)
for _, r := range rules {
if r.chain != "OUTPUT" {
continue
}
if r.jump != "REDIRECT" {
continue
}
if r.dport == nil || r.toPorts == nil {
continue
}
if r.proto == "" {
continue
}
if strings.TrimSpace(r.comment) != firewallNatOutputCommentPrefix && !strings.HasPrefix(strings.TrimSpace(r.comment), firewallNatOutputCommentPrefix+":") {
continue
}
comment := parseManagedPortForwardComment(r.comment)
out = append(out, FirewallPortRule{
Chain: "output",
Managed: boolPtr(true),
SourcePort: *r.dport,
Protocols: []string{r.proto},
DestinationIP: "0.0.0.0",
DestinationPort: *r.toPorts,
Comment: comment,
})
}
return out
}
func parsePortForwardsFromSystemNatOutput(rules []iptablesParsedRule) []FirewallPortRule {
out := make([]FirewallPortRule, 0)
for _, r := range rules {
if r.chain != "OUTPUT" {
continue
}
if r.jump != "REDIRECT" {
continue
}
if r.dport == nil || r.toPorts == nil {
continue
}
if r.proto == "" {
continue
}
if strings.TrimSpace(r.comment) == firewallNatOutputCommentPrefix || strings.HasPrefix(strings.TrimSpace(r.comment), firewallNatOutputCommentPrefix+":") {
continue
}
out = append(out, FirewallPortRule{
Chain: "output",
Managed: boolPtr(false),
SourcePort: *r.dport,
Protocols: []string{r.proto},
DestinationIP: "0.0.0.0",
DestinationPort: *r.toPorts,
Comment: r.comment,
})
}
return out
}
func parseManagedPortForwardComment(s string) string {
t := strings.TrimSpace(s)
if t == firewallNatOutputCommentPrefix {
return ""
}
if strings.HasPrefix(t, firewallNatOutputCommentPrefix+":") {
return strings.TrimPrefix(t, firewallNatOutputCommentPrefix+":")
}
return ""
}
func parseToDestination(toDest string) (string, int) {
t := strings.TrimSpace(toDest)
if t == "" {
return "", 0
}
if strings.Contains(t, ":") {
parts := strings.Split(t, ":")
if len(parts) < 2 {
return "", 0
}
portStr := parts[len(parts)-1]
ip := strings.Join(parts[:len(parts)-1], ":")
p, err := strconv.Atoi(portStr)
if err != nil || p < 1 || p > 65535 {
return "", 0
}
return ip, p
}
return "", 0
}
func filterAutoForwardRules(forward []iptablesParsedRule, portForwards []FirewallPortRule) []iptablesParsedRule {
if len(portForwards) == 0 || len(forward) == 0 {
return forward
}
keys := make(map[string]struct{}, len(portForwards))
for _, pf := range portForwards {
for _, p := range pf.Protocols {
switch strings.ToLower(strings.TrimSpace(pf.Chain)) {
case "output", "prerouting_redirect":
continue
}
if isLocalRedirectDestination(pf.DestinationIP) {
continue
}
keys[fmt.Sprintf("%s|%d|%s|%s", pf.DestinationIP, pf.DestinationPort, strings.ToLower(p), pf.Comment)] = struct{}{}
}
}
out := make([]iptablesParsedRule, 0, len(forward))
for _, r := range forward {
if r.jump == "ACCEPT" && strings.Contains(r.ctstate, "NEW,ESTABLISHED,RELATED") && r.dstIP != "" && r.dport != nil && r.proto != "" {
if _, ok := keys[fmt.Sprintf("%s|%d|%s|%s", r.dstIP, *r.dport, strings.ToLower(r.proto), r.comment)]; ok {
continue
}
}
out = append(out, r)
}
return out
}
func parseCommRulesFromChain(chain string, rules []iptablesParsedRule) []FirewallRule {
out := make([]FirewallRule, 0)
for _, r := range rules {
if isInternalAcceptRule(chain, r) {
continue
}
action := strings.ToLower(r.jump)
if action != "accept" && action != "drop" && action != "reject" {
continue
}
protos := []string{"any"}
if r.proto != "" {
protos = []string{r.proto}
}
src := r.srcIP
dst := r.dstIP
if src == "0.0.0.0/0" {
src = ""
}
if dst == "0.0.0.0/0" {
dst = ""
}
out = append(out, FirewallRule{
Chain: chain,
SourceIP: src,
SourcePort: r.sport,
Protocols: protos,
DestinationIP: dst,
DestinationPort: r.dport,
Action: action,
Comment: r.comment,
})
}
return out
}
func isInternalAcceptRule(chain string, r iptablesParsedRule) bool {
if r.jump != "ACCEPT" {
return false
}
if chain == "input" && r.inIface == "lo" {
return true
}
if chain == "output" && r.outIface == "lo" {
return true
}
if strings.Contains(r.ctstate, "ESTABLISHED,RELATED") {
return true
}
return false
}
func groupFirewallRules(in []FirewallRule) []FirewallRule {
type key struct {
chain string
src string
dst string
sport string
dport string
act string
cmt string
}
out := make([]FirewallRule, 0)
index := make(map[key]int)
for _, r := range in {
k := key{
chain: r.Chain,
src: r.SourceIP,
dst: r.DestinationIP,
act: r.Action,
cmt: r.Comment,
}
if r.SourcePort != nil {
k.sport = strconv.Itoa(*r.SourcePort)
}
if r.DestinationPort != nil {
k.dport = strconv.Itoa(*r.DestinationPort)
}
if idx, ok := index[k]; ok {
if len(out[idx].Protocols) == 1 && out[idx].Protocols[0] == "any" {
continue
}
if len(r.Protocols) == 1 && r.Protocols[0] == "any" {
out[idx].Protocols = []string{"any"}
continue
}
out[idx].Protocols = appendUnique(out[idx].Protocols, r.Protocols...)
continue
}
index[k] = len(out)
out = append(out, r)
}
return out
}
func groupPortForwards(in []FirewallPortRule) []FirewallPortRule {
type key struct {
chain string
managed string
srcPort int
dstIP string
dstPort int
cmt string
}
out := make([]FirewallPortRule, 0)
index := make(map[key]int)
for _, r := range in {
managed := "true"
if r.Managed != nil && !*r.Managed {
managed = "false"
}
k := key{
chain: r.Chain,
managed: managed,
srcPort: r.SourcePort,
dstIP: r.DestinationIP,
dstPort: r.DestinationPort,
cmt: r.Comment,
}
if idx, ok := index[k]; ok {
out[idx].Protocols = appendUnique(out[idx].Protocols, r.Protocols...)
continue
}
index[k] = len(out)
out = append(out, r)
}
return out
}
func appendUnique(dst []string, items ...string) []string {
set := make(map[string]struct{}, len(dst))
for _, v := range dst {
set[v] = struct{}{}
}
for _, v := range items {
v = strings.ToLower(strings.TrimSpace(v))
if v == "" {
continue
}
if _, ok := set[v]; ok {
continue
}
set[v] = struct{}{}
dst = append(dst, v)
}
return dst
}
func resetFirewallForFactory() {
if _, err := exec.LookPath("iptables"); err != nil {
return
}
_ = removeFirewallJumps(false)
_ = removeFirewallChains(false)
natSupported, err := iptablesTableSupported("nat")
if err == nil && natSupported {
_ = removeFirewallJumps(true)
_ = removeFirewallChains(true)
}
_ = sysctlWrite("net.ipv4.ip_forward", "0")
}
func removeFirewallJumps(needNat bool) error {
if err := removeJumpAll("filter", "INPUT", firewallChainInput); err != nil {
return err
}
if err := removeJumpAll("filter", "OUTPUT", firewallChainOutput); err != nil {
return err
}
if err := removeJumpAll("filter", "FORWARD", firewallChainForward); err != nil {
return err
}
if needNat {
if err := removeJumpAll("nat", "PREROUTING", firewallChainNatPrerouting); err != nil {
return err
}
if err := removeJumpAll("nat", "POSTROUTING", firewallChainNatPostrouting); err != nil {
return err
}
}
return nil
}
func removeJumpAll(table, fromChain, toChain string) error {
for i := 0; i < 16; i++ {
err := iptables(table, "-D", fromChain, "-j", toChain)
if err == nil {
continue
}
if isNoSuchRuleErr(err) {
return nil
}
return err
}
return nil
}
func removeFirewallChains(needNat bool) error {
for _, chain := range []string{firewallChainInput, firewallChainOutput, firewallChainForward} {
_ = iptables("filter", "-F", chain)
_ = iptables("filter", "-X", chain)
}
if needNat {
for _, chain := range []string{firewallChainNatPrerouting, firewallChainNatPostrouting} {
_ = iptables("nat", "-F", chain)
_ = iptables("nat", "-X", chain)
}
}
return nil
}
func isNoSuchRuleErr(err error) bool {
if err == nil {
return false
}
msg := err.Error()
if strings.Contains(msg, "No chain/target/match by that name") {
return true
}
if strings.Contains(msg, "Bad rule") {
return true
}
if strings.Contains(msg, "does a matching rule exist in that chain") {
return true
}
if strings.Contains(msg, "No such file or directory") {
return true
}
return errors.Is(err, exec.ErrNotFound)
}