2
0
mirror of https://github.com/hibiken/asynq.git synced 2026-07-02 06:52:05 +08:00

Fix BatchEnqueueContext time comparison and add scheduled task support

BatchEnqueueContext had a time comparison bug where `now` was captured
before the loop but `processAt` was set to time.Now() inside
composeOptions during each iteration, causing all immediate tasks to be
incorrectly classified as scheduled and rejected.

Fix: move `now` capture inside the loop, after composeOptions.

Additionally, extend BatchEnqueueContext to support scheduled tasks in
the same pipeline. Tasks with a future ProcessAt are now routed to
scheduleCmd (ZADD to scheduled set) instead of being rejected. Only
unique and group tasks remain unsupported.

Changes:
- Add BatchEnqueueItem type pairing TaskMessage with optional ProcessAt
- Update Broker interface, RDB, and testbroker to use BatchEnqueueItem
- Route immediate tasks to enqueueCmd, scheduled tasks to scheduleCmd
- Return correct TaskState (Pending vs Scheduled) in results
- Add tests for immediate, scheduled, and mixed batch scenarios
This commit is contained in:
Erik Nilsen
2026-02-25 09:06:20 -08:00
parent 4e62d7e29d
commit 71ebcfa129
6 changed files with 253 additions and 37 deletions

View File

@@ -428,18 +428,23 @@ type BatchEnqueueResult struct {
// BatchEnqueueContext enqueues all given tasks using a single Redis pipeline round-trip. // BatchEnqueueContext enqueues all given tasks using a single Redis pipeline round-trip.
// Each task gets its own result so callers can handle partial success. // Each task gets its own result so callers can handle partial success.
// Only immediate-enqueue tasks are supported; scheduled, unique, and group tasks // Immediate and scheduled tasks are supported; unique and group tasks are
// are rejected with an error in the corresponding BatchEnqueueResult. // rejected with an error in the corresponding BatchEnqueueResult.
func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ...Option) []BatchEnqueueResult { func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ...Option) []BatchEnqueueResult {
results := make([]BatchEnqueueResult, len(tasks)) results := make([]BatchEnqueueResult, len(tasks))
if len(tasks) == 0 { if len(tasks) == 0 {
return results return results
} }
msgs := make([]*base.TaskMessage, 0, len(tasks)) type itemMeta struct {
msgIndexes := make([]int, 0, len(tasks)) state base.TaskState
processAt time.Time
}
items := make([]base.BatchEnqueueItem, 0, len(tasks))
itemIndexes := make([]int, 0, len(tasks))
itemMetas := make([]itemMeta, 0, len(tasks))
now := time.Now()
for i, task := range tasks { for i, task := range tasks {
if task == nil { if task == nil {
results[i] = BatchEnqueueResult{Err: fmt.Errorf("task cannot be nil")} results[i] = BatchEnqueueResult{Err: fmt.Errorf("task cannot be nil")}
@@ -455,10 +460,6 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ..
results[i] = BatchEnqueueResult{Err: err} results[i] = BatchEnqueueResult{Err: err}
continue continue
} }
if opt.processAt.After(now) {
results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support scheduled tasks")}
continue
}
if opt.group != "" { if opt.group != "" {
results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support group tasks")} results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support group tasks")}
continue continue
@@ -489,24 +490,38 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ..
Timeout: int64(timeout.Seconds()), Timeout: int64(timeout.Seconds()),
Retention: int64(opt.retention.Seconds()), Retention: int64(opt.retention.Seconds()),
} }
msgs = append(msgs, msg)
msgIndexes = append(msgIndexes, i) now := time.Now()
scheduled := opt.processAt.After(now)
item := base.BatchEnqueueItem{Msg: msg}
var meta itemMeta
if scheduled {
item.ProcessAt = opt.processAt
meta = itemMeta{state: base.TaskStateScheduled, processAt: opt.processAt}
} else {
meta = itemMeta{state: base.TaskStatePending, processAt: now}
}
items = append(items, item)
itemIndexes = append(itemIndexes, i)
itemMetas = append(itemMetas, meta)
} }
if len(msgs) == 0 { if len(items) == 0 {
return results return results
} }
_, err := c.broker.BatchEnqueue(ctx, msgs) _, err := c.broker.BatchEnqueue(ctx, items)
if err != nil { if err != nil {
for _, idx := range msgIndexes { for _, idx := range itemIndexes {
results[idx] = BatchEnqueueResult{Err: err} results[idx] = BatchEnqueueResult{Err: err}
} }
return results return results
} }
for j, idx := range msgIndexes { for j, idx := range itemIndexes {
info := newTaskInfo(msgs[j], base.TaskStatePending, now, nil) info := newTaskInfo(items[j].Msg, itemMetas[j].state, itemMetas[j].processAt, nil)
results[idx] = BatchEnqueueResult{TaskInfo: info} results[idx] = BatchEnqueueResult{TaskInfo: info}
} }
return results return results

