From 71ebcfa129854909f2dbce2619d57f9264d99a11 Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Wed, 25 Feb 2026 09:06:20 -0800 Subject: [PATCH] Fix BatchEnqueueContext time comparison and add scheduled task support BatchEnqueueContext had a time comparison bug where `now` was captured before the loop but `processAt` was set to time.Now() inside composeOptions during each iteration, causing all immediate tasks to be incorrectly classified as scheduled and rejected. Fix: move `now` capture inside the loop, after composeOptions. Additionally, extend BatchEnqueueContext to support scheduled tasks in the same pipeline. Tasks with a future ProcessAt are now routed to scheduleCmd (ZADD to scheduled set) instead of being rejected. Only unique and group tasks remain unsupported. Changes: - Add BatchEnqueueItem type pairing TaskMessage with optional ProcessAt - Update Broker interface, RDB, and testbroker to use BatchEnqueueItem - Route immediate tasks to enqueueCmd, scheduled tasks to scheduleCmd - Return correct TaskState (Pending vs Scheduled) in results - Add tests for immediate, scheduled, and mixed batch scenarios --- client.go | 47 +++++++---- client_test.go | 124 ++++++++++++++++++++++++++++++ internal/base/base.go | 10 ++- internal/rdb/rdb.go | 39 ++++++---- internal/rdb/rdb_test.go | 66 +++++++++++++++- internal/testbroker/testbroker.go | 4 +- 6 files changed, 253 insertions(+), 37 deletions(-) diff --git a/client.go b/client.go index 6c0bec9..544b0d7 100644 --- a/client.go +++ b/client.go @@ -428,18 +428,23 @@ type BatchEnqueueResult struct { // BatchEnqueueContext enqueues all given tasks using a single Redis pipeline round-trip. // Each task gets its own result so callers can handle partial success. -// Only immediate-enqueue tasks are supported; scheduled, unique, and group tasks -// are rejected with an error in the corresponding BatchEnqueueResult. +// Immediate and scheduled tasks are supported; unique and group tasks are +// rejected with an error in the corresponding BatchEnqueueResult. func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ...Option) []BatchEnqueueResult { results := make([]BatchEnqueueResult, len(tasks)) if len(tasks) == 0 { return results } - msgs := make([]*base.TaskMessage, 0, len(tasks)) - msgIndexes := make([]int, 0, len(tasks)) + type itemMeta struct { + state base.TaskState + processAt time.Time + } + + items := make([]base.BatchEnqueueItem, 0, len(tasks)) + itemIndexes := make([]int, 0, len(tasks)) + itemMetas := make([]itemMeta, 0, len(tasks)) - now := time.Now() for i, task := range tasks { if task == nil { results[i] = BatchEnqueueResult{Err: fmt.Errorf("task cannot be nil")} @@ -455,10 +460,6 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts .. results[i] = BatchEnqueueResult{Err: err} continue } - if opt.processAt.After(now) { - results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support scheduled tasks")} - continue - } if opt.group != "" { results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support group tasks")} continue @@ -489,24 +490,38 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts .. Timeout: int64(timeout.Seconds()), Retention: int64(opt.retention.Seconds()), } - msgs = append(msgs, msg) - msgIndexes = append(msgIndexes, i) + + now := time.Now() + scheduled := opt.processAt.After(now) + + item := base.BatchEnqueueItem{Msg: msg} + var meta itemMeta + if scheduled { + item.ProcessAt = opt.processAt + meta = itemMeta{state: base.TaskStateScheduled, processAt: opt.processAt} + } else { + meta = itemMeta{state: base.TaskStatePending, processAt: now} + } + + items = append(items, item) + itemIndexes = append(itemIndexes, i) + itemMetas = append(itemMetas, meta) } - if len(msgs) == 0 { + if len(items) == 0 { return results } - _, err := c.broker.BatchEnqueue(ctx, msgs) + _, err := c.broker.BatchEnqueue(ctx, items) if err != nil { - for _, idx := range msgIndexes { + for _, idx := range itemIndexes { results[idx] = BatchEnqueueResult{Err: err} } return results } - for j, idx := range msgIndexes { - info := newTaskInfo(msgs[j], base.TaskStatePending, now, nil) + for j, idx := range itemIndexes { + info := newTaskInfo(items[j].Msg, itemMetas[j].state, itemMetas[j].processAt, nil) results[idx] = BatchEnqueueResult{TaskInfo: info} } return results diff --git a/client_test.go b/client_test.go index e4104f5..8f5935f 100644 --- a/client_test.go +++ b/client_test.go @@ -1661,3 +1661,127 @@ func TestClientEnqueueWithHeadersAndGroup(t *testing.T) { } } } + +func TestBatchEnqueueContext_ImmediateTasks(t *testing.T) { + r := setup(t) + client := NewClient(getRedisConnOpt(t)) + defer client.Close() + + tasks := []*Task{ + NewTask("task1", []byte("payload1")), + NewTask("task2", []byte("payload2")), + NewTask("task3", []byte("payload3")), + } + + results := client.BatchEnqueueContext(context.Background(), tasks) + if len(results) != 3 { + t.Fatalf("BatchEnqueueContext returned %d results, want 3", len(results)) + } + for i, res := range results { + if res.Err != nil { + t.Errorf("results[%d].Err = %v, want nil", i, res.Err) + } + if res.TaskInfo == nil { + t.Errorf("results[%d].TaskInfo is nil, want non-nil", i) + continue + } + if res.TaskInfo.Queue != "default" { + t.Errorf("results[%d].TaskInfo.Queue = %q, want %q", i, res.TaskInfo.Queue, "default") + } + if res.TaskInfo.State != TaskStatePending { + t.Errorf("results[%d].TaskInfo.State = %v, want %v", i, res.TaskInfo.State, TaskStatePending) + } + } + + gotPending := h.GetPendingMessages(t, r, "default") + if len(gotPending) != 3 { + t.Errorf("len(pending) = %d, want 3", len(gotPending)) + } +} + +func TestBatchEnqueueContext_ScheduledTask(t *testing.T) { + r := setup(t) + client := NewClient(getRedisConnOpt(t)) + defer client.Close() + + future := time.Now().Add(1 * time.Hour) + tasks := []*Task{ + NewTask("scheduled_task", []byte("payload"), ProcessAt(future)), + } + + results := client.BatchEnqueueContext(context.Background(), tasks) + if len(results) != 1 { + t.Fatalf("BatchEnqueueContext returned %d results, want 1", len(results)) + } + if results[0].Err != nil { + t.Fatalf("results[0].Err = %v, want nil", results[0].Err) + } + if results[0].TaskInfo == nil { + t.Fatal("results[0].TaskInfo is nil, want non-nil") + } + if results[0].TaskInfo.State != TaskStateScheduled { + t.Errorf("results[0].TaskInfo.State = %v, want %v", results[0].TaskInfo.State, TaskStateScheduled) + } + + gotScheduled := h.GetScheduledMessages(t, r, "default") + if len(gotScheduled) != 1 { + t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled)) + } +} + +func TestBatchEnqueueContext_MixedBatch(t *testing.T) { + r := setup(t) + client := NewClient(getRedisConnOpt(t)) + defer client.Close() + + future := time.Now().Add(1 * time.Hour) + tasks := []*Task{ + NewTask("immediate1", []byte("p1")), + NewTask("scheduled1", []byte("p2"), ProcessAt(future)), + NewTask("immediate2", []byte("p3")), + NewTask("grouped1", []byte("p4"), Group("mygroup")), + NewTask("immediate3", []byte("p5")), + } + + results := client.BatchEnqueueContext(context.Background(), tasks) + if len(results) != 5 { + t.Fatalf("BatchEnqueueContext returned %d results, want 5", len(results)) + } + + // Immediate tasks (indices 0, 2, 4) should succeed with Pending state. + for _, idx := range []int{0, 2, 4} { + if results[idx].Err != nil { + t.Errorf("results[%d].Err = %v, want nil (immediate task)", idx, results[idx].Err) + } + if results[idx].TaskInfo == nil { + t.Errorf("results[%d].TaskInfo is nil, want non-nil", idx) + continue + } + if results[idx].TaskInfo.State != TaskStatePending { + t.Errorf("results[%d].TaskInfo.State = %v, want %v", idx, results[idx].TaskInfo.State, TaskStatePending) + } + } + + // Scheduled task (index 1) should succeed with Scheduled state. + if results[1].Err != nil { + t.Errorf("results[1].Err = %v, want nil (scheduled task)", results[1].Err) + } + if results[1].TaskInfo != nil && results[1].TaskInfo.State != TaskStateScheduled { + t.Errorf("results[1].TaskInfo.State = %v, want %v", results[1].TaskInfo.State, TaskStateScheduled) + } + + // Grouped task (index 3) should be rejected. + if results[3].Err == nil { + t.Error("results[3].Err is nil, want error for group task") + } + + gotPending := h.GetPendingMessages(t, r, "default") + if len(gotPending) != 3 { + t.Errorf("len(pending) = %d, want 3", len(gotPending)) + } + + gotScheduled := h.GetScheduledMessages(t, r, "default") + if len(gotScheduled) != 1 { + t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled)) + } +} diff --git a/internal/base/base.go b/internal/base/base.go index e1bd58e..f5e3609 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -684,6 +684,14 @@ func (l *Lease) IsValid() bool { return l.expireAt.After(now) || l.expireAt.Equal(now) } +// BatchEnqueueItem pairs a task message with optional scheduling metadata for +// batch enqueue operations. If ProcessAt is zero, the task is enqueued for +// immediate processing; otherwise it is added to the scheduled set. +type BatchEnqueueItem struct { + Msg *TaskMessage + ProcessAt time.Time // zero value → immediate +} + // Broker is a message broker that supports operations to manage task queues. // // See rdb.RDB as a reference implementation. @@ -692,7 +700,7 @@ type Broker interface { Close() error Enqueue(ctx context.Context, msg *TaskMessage) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error - BatchEnqueue(ctx context.Context, msgs []*TaskMessage) (int, error) + BatchEnqueue(ctx context.Context, items []BatchEnqueueItem) (int, error) Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Done(ctx context.Context, msg *TaskMessage) error MarkAsComplete(ctx context.Context, msg *TaskMessage) error diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 376f63a..13aae29 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -142,37 +142,48 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error { // BatchEnqueue adds all given tasks to their respective pending lists using a // single Redis pipeline round-trip. It returns the number of newly enqueued // messages (tasks whose IDs already exist in Redis are silently skipped). -func (r *RDB) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) { +// BatchEnqueue adds all given tasks to Redis using a single pipeline round-trip. +// Each item is either enqueued immediately or scheduled based on its ProcessAt field. +func (r *RDB) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) { var op errors.Op = "rdb.BatchEnqueue" - if len(msgs) == 0 { + if len(items) == 0 { return 0, nil } pipe := r.client.Pipeline() - // Track which indices in the pipeline correspond to enqueueCmd results vs SADD commands. type cmdIndex struct{ pipeIdx int } - scriptCmds := make([]cmdIndex, 0, len(msgs)) + scriptCmds := make([]cmdIndex, 0, len(items)) pipeLen := 0 now := r.clock.Now().UnixNano() - for _, msg := range msgs { - encoded, err := base.EncodeMessage(msg) + for _, item := range items { + encoded, err := base.EncodeMessage(item.Msg) if err != nil { return 0, errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) } - if _, found := r.queuesPublished.Load(msg.Queue); !found { - pipe.SAdd(ctx, base.AllQueues, msg.Queue) - r.queuesPublished.Store(msg.Queue, true) + if _, found := r.queuesPublished.Load(item.Msg.Queue); !found { + pipe.SAdd(ctx, base.AllQueues, item.Msg.Queue) + r.queuesPublished.Store(item.Msg.Queue, true) pipeLen++ } - keys := []string{ - base.TaskKey(msg.Queue, msg.ID), - base.PendingKey(msg.Queue), + + if item.ProcessAt.IsZero() { + keys := []string{ + base.TaskKey(item.Msg.Queue, item.Msg.ID), + base.PendingKey(item.Msg.Queue), + } + argv := []interface{}{encoded, item.Msg.ID, now} + enqueueCmd.Run(ctx, pipe, keys, argv...) + } else { + keys := []string{ + base.TaskKey(item.Msg.Queue, item.Msg.ID), + base.ScheduledKey(item.Msg.Queue), + } + argv := []interface{}{encoded, item.ProcessAt.Unix(), item.Msg.ID} + scheduleCmd.Run(ctx, pipe, keys, argv...) } - argv := []interface{}{encoded, msg.ID, now} - enqueueCmd.Run(ctx, pipe, keys, argv...) scriptCmds = append(scriptCmds, cmdIndex{pipeIdx: pipeLen}) pipeLen++ } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index a31c5ab..a1a873c 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -173,9 +173,13 @@ func TestBatchEnqueue(t *testing.T) { t.Run("enqueue multiple tasks", func(t *testing.T) { h.FlushDB(t, r.client) - msgs := []*base.TaskMessage{t1, t2, t3} + items := []base.BatchEnqueueItem{ + {Msg: t1}, + {Msg: t2}, + {Msg: t3}, + } - n, err := r.BatchEnqueue(context.Background(), msgs) + n, err := r.BatchEnqueue(context.Background(), items) if err != nil { t.Fatalf("BatchEnqueue returned error: %v", err) } @@ -183,7 +187,8 @@ func TestBatchEnqueue(t *testing.T) { t.Errorf("BatchEnqueue returned %d, want 3", n) } - for _, msg := range msgs { + for _, item := range items { + msg := item.Msg pendingKey := base.PendingKey(msg.Queue) pendingIDs := r.client.LRange(context.Background(), pendingKey, 0, -1).Val() found := false @@ -227,7 +232,11 @@ func TestBatchEnqueue(t *testing.T) { dup := *t1 newMsg := h.NewTaskMessage("new_task", nil) - n, err := r.BatchEnqueue(context.Background(), []*base.TaskMessage{&dup, newMsg}) + items := []base.BatchEnqueueItem{ + {Msg: &dup}, + {Msg: newMsg}, + } + n, err := r.BatchEnqueue(context.Background(), items) if err != nil { t.Fatalf("BatchEnqueue returned error: %v", err) } @@ -235,6 +244,55 @@ func TestBatchEnqueue(t *testing.T) { t.Errorf("BatchEnqueue returned %d, want 1 (duplicate should be skipped)", n) } }) + + t.Run("scheduled tasks", func(t *testing.T) { + h.FlushDB(t, r.client) + + future := time.Now().Add(1 * time.Hour) + s1 := h.NewTaskMessage("deferred_email", nil) + items := []base.BatchEnqueueItem{ + {Msg: t1}, + {Msg: s1, ProcessAt: future}, + } + + n, err := r.BatchEnqueue(context.Background(), items) + if err != nil { + t.Fatalf("BatchEnqueue returned error: %v", err) + } + if n != 2 { + t.Errorf("BatchEnqueue returned %d, want 2", n) + } + + // Immediate task should be in pending. + pendingIDs := r.client.LRange(context.Background(), base.PendingKey(t1.Queue), 0, -1).Val() + foundPending := false + for _, id := range pendingIDs { + if id == t1.ID { + foundPending = true + } + } + if !foundPending { + t.Errorf("immediate task %s not found in pending list", t1.ID) + } + + // Scheduled task should be in scheduled set. + scheduledIDs := r.client.ZRange(context.Background(), base.ScheduledKey(s1.Queue), 0, -1).Val() + foundScheduled := false + for _, id := range scheduledIDs { + if id == s1.ID { + foundScheduled = true + } + } + if !foundScheduled { + t.Errorf("scheduled task %s not found in scheduled set", s1.ID) + } + + taskKey := base.TaskKey(s1.Queue, s1.ID) + state := r.client.HGet(context.Background(), taskKey, "state").Val() + if state != "scheduled" { + t.Errorf("state for scheduled task %s = %q, want %q", s1.ID, state, "scheduled") + } + }) } func TestEnqueueQueueCache(t *testing.T) { diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 63adf7a..b02d8ad 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -64,13 +64,13 @@ func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, return tb.real.EnqueueUnique(ctx, msg, ttl) } -func (tb *TestBroker) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) { +func (tb *TestBroker) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return 0, errRedisDown } - return tb.real.BatchEnqueue(ctx, msgs) + return tb.real.BatchEnqueue(ctx, items) } func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) {