core: Change net.IP to netip.Addr; use netip.Prefix (#4966)

Co-authored-by: Matt Holt <mholt@users.noreply.github.com>
This commit is contained in:
WilczyńskiT 2022-08-18 00:10:57 +02:00 committed by GitHub
parent a944de4ab7
commit c7772588bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 50 additions and 57 deletions

View file

@ -17,6 +17,7 @@ package httpcaddyfile
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -354,9 +355,9 @@ func (a Address) Normalize() Address {
// ensure host is normalized if it's an IP address // ensure host is normalized if it's an IP address
host := strings.TrimSpace(a.Host) host := strings.TrimSpace(a.Host)
if ip := net.ParseIP(host); ip != nil { if ip, err := netip.ParseAddr(host); err == nil {
if ipv6 := ip.To16(); ipv6 != nil && ipv6.DefaultMask() == nil { if ip.Is6() && !ip.Is4() && !ip.Is4In6() {
host = ipv6.String() host = ip.String()
} }
} }

View file

@ -20,6 +20,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -400,7 +401,7 @@ func (na NetworkAddress) isLoopback() bool {
if na.Host == "localhost" { if na.Host == "localhost" {
return true return true
} }
if ip := net.ParseIP(na.Host); ip != nil { if ip, err := netip.ParseAddr(na.Host); err == nil {
return ip.IsLoopback() return ip.IsLoopback()
} }
return false return false
@ -410,7 +411,7 @@ func (na NetworkAddress) isWildcardInterface() bool {
if na.Host == "" { if na.Host == "" {
return true return true
} }
if ip := net.ParseIP(na.Host); ip != nil { if ip, err := netip.ParseAddr(na.Host); err == nil {
return ip.IsUnspecified() return ip.IsUnspecified()
} }
return false return false

View file

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/textproto" "net/textproto"
"net/url" "net/url"
"path" "path"
@ -171,7 +172,7 @@ type (
// cidrs and zones vars should aligned always in the same // cidrs and zones vars should aligned always in the same
// length and indexes for matching later // length and indexes for matching later
cidrs []*net.IPNet cidrs []*netip.Prefix
zones []string zones []string
logger *zap.Logger logger *zap.Logger
} }
@ -1311,27 +1312,24 @@ func (m *MatchRemoteIP) Provision(ctx caddy.Context) error {
m.zones = append(m.zones, "") m.zones = append(m.zones, "")
} }
if strings.Contains(str, "/") { if strings.Contains(str, "/") {
_, ipNet, err := net.ParseCIDR(str) ipNet, err := netip.ParsePrefix(str)
if err != nil { if err != nil {
return fmt.Errorf("parsing CIDR expression '%s': %v", str, err) return fmt.Errorf("parsing CIDR expression '%s': %v", str, err)
} }
m.cidrs = append(m.cidrs, ipNet) m.cidrs = append(m.cidrs, &ipNet)
} else { } else {
ip := net.ParseIP(str) ipAddr, err := netip.ParseAddr(str)
if ip == nil { if err != nil {
return fmt.Errorf("invalid IP address: %s", str) return fmt.Errorf("invalid IP address: '%s': %v", str, err)
} }
mask := len(ip) * 8 ipNew := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
m.cidrs = append(m.cidrs, &net.IPNet{ m.cidrs = append(m.cidrs, &ipNew)
IP: ip,
Mask: net.CIDRMask(mask, mask),
})
} }
} }
return nil return nil
} }
func (m MatchRemoteIP) getClientIP(r *http.Request) (net.IP, string, error) { func (m MatchRemoteIP) getClientIP(r *http.Request) (netip.Addr, string, error) {
remote := r.RemoteAddr remote := r.RemoteAddr
zoneID := "" zoneID := ""
if m.Forwarded { if m.Forwarded {
@ -1350,11 +1348,11 @@ func (m MatchRemoteIP) getClientIP(r *http.Request) (net.IP, string, error) {
ipStr = split[0] ipStr = split[0]
zoneID = split[1] zoneID = split[1]
} }
ip := net.ParseIP(ipStr) ipAddr, err := netip.ParseAddr(ipStr)
if ip == nil { if err != nil {
return nil, zoneID, fmt.Errorf("invalid client IP address: %s", ipStr) return netip.IPv4Unspecified(), "", err
} }
return ip, zoneID, nil return ipAddr, zoneID, nil
} }
// Match returns true if r matches m. // Match returns true if r matches m.

View file

@ -24,6 +24,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/netip"
"net/textproto" "net/textproto"
"net/url" "net/url"
"regexp" "regexp"
@ -180,7 +181,7 @@ type Handler struct {
DynamicUpstreams UpstreamSource `json:"-"` DynamicUpstreams UpstreamSource `json:"-"`
// Holds the parsed CIDR ranges from TrustedProxies // Holds the parsed CIDR ranges from TrustedProxies
trustedProxies []*net.IPNet trustedProxies []netip.Prefix
// Holds the named response matchers from the Caddyfile while adapting // Holds the named response matchers from the Caddyfile while adapting
responseMatchers map[string]caddyhttp.ResponseMatcher responseMatchers map[string]caddyhttp.ResponseMatcher
@ -251,24 +252,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
// parse trusted proxy CIDRs ahead of time // parse trusted proxy CIDRs ahead of time
for _, str := range h.TrustedProxies { for _, str := range h.TrustedProxies {
if strings.Contains(str, "/") { if strings.Contains(str, "/") {
_, ipNet, err := net.ParseCIDR(str) ipNet, err := netip.ParsePrefix(str)
if err != nil { if err != nil {
return fmt.Errorf("parsing CIDR expression: %v", err) return fmt.Errorf("parsing CIDR expression: '%s': %v", str, err)
} }
h.trustedProxies = append(h.trustedProxies, ipNet) h.trustedProxies = append(h.trustedProxies, ipNet)
} else { } else {
ip := net.ParseIP(str) ipAddr, err := netip.ParseAddr(str)
if ip == nil { if err != nil {
return fmt.Errorf("invalid IP address: %s", str) return fmt.Errorf("invalid IP address: '%s': %v", str, err)
} }
if ipv4 := ip.To4(); ipv4 != nil { ipNew := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
ip = ipv4 h.trustedProxies = append(h.trustedProxies, ipNew)
}
mask := len(ip) * 8
h.trustedProxies = append(h.trustedProxies, &net.IPNet{
IP: ip,
Mask: net.CIDRMask(mask, mask),
})
} }
} }
@ -672,15 +667,15 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
if before, _, found := strings.Cut(clientIP, "%"); found { if before, _, found := strings.Cut(clientIP, "%"); found {
clientIP = before clientIP = before
} }
ip := net.ParseIP(clientIP) ipAddr, err := netip.ParseAddr(clientIP)
if ip == nil { if err != nil {
return fmt.Errorf("invalid client IP address: %s", clientIP) return fmt.Errorf("invalid IP address: '%s': %v", clientIP, err)
} }
// Check if the client is a trusted proxy // Check if the client is a trusted proxy
trusted := false trusted := false
for _, ipRange := range h.trustedProxies { for _, ipRange := range h.trustedProxies {
if ipRange.Contains(ip) { if ipRange.Contains(ipAddr) {
trusted = true trusted = true
break break
} }

View file

@ -18,6 +18,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"net/netip"
"strings" "strings"
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
@ -65,8 +66,8 @@ type MatchRemoteIP struct {
// The IPs or CIDR ranges to *NOT* match. // The IPs or CIDR ranges to *NOT* match.
NotRanges []string `json:"not_ranges,omitempty"` NotRanges []string `json:"not_ranges,omitempty"`
cidrs []*net.IPNet cidrs []netip.Prefix
notCidrs []*net.IPNet notCidrs []netip.Prefix
logger *zap.Logger logger *zap.Logger
} }
@ -105,38 +106,35 @@ func (m MatchRemoteIP) Match(hello *tls.ClientHelloInfo) bool {
if err != nil { if err != nil {
ipStr = remoteAddr // weird; maybe no port? ipStr = remoteAddr // weird; maybe no port?
} }
ip := net.ParseIP(ipStr) ipAddr, err := netip.ParseAddr(ipStr)
if ip == nil { if err != nil {
m.logger.Error("invalid client IP addresss", zap.String("ip", ipStr)) m.logger.Error("invalid client IP addresss", zap.String("ip", ipStr))
return false return false
} }
return (len(m.cidrs) == 0 || m.matches(ip, m.cidrs)) && return (len(m.cidrs) == 0 || m.matches(ipAddr, m.cidrs)) &&
(len(m.notCidrs) == 0 || !m.matches(ip, m.notCidrs)) (len(m.notCidrs) == 0 || !m.matches(ipAddr, m.notCidrs))
} }
func (MatchRemoteIP) parseIPRange(str string) ([]*net.IPNet, error) { func (MatchRemoteIP) parseIPRange(str string) ([]netip.Prefix, error) {
var cidrs []*net.IPNet var cidrs []netip.Prefix
if strings.Contains(str, "/") { if strings.Contains(str, "/") {
_, ipNet, err := net.ParseCIDR(str) ipNet, err := netip.ParsePrefix(str)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing CIDR expression: %v", err) return nil, fmt.Errorf("parsing CIDR expression: %v", err)
} }
cidrs = append(cidrs, ipNet) cidrs = append(cidrs, ipNet)
} else { } else {
ip := net.ParseIP(str) ipAddr, err := netip.ParseAddr(str)
if ip == nil { if err != nil {
return nil, fmt.Errorf("invalid IP address: %s", str) return nil, fmt.Errorf("invalid IP address: '%s': %v", str, err)
} }
mask := len(ip) * 8 ip := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
cidrs = append(cidrs, &net.IPNet{ cidrs = append(cidrs, ip)
IP: ip,
Mask: net.CIDRMask(mask, mask),
})
} }
return cidrs, nil return cidrs, nil
} }
func (MatchRemoteIP) matches(ip net.IP, ranges []*net.IPNet) bool { func (MatchRemoteIP) matches(ip netip.Addr, ranges []netip.Prefix) bool {
for _, ipRange := range ranges { for _, ipRange := range ranges {
if ipRange.Contains(ip) { if ipRange.Contains(ip) {
return true return true