View File

@@ -1661,3 +1661,127 @@ func TestClientEnqueueWithHeadersAndGroup(t *testing.T) {
} }
} }
} }
func TestBatchEnqueueContext_ImmediateTasks(t *testing.T) {
r := setup(t)
client := NewClient(getRedisConnOpt(t))
defer client.Close()
tasks := []*Task{
NewTask("task1", []byte("payload1")),
NewTask("task2", []byte("payload2")),
NewTask("task3", []byte("payload3")),
}
results := client.BatchEnqueueContext(context.Background(), tasks)
if len(results) != 3 {
t.Fatalf("BatchEnqueueContext returned %d results, want 3", len(results))
}
for i, res := range results {
if res.Err != nil {
t.Errorf("results[%d].Err = %v, want nil", i, res.Err)
}
if res.TaskInfo == nil {
t.Errorf("results[%d].TaskInfo is nil, want non-nil", i)
continue
}
if res.TaskInfo.Queue != "default" {
t.Errorf("results[%d].TaskInfo.Queue = %q, want %q", i, res.TaskInfo.Queue, "default")
}
if res.TaskInfo.State != TaskStatePending {
t.Errorf("results[%d].TaskInfo.State = %v, want %v", i, res.TaskInfo.State, TaskStatePending)
}
}
gotPending := h.GetPendingMessages(t, r, "default")
if len(gotPending) != 3 {
t.Errorf("len(pending) = %d, want 3", len(gotPending))
}
}
func TestBatchEnqueueContext_ScheduledTask(t *testing.T) {
r := setup(t)
client := NewClient(getRedisConnOpt(t))
defer client.Close()
future := time.Now().Add(1 * time.Hour)
tasks := []*Task{
NewTask("scheduled_task", []byte("payload"), ProcessAt(future)),
}
results := client.BatchEnqueueContext(context.Background(), tasks)
if len(results) != 1 {
t.Fatalf("BatchEnqueueContext returned %d results, want 1", len(results))
}
if results[0].Err != nil {
t.Fatalf("results[0].Err = %v, want nil", results[0].Err)
}
if results[0].TaskInfo == nil {
t.Fatal("results[0].TaskInfo is nil, want non-nil")
}
if results[0].TaskInfo.State != TaskStateScheduled {
t.Errorf("results[0].TaskInfo.State = %v, want %v", results[0].TaskInfo.State, TaskStateScheduled)
}
gotScheduled := h.GetScheduledMessages(t, r, "default")
if len(gotScheduled) != 1 {
t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled))
}
}
func TestBatchEnqueueContext_MixedBatch(t *testing.T) {
r := setup(t)
client := NewClient(getRedisConnOpt(t))
defer client.Close()
future := time.Now().Add(1 * time.Hour)
tasks := []*Task{
NewTask("immediate1", []byte("p1")),
NewTask("scheduled1", []byte("p2"), ProcessAt(future)),
NewTask("immediate2", []byte("p3")),
NewTask("grouped1", []byte("p4"), Group("mygroup")),
NewTask("immediate3", []byte("p5")),
}
results := client.BatchEnqueueContext(context.Background(), tasks)
if len(results) != 5 {
t.Fatalf("BatchEnqueueContext returned %d results, want 5", len(results))
}
// Immediate tasks (indices 0, 2, 4) should succeed with Pending state.
for _, idx := range []int{0, 2, 4} {
if results[idx].Err != nil {
t.Errorf("results[%d].Err = %v, want nil (immediate task)", idx, results[idx].Err)
}
if results[idx].TaskInfo == nil {
t.Errorf("results[%d].TaskInfo is nil, want non-nil", idx)
continue
}
if results[idx].TaskInfo.State != TaskStatePending {
t.Errorf("results[%d].TaskInfo.State = %v, want %v", idx, results[idx].TaskInfo.State, TaskStatePending)
}
}
// Scheduled task (index 1) should succeed with Scheduled state.
if results[1].Err != nil {
t.Errorf("results[1].Err = %v, want nil (scheduled task)", results[1].Err)
}
if results[1].TaskInfo != nil && results[1].TaskInfo.State != TaskStateScheduled {
t.Errorf("results[1].TaskInfo.State = %v, want %v", results[1].TaskInfo.State, TaskStateScheduled)
}
// Grouped task (index 3) should be rejected.
if results[3].Err == nil {
t.Error("results[3].Err is nil, want error for group task")
}
gotPending := h.GetPendingMessages(t, r, "default")
if len(gotPending) != 3 {
t.Errorf("len(pending) = %d, want 3", len(gotPending))
}
gotScheduled := h.GetScheduledMessages(t, r, "default")
if len(gotScheduled) != 1 {
t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled))
}
}

