// Package ratelimit provides a simple window-based rate limiter. package ratelimit import ( "net" "sync" "time" ) // Limiter is a simple rate limiter with one or more fixed windows, e.g. the // last minute/hour/day/week, working on three classes/subnets of an IP. type Limiter struct { sync.Mutex WindowLimits []WindowLimit ipmasked [3][16]byte } // WindowLimit holds counters for one window, with limits for each IP class/subnet. type WindowLimit struct { Window time.Duration Limits [3]int64 // For "ipmasked1" through "ipmasked3". Time uint32 // Time/Window. Counts map[struct { Index uint8 IPMasked [16]byte }]int64 } // Add attempts to consume "n" items from the rate limiter. If the total for this // key and this interval would exceed limit, "n" is not counted and false is // returned. If now represents a different time interval, all counts are reset. func (l *Limiter) Add(ip net.IP, tm time.Time, n int64) bool { return l.checkAdd(true, ip, tm, n) } // CanAdd returns if n could be added to the limiter. func (l *Limiter) CanAdd(ip net.IP, tm time.Time, n int64) bool { return l.checkAdd(false, ip, tm, n) } func (l *Limiter) checkAdd(add bool, ip net.IP, tm time.Time, n int64) bool { l.Lock() defer l.Unlock() // First check. for i, pl := range l.WindowLimits { t := uint32(tm.UnixNano() / int64(pl.Window)) if t > pl.Time || pl.Counts == nil { l.WindowLimits[i].Time = t pl.Counts = map[struct { Index uint8 IPMasked [16]byte }]int64{} // Used below. l.WindowLimits[i].Counts = pl.Counts } for j := 0; j < 3; j++ { if i == 0 { l.ipmasked[j] = l.maskIP(j, ip) } v := pl.Counts[struct { Index uint8 IPMasked [16]byte }{uint8(j), l.ipmasked[j]}] if v+n > pl.Limits[j] { return false } } } if !add { return true } // Finally record. for _, pl := range l.WindowLimits { for j := 0; j < 3; j++ { pl.Counts[struct { Index uint8 IPMasked [16]byte }{uint8(j), l.ipmasked[j]}] += n } } return true } // Reset sets the counter to 0 for key and ip, and subtracts from the ipmasked counts. func (l *Limiter) Reset(ip net.IP, tm time.Time) { l.Lock() defer l.Unlock() // Prepare masked ip's. for i := 0; i < 3; i++ { l.ipmasked[i] = l.maskIP(i, ip) } for _, pl := range l.WindowLimits { t := uint32(tm.UnixNano() / int64(pl.Window)) if t != pl.Time || pl.Counts == nil { continue } var n int64 for j := 0; j < 3; j++ { k := struct { Index uint8 IPMasked [16]byte }{uint8(j), l.ipmasked[j]} if j == 0 { n = pl.Counts[k] } if pl.Counts != nil { pl.Counts[k] -= n } } } } func (l *Limiter) maskIP(i int, ip net.IP) [16]byte { isv4 := ip.To4() != nil var ipmasked net.IP if isv4 { switch i { case 0: ipmasked = ip case 1: ipmasked = ip.Mask(net.CIDRMask(26, 32)) case 2: ipmasked = ip.Mask(net.CIDRMask(21, 32)) default: panic("missing case for maskip ipv4") } } else { switch i { case 0: ipmasked = ip.Mask(net.CIDRMask(64, 128)) case 1: ipmasked = ip.Mask(net.CIDRMask(48, 128)) case 2: ipmasked = ip.Mask(net.CIDRMask(32, 128)) default: panic("missing case for masking ipv6") } } return *(*[16]byte)(ipmasked.To16()) }