mox/ratelimit/ratelimit.go
Mechiel Lukkien 2154392bd8
add basic rate limiters
limiting is done based on remote ip's, with 3 ip mask variants to limit networks
of machines. often with two windows, enabling short bursts of activity, but not
sustained high activity. currently only for imap and smtp, not yet http.

limits are currently based on:
- number of open connections
- connection rate
- limits after authentication failures. too many failures, and new connections will be dropped.
- rate of delivery in total number of messages
- rate of delivery in total size of messages

the limits on connections and authentication failures are in-memory. the limits
on delivery of messages are based on stored messages.

the limits themselves are not yet configurable, let's use this first.

in the future, we may also want to have stricter limits for senders without any
reputation.
2023-02-07 23:18:15 +01:00

146 lines
3.2 KiB
Go

// Package ratelimit provides a simple window-based rate limiter.
package ratelimit
import (
"net"
"sync"
"time"
)
// Limiter is a simple rate limiter with one or more fixed windows, e.g. the
// last minute/hour/day/week, working on three classes/subnets of an IP.
type Limiter struct {
sync.Mutex
WindowLimits []WindowLimit
ipmasked [3][16]byte
}
// WindowLimit holds counters for one window, with limits for each IP class/subnet.
type WindowLimit struct {
Window time.Duration
Limits [3]int64 // For "ipmasked1" through "ipmasked3".
Time uint32 // Time/Window.
Counts map[struct {
Index uint8
IPMasked [16]byte
}]int64
}
// Add attempts to consume "n" items from the rate limiter. If the total for this
// key and this interval would exceed limit, "n" is not counted and false is
// returned. If now represents a different time interval, all counts are reset.
func (l *Limiter) Add(ip net.IP, tm time.Time, n int64) bool {
return l.checkAdd(true, ip, tm, n)
}
// CanAdd returns if n could be added to the limiter.
func (l *Limiter) CanAdd(ip net.IP, tm time.Time, n int64) bool {
return l.checkAdd(false, ip, tm, n)
}
func (l *Limiter) checkAdd(add bool, ip net.IP, tm time.Time, n int64) bool {
l.Lock()
defer l.Unlock()
// First check.
for i, pl := range l.WindowLimits {
t := uint32(tm.UnixNano() / int64(pl.Window))
if t > pl.Time || pl.Counts == nil {
l.WindowLimits[i].Time = t
pl.Counts = map[struct {
Index uint8
IPMasked [16]byte
}]int64{} // Used below.
l.WindowLimits[i].Counts = pl.Counts
}
for j := 0; j < 3; j++ {
if i == 0 {
l.ipmasked[j] = l.maskIP(j, ip)
}
v := pl.Counts[struct {
Index uint8
IPMasked [16]byte
}{uint8(j), l.ipmasked[j]}]
if v+n > pl.Limits[j] {
return false
}
}
}
if !add {
return true
}
// Finally record.
for _, pl := range l.WindowLimits {
for j := 0; j < 3; j++ {
pl.Counts[struct {
Index uint8
IPMasked [16]byte
}{uint8(j), l.ipmasked[j]}] += n
}
}
return true
}
// Reset sets the counter to 0 for key and ip, and substracts from the ipmasked counts.
func (l *Limiter) Reset(ip net.IP, tm time.Time) {
l.Lock()
defer l.Unlock()
// Prepare masked ip's.
for i := 0; i < 3; i++ {
l.ipmasked[i] = l.maskIP(i, ip)
}
for _, pl := range l.WindowLimits {
t := uint32(tm.UnixNano() / int64(pl.Window))
if t != pl.Time || pl.Counts == nil {
continue
}
var n int64
for j := 0; j < 3; j++ {
k := struct {
Index uint8
IPMasked [16]byte
}{uint8(j), l.ipmasked[j]}
if j == 0 {
n = pl.Counts[k]
}
if pl.Counts != nil {
pl.Counts[k] -= n
}
}
}
}
func (l *Limiter) maskIP(i int, ip net.IP) [16]byte {
isv4 := ip.To4() != nil
var ipmasked net.IP
if isv4 {
switch i {
case 0:
ipmasked = ip
case 1:
ipmasked = ip.Mask(net.CIDRMask(26, 32))
case 2:
ipmasked = ip.Mask(net.CIDRMask(21, 32))
default:
panic("missing case for maskip ipv4")
}
} else {
switch i {
case 0:
ipmasked = ip.Mask(net.CIDRMask(64, 128))
case 1:
ipmasked = ip.Mask(net.CIDRMask(48, 128))
case 2:
ipmasked = ip.Mask(net.CIDRMask(32, 128))
default:
panic("missing case for masking ipv6")
}
}
return *(*[16]byte)(ipmasked.To16())
}