View File

@@ -684,6 +684,14 @@ func (l *Lease) IsValid() bool {
return l.expireAt.After(now) || l.expireAt.Equal(now) return l.expireAt.After(now) || l.expireAt.Equal(now)
} }
// BatchEnqueueItem pairs a task message with optional scheduling metadata for
// batch enqueue operations. If ProcessAt is zero, the task is enqueued for
// immediate processing; otherwise it is added to the scheduled set.
type BatchEnqueueItem struct {
Msg *TaskMessage
ProcessAt time.Time // zero value → immediate
}
// Broker is a message broker that supports operations to manage task queues. // Broker is a message broker that supports operations to manage task queues.
// //
// See rdb.RDB as a reference implementation. // See rdb.RDB as a reference implementation.
@@ -692,7 +700,7 @@ type Broker interface {
Close() error Close() error
Enqueue(ctx context.Context, msg *TaskMessage) error Enqueue(ctx context.Context, msg *TaskMessage) error
EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error
BatchEnqueue(ctx context.Context, msgs []*TaskMessage) (int, error) BatchEnqueue(ctx context.Context, items []BatchEnqueueItem) (int, error)
Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Dequeue(qnames ...string) (*TaskMessage, time.Time, error)
Done(ctx context.Context, msg *TaskMessage) error Done(ctx context.Context, msg *TaskMessage) error
MarkAsComplete(ctx context.Context, msg *TaskMessage) error MarkAsComplete(ctx context.Context, msg *TaskMessage) error

View File

@@ -142,37 +142,48 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
// BatchEnqueue adds all given tasks to their respective pending lists using a // BatchEnqueue adds all given tasks to their respective pending lists using a
// single Redis pipeline round-trip. It returns the number of newly enqueued // single Redis pipeline round-trip. It returns the number of newly enqueued
// messages (tasks whose IDs already exist in Redis are silently skipped). // messages (tasks whose IDs already exist in Redis are silently skipped).
func (r *RDB) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) { // BatchEnqueue adds all given tasks to Redis using a single pipeline round-trip.
// Each item is either enqueued immediately or scheduled based on its ProcessAt field.
func (r *RDB) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) {
var op errors.Op = "rdb.BatchEnqueue" var op errors.Op = "rdb.BatchEnqueue"
if len(msgs) == 0 { if len(items) == 0 {
return 0, nil return 0, nil
} }
pipe := r.client.Pipeline() pipe := r.client.Pipeline()
// Track which indices in the pipeline correspond to enqueueCmd results vs SADD commands.
type cmdIndex struct{ pipeIdx int } type cmdIndex struct{ pipeIdx int }
scriptCmds := make([]cmdIndex, 0, len(msgs)) scriptCmds := make([]cmdIndex, 0, len(items))
pipeLen := 0 pipeLen := 0
now := r.clock.Now().UnixNano() now := r.clock.Now().UnixNano()
for _, msg := range msgs { for _, item := range items {
encoded, err := base.EncodeMessage(msg) encoded, err := base.EncodeMessage(item.Msg)
if err != nil { if err != nil {
return 0, errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) return 0, errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
} }
if _, found := r.queuesPublished.Load(msg.Queue); !found { if _, found := r.queuesPublished.Load(item.Msg.Queue); !found {
pipe.SAdd(ctx, base.AllQueues, msg.Queue) pipe.SAdd(ctx, base.AllQueues, item.Msg.Queue)
r.queuesPublished.Store(msg.Queue, true) r.queuesPublished.Store(item.Msg.Queue, true)
pipeLen++ pipeLen++
} }
keys := []string{
base.TaskKey(msg.Queue, msg.ID), if item.ProcessAt.IsZero() {
base.PendingKey(msg.Queue), keys := []string{
base.TaskKey(item.Msg.Queue, item.Msg.ID),
base.PendingKey(item.Msg.Queue),
}
argv := []interface{}{encoded, item.Msg.ID, now}
enqueueCmd.Run(ctx, pipe, keys, argv...)
} else {
keys := []string{
base.TaskKey(item.Msg.Queue, item.Msg.ID),
base.ScheduledKey(item.Msg.Queue),
}
argv := []interface{}{encoded, item.ProcessAt.Unix(), item.Msg.ID}
scheduleCmd.Run(ctx, pipe, keys, argv...)
} }
argv := []interface{}{encoded, msg.ID, now}
enqueueCmd.Run(ctx, pipe, keys, argv...)
scriptCmds = append(scriptCmds, cmdIndex{pipeIdx: pipeLen}) scriptCmds = append(scriptCmds, cmdIndex{pipeIdx: pipeLen})
pipeLen++ pipeLen++
} }

