Skip to content

Commit

Permalink
fix: remove rate.Limiter and implement as requested.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Stockton committed Sep 30, 2024
1 parent 00c0439 commit 994a9f4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 101 deletions.
5 changes: 2 additions & 3 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
"github.com/supabase/hibp"
"golang.org/x/time/rate"
)

const (
Expand All @@ -39,8 +38,8 @@ type API struct {
// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time

emailRateLimiter *rate.Limiter
smsRateLimiter *rate.Limiter
emailRateLimiter *RateLimiter
smsRateLimiter *RateLimiter
}

func (a *API) Now() time.Time {
Expand Down
56 changes: 39 additions & 17 deletions internal/api/ratelimits.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
package api

import (
"sync"
"time"

"github.com/supabase/auth/internal/conf"
"golang.org/x/time/rate"
)

// RateLimiter will limit the number of calls to Allow per interval.
type RateLimiter struct {
mu sync.Mutex
ival time.Duration // Count is reset and time updated every ival.
limit int // Limit calls to Allow() per ival.

// Guarded by mu.
last time.Time // When the limiter was last reset.
count int // Total calls to Allow() since time.
}

// newRateLimiter returns a rate limiter configured using the given conf.Rate.
//
// The returned *rate.Limiter will be configured with a token bucket containing
// a single token, which will fill up at a rate of r. For example to allow 100
// events every 24 hours. This will fill a token bucket approximately once every
// 864 seconds (14.4 minutes). See Example_newRateLimiter for a visualization.
func newRateLimiter(r conf.Rate) *rate.Limiter {
// The rate limiter deals in events per second.
eps := r.EventsPerSecond()
burst := int(r.Events)
if burst <= 0 {
burst = 1
func newRateLimiter(r conf.Rate) *RateLimiter {
return &RateLimiter{
ival: r.OverTime,
limit: int(r.Events),
last: time.Now(),
}
}

// NewLimiter will have an initial token bucket of size `burst`. It will
// be refilled at a rate of `eps` indefinitely. Note that the expression
// 100 / 24h is roughly equivelant to the expression 1 / 15m. The 100 is
// a rate, not a quota.
return rate.NewLimiter(rate.Limit(eps), burst)
func (rl *RateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()

now := time.Now()
return rl.allowAt(now)
}

func (rl *RateLimiter) allowAt(at time.Time) bool {
since := at.Sub(rl.last)
if ivals := int64(since / rl.ival); ivals > 0 {
rl.last = rl.last.Add(time.Duration(ivals) * rl.ival)
rl.count = 0
}
if rl.count < rl.limit {
rl.count++
return true
}
return false
}
130 changes: 49 additions & 81 deletions internal/api/ratelimits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,101 +2,65 @@ package api

import (
"fmt"
"math"
"testing"
"time"

"github.com/supabase/auth/internal/conf"
"golang.org/x/time/rate"
)

func newUnlimitedLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Inf, 0)
func newUnlimitedLimiter() *RateLimiter {
cfg := conf.Rate{
Events: float64(math.MaxInt32),
OverTime: time.Second,
}
return newRateLimiter(cfg)
}

func Example_newRateLimiter() {
now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z")
{
cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24}
rl := newRateLimiter(cfg)
cur := now
for limited, i := 0, 0; i < 160; i++ {
allowed := rl.AllowN(cur, 1)
if allowed {
limited++
cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24}
rl := newRateLimiter(cfg)
rl.last = now

cur := now
allowed := 0

for days := 0; days < 2; days++ {
// First 100 events succeed.
for i := 0; i < 100; i++ {
allow := rl.allowAt(cur)
cur = cur.Add(time.Second)

if !allow {
fmt.Printf("false @ %v after %v events... [FAILED]\n", cur, allowed)
return
}
allowed++
}
fmt.Printf("true @ %v for last %v events...\n", cur, allowed)

switch {
case i == 100:
fmt.Printf("true @ %v for last %v events...\n", cur, limited)
case i > 100:
fmt.Printf("%-5v @ %v\n", allowed, cur)
cur = cur.Add(time.Minute)
// We try hourly until it allows us to make requests again.
denied := 0
for i := 0; i < 23; i++ {
cur = cur.Add(time.Hour)
allow := rl.allowAt(cur)
if allow {
fmt.Printf("true @ %v before quota reset... [FAILED]\n", cur)
return
}
denied++
}
fmt.Printf("false @ %v for last %v events...\n", cur, denied)

cur = cur.Add(time.Hour)
}

// Output:
// true @ 2024-09-24 10:00:00 +0000 UTC for last 100 events...
// false @ 2024-09-24 10:00:00 +0000 UTC
// false @ 2024-09-24 10:01:00 +0000 UTC
// false @ 2024-09-24 10:02:00 +0000 UTC
// false @ 2024-09-24 10:03:00 +0000 UTC
// false @ 2024-09-24 10:04:00 +0000 UTC
// false @ 2024-09-24 10:05:00 +0000 UTC
// false @ 2024-09-24 10:06:00 +0000 UTC
// false @ 2024-09-24 10:07:00 +0000 UTC
// false @ 2024-09-24 10:08:00 +0000 UTC
// false @ 2024-09-24 10:09:00 +0000 UTC
// false @ 2024-09-24 10:10:00 +0000 UTC
// false @ 2024-09-24 10:11:00 +0000 UTC
// false @ 2024-09-24 10:12:00 +0000 UTC
// false @ 2024-09-24 10:13:00 +0000 UTC
// false @ 2024-09-24 10:14:00 +0000 UTC
// true @ 2024-09-24 10:15:00 +0000 UTC
// false @ 2024-09-24 10:16:00 +0000 UTC
// false @ 2024-09-24 10:17:00 +0000 UTC
// false @ 2024-09-24 10:18:00 +0000 UTC
// false @ 2024-09-24 10:19:00 +0000 UTC
// false @ 2024-09-24 10:20:00 +0000 UTC
// false @ 2024-09-24 10:21:00 +0000 UTC
// false @ 2024-09-24 10:22:00 +0000 UTC
// false @ 2024-09-24 10:23:00 +0000 UTC
// false @ 2024-09-24 10:24:00 +0000 UTC
// false @ 2024-09-24 10:25:00 +0000 UTC
// false @ 2024-09-24 10:26:00 +0000 UTC
// false @ 2024-09-24 10:27:00 +0000 UTC
// false @ 2024-09-24 10:28:00 +0000 UTC
// true @ 2024-09-24 10:29:00 +0000 UTC
// false @ 2024-09-24 10:30:00 +0000 UTC
// false @ 2024-09-24 10:31:00 +0000 UTC
// false @ 2024-09-24 10:32:00 +0000 UTC
// false @ 2024-09-24 10:33:00 +0000 UTC
// false @ 2024-09-24 10:34:00 +0000 UTC
// false @ 2024-09-24 10:35:00 +0000 UTC
// false @ 2024-09-24 10:36:00 +0000 UTC
// false @ 2024-09-24 10:37:00 +0000 UTC
// false @ 2024-09-24 10:38:00 +0000 UTC
// false @ 2024-09-24 10:39:00 +0000 UTC
// false @ 2024-09-24 10:40:00 +0000 UTC
// false @ 2024-09-24 10:41:00 +0000 UTC
// false @ 2024-09-24 10:42:00 +0000 UTC
// false @ 2024-09-24 10:43:00 +0000 UTC
// true @ 2024-09-24 10:44:00 +0000 UTC
// false @ 2024-09-24 10:45:00 +0000 UTC
// false @ 2024-09-24 10:46:00 +0000 UTC
// false @ 2024-09-24 10:47:00 +0000 UTC
// false @ 2024-09-24 10:48:00 +0000 UTC
// false @ 2024-09-24 10:49:00 +0000 UTC
// false @ 2024-09-24 10:50:00 +0000 UTC
// false @ 2024-09-24 10:51:00 +0000 UTC
// false @ 2024-09-24 10:52:00 +0000 UTC
// false @ 2024-09-24 10:53:00 +0000 UTC
// false @ 2024-09-24 10:54:00 +0000 UTC
// false @ 2024-09-24 10:55:00 +0000 UTC
// false @ 2024-09-24 10:56:00 +0000 UTC
// false @ 2024-09-24 10:57:00 +0000 UTC
// true @ 2024-09-24 10:58:00 +0000 UTC

// true @ 2024-09-24 10:01:40 +0000 UTC for last 100 events...
// false @ 2024-09-25 09:01:40 +0000 UTC for last 23 events...
// true @ 2024-09-25 10:03:20 +0000 UTC for last 200 events...
// false @ 2024-09-26 09:03:20 +0000 UTC for last 23 events...
}

func TestNewRateLimiter(t *testing.T) {
Expand All @@ -120,18 +84,22 @@ func TestNewRateLimiter(t *testing.T) {
{true, now.Add(time.Minute), 98},
{false, now.Add(time.Minute), 0},
{false, now.Add(time.Minute * 14), 0},
{true, now.Add(time.Minute * 15), 0},
{false, now.Add(time.Minute * 15), 0},
{false, now.Add(time.Minute * 16), 0},
{false, now.Add(time.Minute * 17), 0},
{true, now.Add(time.Minute * 30), 0},
{false, now.Add(time.Minute * 17), 0},
{true, now.Add(time.Hour * 24), 0},
{true, now.Add(time.Hour * 25), 0},
},
},
}
for _, tc := range cases {
rl := newRateLimiter(tc.cfg)
rl.last = tc.now

for _, evt := range tc.evts {
for i := 0; i <= evt.r; i++ {
if exp, got := evt.ok, rl.AllowN(evt.at, 1); exp != got {
if exp, got := evt.ok, rl.allowAt(evt.at); exp != got {
t.Fatalf("exp AllowN(%v, 1) to be %v; got %v", evt.at, exp, got)
}
}
Expand Down

0 comments on commit 994a9f4

Please sign in to comment.