diff --git a/background.go b/background.go index 2d9ac6e..3fe01d6 100644 --- a/background.go +++ b/background.go @@ -6,8 +6,6 @@ import ( "os/signal" "sync" "time" - - "github.com/go-redis/redis/v7" ) // Background is a top-level entity for the background-task processing. @@ -21,12 +19,7 @@ type Background struct { // NewBackground returns a new Background instance. func NewBackground(numWorkers int, opt *RedisOpt) *Background { - client := redis.NewClient(&redis.Options{ - Addr: opt.Addr, - Password: opt.Password, - DB: opt.DB, - }) - rdb := newRDB(client) + rdb := newRDB(opt) poller := newPoller(rdb, 5*time.Second, []string{scheduled, retry}) processor := newProcessor(rdb, numWorkers, nil) return &Background{ diff --git a/client.go b/client.go index 6a4d9a6..5f61313 100644 --- a/client.go +++ b/client.go @@ -3,7 +3,6 @@ package asynq import ( "time" - "github.com/go-redis/redis/v7" "github.com/google/uuid" ) @@ -14,12 +13,7 @@ type Client struct { // NewClient creates and returns a new client. func NewClient(opt *RedisOpt) *Client { - client := redis.NewClient(&redis.Options{ - Addr: opt.Addr, - Password: opt.Password, - DB: opt.DB, - }) - return &Client{rdb: newRDB(client)} + return &Client{rdb: newRDB(opt)} } // Process enqueues the task to be performed at a given time. diff --git a/rdb.go b/rdb.go index 34493f1..430e2c0 100644 --- a/rdb.go +++ b/rdb.go @@ -32,7 +32,12 @@ type rdb struct { client *redis.Client } -func newRDB(client *redis.Client) *rdb { +func newRDB(opt *RedisOpt) *rdb { + client := redis.NewClient(&redis.Options{ + Addr: opt.Addr, + Password: opt.Password, + DB: opt.DB, + }) return &rdb{client} } diff --git a/rdb_test.go b/rdb_test.go index 4e27f5e..b723851 100644 --- a/rdb_test.go +++ b/rdb_test.go @@ -12,8 +12,6 @@ import ( "github.com/google/uuid" ) -var client *redis.Client - func init() { rand.Seed(time.Now().UnixNano()) } @@ -28,15 +26,15 @@ var sortStrOpt = cmp.Transformer("SortStr", func(in []string) []string { // before returning an instance of rdb. func setup(t *testing.T) *rdb { t.Helper() - client = redis.NewClient(&redis.Options{ + r := newRDB(&RedisOpt{ Addr: "localhost:6379", DB: 15, // use database 15 to separate from other applications }) // Start each test with a clean slate. - if err := client.FlushDB().Err(); err != nil { + if err := r.client.FlushDB().Err(); err != nil { panic(err) } - return newRDB(client) + return r } func randomTask(taskType, qname string, payload map[string]interface{}) *taskMessage { @@ -63,19 +61,19 @@ func TestEnqueue(t *testing.T) { for _, tc := range tests { // clean up db before each test case. - if err := client.FlushDB().Err(); err != nil { + if err := r.client.FlushDB().Err(); err != nil { t.Fatal(err) } err := r.enqueue(tc.msg) if err != nil { t.Error(err) } - res := client.LRange(defaultQueue, 0, -1).Val() + res := r.client.LRange(defaultQueue, 0, -1).Val() if len(res) != 1 { t.Errorf("LIST %q has length %d, want 1", defaultQueue, len(res)) continue } - if !client.SIsMember(allQueues, defaultQueue).Val() { + if !r.client.SIsMember(allQueues, defaultQueue).Val() { t.Errorf("SISMEMBER %q %q = false, want true", allQueues, defaultQueue) } var persisted taskMessage @@ -104,7 +102,7 @@ func TestDequeue(t *testing.T) { for _, tc := range tests { // clean up db before each test case. - if err := client.FlushDB().Err(); err != nil { + if err := r.client.FlushDB().Err(); err != nil { t.Fatal(err) } for _, m := range tc.queued { @@ -116,7 +114,7 @@ func TestDequeue(t *testing.T) { defaultQueue, got, err, tc.want, tc.err) continue } - if l := client.LLen(inProgress).Val(); l != tc.inProgress { + if l := r.client.LLen(inProgress).Val(); l != tc.inProgress { t.Errorf("LIST %q has length %d, want %d", inProgress, l, tc.inProgress) } } @@ -168,17 +166,17 @@ func TestMoveAll(t *testing.T) { for _, tc := range tests { // clean up db before each test case. - if err := client.FlushDB().Err(); err != nil { + if err := r.client.FlushDB().Err(); err != nil { t.Error(err) continue } // seed src list. for _, msg := range tc.beforeSrc { - client.LPush(inProgress, msg) + r.client.LPush(inProgress, msg) } // seed dst list. for _, msg := range tc.beforeDst { - client.LPush(defaultQueue, msg) + r.client.LPush(defaultQueue, msg) } if err := r.moveAll(inProgress, defaultQueue); err != nil { @@ -186,11 +184,11 @@ func TestMoveAll(t *testing.T) { continue } - gotSrc := client.LRange(inProgress, 0, -1).Val() + gotSrc := r.client.LRange(inProgress, 0, -1).Val() if diff := cmp.Diff(tc.afterSrc, gotSrc, sortStrOpt); diff != "" { t.Errorf("mismatch found in %q (-want, +got)\n%s", inProgress, diff) } - gotDst := client.LRange(defaultQueue, 0, -1).Val() + gotDst := r.client.LRange(defaultQueue, 0, -1).Val() if diff := cmp.Diff(tc.afterDst, gotDst, sortStrOpt); diff != "" { t.Errorf("mismatch found in %q (-want, +got)\n%s", defaultQueue, diff) } @@ -242,10 +240,10 @@ func TestForward(t *testing.T) { for _, tc := range tests { // clean up db before each test case. - if err := client.FlushDB().Err(); err != nil { + if err := r.client.FlushDB().Err(); err != nil { t.Fatal(err) } - if err := client.ZAdd(scheduled, tc.tasks...).Err(); err != nil { + if err := r.client.ZAdd(scheduled, tc.tasks...).Err(); err != nil { t.Error(err) continue } @@ -255,12 +253,12 @@ func TestForward(t *testing.T) { t.Errorf("(*rdb).forward(%q) = %v, want nil", scheduled, err) continue } - gotQueued := client.LRange(defaultQueue, 0, -1).Val() + gotQueued := r.client.LRange(defaultQueue, 0, -1).Val() if diff := cmp.Diff(tc.wantQueued, gotQueued, sortStrOpt); diff != "" { t.Errorf("%q has %d tasks, want %d tasks; (-want, +got)\n%s", defaultQueue, len(gotQueued), len(tc.wantQueued), diff) continue } - gotScheduled := client.ZRangeByScore(scheduled, &redis.ZRangeBy{Min: "-inf", Max: "+inf"}).Val() + gotScheduled := r.client.ZRangeByScore(scheduled, &redis.ZRangeBy{Min: "-inf", Max: "+inf"}).Val() if diff := cmp.Diff(tc.wantScheduled, gotScheduled, sortStrOpt); diff != "" { t.Errorf("%q has %d tasks, want %d tasks; (-want, +got)\n%s", scheduled, len(gotScheduled), len(tc.wantScheduled), diff) continue