View File

@@ -173,9 +173,13 @@ func TestBatchEnqueue(t *testing.T) {
t.Run("enqueue multiple tasks", func(t *testing.T) { t.Run("enqueue multiple tasks", func(t *testing.T) {
h.FlushDB(t, r.client) h.FlushDB(t, r.client)
msgs := []*base.TaskMessage{t1, t2, t3} items := []base.BatchEnqueueItem{
{Msg: t1},
{Msg: t2},
{Msg: t3},
}
n, err := r.BatchEnqueue(context.Background(), msgs) n, err := r.BatchEnqueue(context.Background(), items)
if err != nil { if err != nil {
t.Fatalf("BatchEnqueue returned error: %v", err) t.Fatalf("BatchEnqueue returned error: %v", err)
} }
@@ -183,7 +187,8 @@ func TestBatchEnqueue(t *testing.T) {
t.Errorf("BatchEnqueue returned %d, want 3", n) t.Errorf("BatchEnqueue returned %d, want 3", n)
} }
for _, msg := range msgs { for _, item := range items {
msg := item.Msg
pendingKey := base.PendingKey(msg.Queue) pendingKey := base.PendingKey(msg.Queue)
pendingIDs := r.client.LRange(context.Background(), pendingKey, 0, -1).Val() pendingIDs := r.client.LRange(context.Background(), pendingKey, 0, -1).Val()
found := false found := false
@@ -227,7 +232,11 @@ func TestBatchEnqueue(t *testing.T) {
dup := *t1 dup := *t1
newMsg := h.NewTaskMessage("new_task", nil) newMsg := h.NewTaskMessage("new_task", nil)
n, err := r.BatchEnqueue(context.Background(), []*base.TaskMessage{&dup, newMsg}) items := []base.BatchEnqueueItem{
{Msg: &dup},
{Msg: newMsg},
}
n, err := r.BatchEnqueue(context.Background(), items)
if err != nil { if err != nil {
t.Fatalf("BatchEnqueue returned error: %v", err) t.Fatalf("BatchEnqueue returned error: %v", err)
} }
@@ -235,6 +244,55 @@ func TestBatchEnqueue(t *testing.T) {
t.Errorf("BatchEnqueue returned %d, want 1 (duplicate should be skipped)", n) t.Errorf("BatchEnqueue returned %d, want 1 (duplicate should be skipped)", n)
} }
}) })
t.Run("scheduled tasks", func(t *testing.T) {
h.FlushDB(t, r.client)
future := time.Now().Add(1 * time.Hour)
s1 := h.NewTaskMessage("deferred_email", nil)
items := []base.BatchEnqueueItem{
{Msg: t1},
{Msg: s1, ProcessAt: future},
}
n, err := r.BatchEnqueue(context.Background(), items)
if err != nil {
t.Fatalf("BatchEnqueue returned error: %v", err)
}
if n != 2 {
t.Errorf("BatchEnqueue returned %d, want 2", n)
}
// Immediate task should be in pending.
pendingIDs := r.client.LRange(context.Background(), base.PendingKey(t1.Queue), 0, -1).Val()
foundPending := false
for _, id := range pendingIDs {
if id == t1.ID {
foundPending = true
}
}
if !foundPending {
t.Errorf("immediate task %s not found in pending list", t1.ID)
}
// Scheduled task should be in scheduled set.
scheduledIDs := r.client.ZRange(context.Background(), base.ScheduledKey(s1.Queue), 0, -1).Val()
foundScheduled := false
for _, id := range scheduledIDs {
if id == s1.ID {
foundScheduled = true
}
}
if !foundScheduled {
t.Errorf("scheduled task %s not found in scheduled set", s1.ID)
}
taskKey := base.TaskKey(s1.Queue, s1.ID)
state := r.client.HGet(context.Background(), taskKey, "state").Val()
if state != "scheduled" {
t.Errorf("state for scheduled task %s = %q, want %q", s1.ID, state, "scheduled")
}
})
} }
func TestEnqueueQueueCache(t *testing.T) { func TestEnqueueQueueCache(t *testing.T) {

View File

@@ -64,13 +64,13 @@ func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage,
return tb.real.EnqueueUnique(ctx, msg, ttl) return tb.real.EnqueueUnique(ctx, msg, ttl)
} }
func (tb *TestBroker) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) { func (tb *TestBroker) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) {
tb.mu.Lock() tb.mu.Lock()
defer tb.mu.Unlock() defer tb.mu.Unlock()
if tb.sleeping { if tb.sleeping {
return 0, errRedisDown return 0, errRedisDown
} }
return tb.real.BatchEnqueue(ctx, msgs) return tb.real.BatchEnqueue(ctx, items)
} }
func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) { func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) {