2
0
mirror of https://github.com/hibiken/asynq.git synced 2026-07-02 10:50:50 +08:00

Fix BatchEnqueueContext time comparison and add scheduled task support

BatchEnqueueContext had a time comparison bug where `now` was captured
before the loop but `processAt` was set to time.Now() inside
composeOptions during each iteration, causing all immediate tasks to be
incorrectly classified as scheduled and rejected.

Fix: move `now` capture inside the loop, after composeOptions.

Additionally, extend BatchEnqueueContext to support scheduled tasks in
the same pipeline. Tasks with a future ProcessAt are now routed to
scheduleCmd (ZADD to scheduled set) instead of being rejected. Only
unique and group tasks remain unsupported.

Changes:
- Add BatchEnqueueItem type pairing TaskMessage with optional ProcessAt
- Update Broker interface, RDB, and testbroker to use BatchEnqueueItem
- Route immediate tasks to enqueueCmd, scheduled tasks to scheduleCmd
- Return correct TaskState (Pending vs Scheduled) in results
- Add tests for immediate, scheduled, and mixed batch scenarios
This commit is contained in:
Erik Nilsen
2026-02-25 09:06:20 -08:00
parent 4e62d7e29d
commit 71ebcfa129
6 changed files with 253 additions and 37 deletions

View File

@@ -428,18 +428,23 @@ type BatchEnqueueResult struct {
// BatchEnqueueContext enqueues all given tasks using a single Redis pipeline round-trip.
// Each task gets its own result so callers can handle partial success.
// Only immediate-enqueue tasks are supported; scheduled, unique, and group tasks
// are rejected with an error in the corresponding BatchEnqueueResult.
// Immediate and scheduled tasks are supported; unique and group tasks are
// rejected with an error in the corresponding BatchEnqueueResult.
func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ...Option) []BatchEnqueueResult {
results := make([]BatchEnqueueResult, len(tasks))
if len(tasks) == 0 {
return results
}
msgs := make([]*base.TaskMessage, 0, len(tasks))
msgIndexes := make([]int, 0, len(tasks))
type itemMeta struct {
state base.TaskState
processAt time.Time
}
items := make([]base.BatchEnqueueItem, 0, len(tasks))
itemIndexes := make([]int, 0, len(tasks))
itemMetas := make([]itemMeta, 0, len(tasks))
now := time.Now()
for i, task := range tasks {
if task == nil {
results[i] = BatchEnqueueResult{Err: fmt.Errorf("task cannot be nil")}
@@ -455,10 +460,6 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ..
results[i] = BatchEnqueueResult{Err: err}
continue
}
if opt.processAt.After(now) {
results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support scheduled tasks")}
continue
}
if opt.group != "" {
results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support group tasks")}
continue
@@ -489,24 +490,38 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ..
Timeout: int64(timeout.Seconds()),
Retention: int64(opt.retention.Seconds()),
}
msgs = append(msgs, msg)
msgIndexes = append(msgIndexes, i)
now := time.Now()
scheduled := opt.processAt.After(now)
item := base.BatchEnqueueItem{Msg: msg}
var meta itemMeta
if scheduled {
item.ProcessAt = opt.processAt
meta = itemMeta{state: base.TaskStateScheduled, processAt: opt.processAt}
} else {
meta = itemMeta{state: base.TaskStatePending, processAt: now}
}
items = append(items, item)
itemIndexes = append(itemIndexes, i)
itemMetas = append(itemMetas, meta)
}
if len(msgs) == 0 {
if len(items) == 0 {
return results
}
_, err := c.broker.BatchEnqueue(ctx, msgs)
_, err := c.broker.BatchEnqueue(ctx, items)
if err != nil {
for _, idx := range msgIndexes {
for _, idx := range itemIndexes {
results[idx] = BatchEnqueueResult{Err: err}
}
return results
}
for j, idx := range msgIndexes {
info := newTaskInfo(msgs[j], base.TaskStatePending, now, nil)
for j, idx := range itemIndexes {
info := newTaskInfo(items[j].Msg, itemMetas[j].state, itemMetas[j].processAt, nil)
results[idx] = BatchEnqueueResult{TaskInfo: info}
}
return results

View File

@@ -1661,3 +1661,127 @@ func TestClientEnqueueWithHeadersAndGroup(t *testing.T) {
}
}
}
func TestBatchEnqueueContext_ImmediateTasks(t *testing.T) {
r := setup(t)
client := NewClient(getRedisConnOpt(t))
defer client.Close()
tasks := []*Task{
NewTask("task1", []byte("payload1")),
NewTask("task2", []byte("payload2")),
NewTask("task3", []byte("payload3")),
}
results := client.BatchEnqueueContext(context.Background(), tasks)
if len(results) != 3 {
t.Fatalf("BatchEnqueueContext returned %d results, want 3", len(results))
}
for i, res := range results {
if res.Err != nil {
t.Errorf("results[%d].Err = %v, want nil", i, res.Err)
}
if res.TaskInfo == nil {
t.Errorf("results[%d].TaskInfo is nil, want non-nil", i)
continue
}
if res.TaskInfo.Queue != "default" {
t.Errorf("results[%d].TaskInfo.Queue = %q, want %q", i, res.TaskInfo.Queue, "default")
}
if res.TaskInfo.State != TaskStatePending {
t.Errorf("results[%d].TaskInfo.State = %v, want %v", i, res.TaskInfo.State, TaskStatePending)
}
}
gotPending := h.GetPendingMessages(t, r, "default")
if len(gotPending) != 3 {
t.Errorf("len(pending) = %d, want 3", len(gotPending))
}
}
func TestBatchEnqueueContext_ScheduledTask(t *testing.T) {
r := setup(t)
client := NewClient(getRedisConnOpt(t))
defer client.Close()
future := time.Now().Add(1 * time.Hour)
tasks := []*Task{
NewTask("scheduled_task", []byte("payload"), ProcessAt(future)),
}
results := client.BatchEnqueueContext(context.Background(), tasks)
if len(results) != 1 {
t.Fatalf("BatchEnqueueContext returned %d results, want 1", len(results))
}
if results[0].Err != nil {
t.Fatalf("results[0].Err = %v, want nil", results[0].Err)
}
if results[0].TaskInfo == nil {
t.Fatal("results[0].TaskInfo is nil, want non-nil")
}
if results[0].TaskInfo.State != TaskStateScheduled {
t.Errorf("results[0].TaskInfo.State = %v, want %v", results[0].TaskInfo.State, TaskStateScheduled)
}
gotScheduled := h.GetScheduledMessages(t, r, "default")
if len(gotScheduled) != 1 {
t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled))
}
}
func TestBatchEnqueueContext_MixedBatch(t *testing.T) {
r := setup(t)
client := NewClient(getRedisConnOpt(t))
defer client.Close()
future := time.Now().Add(1 * time.Hour)
tasks := []*Task{
NewTask("immediate1", []byte("p1")),
NewTask("scheduled1", []byte("p2"), ProcessAt(future)),
NewTask("immediate2", []byte("p3")),
NewTask("grouped1", []byte("p4"), Group("mygroup")),
NewTask("immediate3", []byte("p5")),
}
results := client.BatchEnqueueContext(context.Background(), tasks)
if len(results) != 5 {
t.Fatalf("BatchEnqueueContext returned %d results, want 5", len(results))
}
// Immediate tasks (indices 0, 2, 4) should succeed with Pending state.
for _, idx := range []int{0, 2, 4} {
if results[idx].Err != nil {
t.Errorf("results[%d].Err = %v, want nil (immediate task)", idx, results[idx].Err)
}
if results[idx].TaskInfo == nil {
t.Errorf("results[%d].TaskInfo is nil, want non-nil", idx)
continue
}
if results[idx].TaskInfo.State != TaskStatePending {
t.Errorf("results[%d].TaskInfo.State = %v, want %v", idx, results[idx].TaskInfo.State, TaskStatePending)
}
}
// Scheduled task (index 1) should succeed with Scheduled state.
if results[1].Err != nil {
t.Errorf("results[1].Err = %v, want nil (scheduled task)", results[1].Err)
}
if results[1].TaskInfo != nil && results[1].TaskInfo.State != TaskStateScheduled {
t.Errorf("results[1].TaskInfo.State = %v, want %v", results[1].TaskInfo.State, TaskStateScheduled)
}
// Grouped task (index 3) should be rejected.
if results[3].Err == nil {
t.Error("results[3].Err is nil, want error for group task")
}
gotPending := h.GetPendingMessages(t, r, "default")
if len(gotPending) != 3 {
t.Errorf("len(pending) = %d, want 3", len(gotPending))
}
gotScheduled := h.GetScheduledMessages(t, r, "default")
if len(gotScheduled) != 1 {
t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled))
}
}

View File

@@ -684,6 +684,14 @@ func (l *Lease) IsValid() bool {
return l.expireAt.After(now) || l.expireAt.Equal(now)
}
// BatchEnqueueItem pairs a task message with optional scheduling metadata for
// batch enqueue operations. If ProcessAt is zero, the task is enqueued for
// immediate processing; otherwise it is added to the scheduled set.
type BatchEnqueueItem struct {
Msg *TaskMessage
ProcessAt time.Time // zero value → immediate
}
// Broker is a message broker that supports operations to manage task queues.
//
// See rdb.RDB as a reference implementation.
@@ -692,7 +700,7 @@ type Broker interface {
Close() error
Enqueue(ctx context.Context, msg *TaskMessage) error
EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error
BatchEnqueue(ctx context.Context, msgs []*TaskMessage) (int, error)
BatchEnqueue(ctx context.Context, items []BatchEnqueueItem) (int, error)
Dequeue(qnames ...string) (*TaskMessage, time.Time, error)
Done(ctx context.Context, msg *TaskMessage) error
MarkAsComplete(ctx context.Context, msg *TaskMessage) error

View File

@@ -142,37 +142,48 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
// BatchEnqueue adds all given tasks to their respective pending lists using a
// single Redis pipeline round-trip. It returns the number of newly enqueued
// messages (tasks whose IDs already exist in Redis are silently skipped).
func (r *RDB) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) {
// BatchEnqueue adds all given tasks to Redis using a single pipeline round-trip.
// Each item is either enqueued immediately or scheduled based on its ProcessAt field.
func (r *RDB) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) {
var op errors.Op = "rdb.BatchEnqueue"
if len(msgs) == 0 {
if len(items) == 0 {
return 0, nil
}
pipe := r.client.Pipeline()
// Track which indices in the pipeline correspond to enqueueCmd results vs SADD commands.
type cmdIndex struct{ pipeIdx int }
scriptCmds := make([]cmdIndex, 0, len(msgs))
scriptCmds := make([]cmdIndex, 0, len(items))
pipeLen := 0
now := r.clock.Now().UnixNano()
for _, msg := range msgs {
encoded, err := base.EncodeMessage(msg)
for _, item := range items {
encoded, err := base.EncodeMessage(item.Msg)
if err != nil {
return 0, errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
}
if _, found := r.queuesPublished.Load(msg.Queue); !found {
pipe.SAdd(ctx, base.AllQueues, msg.Queue)
r.queuesPublished.Store(msg.Queue, true)
if _, found := r.queuesPublished.Load(item.Msg.Queue); !found {
pipe.SAdd(ctx, base.AllQueues, item.Msg.Queue)
r.queuesPublished.Store(item.Msg.Queue, true)
pipeLen++
}
keys := []string{
base.TaskKey(msg.Queue, msg.ID),
base.PendingKey(msg.Queue),
if item.ProcessAt.IsZero() {
keys := []string{
base.TaskKey(item.Msg.Queue, item.Msg.ID),
base.PendingKey(item.Msg.Queue),
}
argv := []interface{}{encoded, item.Msg.ID, now}
enqueueCmd.Run(ctx, pipe, keys, argv...)
} else {
keys := []string{
base.TaskKey(item.Msg.Queue, item.Msg.ID),
base.ScheduledKey(item.Msg.Queue),
}
argv := []interface{}{encoded, item.ProcessAt.Unix(), item.Msg.ID}
scheduleCmd.Run(ctx, pipe, keys, argv...)
}
argv := []interface{}{encoded, msg.ID, now}
enqueueCmd.Run(ctx, pipe, keys, argv...)
scriptCmds = append(scriptCmds, cmdIndex{pipeIdx: pipeLen})
pipeLen++
}

View File

@@ -173,9 +173,13 @@ func TestBatchEnqueue(t *testing.T) {
t.Run("enqueue multiple tasks", func(t *testing.T) {
h.FlushDB(t, r.client)
msgs := []*base.TaskMessage{t1, t2, t3}
items := []base.BatchEnqueueItem{
{Msg: t1},
{Msg: t2},
{Msg: t3},
}
n, err := r.BatchEnqueue(context.Background(), msgs)
n, err := r.BatchEnqueue(context.Background(), items)
if err != nil {
t.Fatalf("BatchEnqueue returned error: %v", err)
}
@@ -183,7 +187,8 @@ func TestBatchEnqueue(t *testing.T) {
t.Errorf("BatchEnqueue returned %d, want 3", n)
}
for _, msg := range msgs {
for _, item := range items {
msg := item.Msg
pendingKey := base.PendingKey(msg.Queue)
pendingIDs := r.client.LRange(context.Background(), pendingKey, 0, -1).Val()
found := false
@@ -227,7 +232,11 @@ func TestBatchEnqueue(t *testing.T) {
dup := *t1
newMsg := h.NewTaskMessage("new_task", nil)
n, err := r.BatchEnqueue(context.Background(), []*base.TaskMessage{&dup, newMsg})
items := []base.BatchEnqueueItem{
{Msg: &dup},
{Msg: newMsg},
}
n, err := r.BatchEnqueue(context.Background(), items)
if err != nil {
t.Fatalf("BatchEnqueue returned error: %v", err)
}
@@ -235,6 +244,55 @@ func TestBatchEnqueue(t *testing.T) {
t.Errorf("BatchEnqueue returned %d, want 1 (duplicate should be skipped)", n)
}
})
t.Run("scheduled tasks", func(t *testing.T) {
h.FlushDB(t, r.client)
future := time.Now().Add(1 * time.Hour)
s1 := h.NewTaskMessage("deferred_email", nil)
items := []base.BatchEnqueueItem{
{Msg: t1},
{Msg: s1, ProcessAt: future},
}
n, err := r.BatchEnqueue(context.Background(), items)
if err != nil {
t.Fatalf("BatchEnqueue returned error: %v", err)
}
if n != 2 {
t.Errorf("BatchEnqueue returned %d, want 2", n)
}
// Immediate task should be in pending.
pendingIDs := r.client.LRange(context.Background(), base.PendingKey(t1.Queue), 0, -1).Val()
foundPending := false
for _, id := range pendingIDs {
if id == t1.ID {
foundPending = true
}
}
if !foundPending {
t.Errorf("immediate task %s not found in pending list", t1.ID)
}
// Scheduled task should be in scheduled set.
scheduledIDs := r.client.ZRange(context.Background(), base.ScheduledKey(s1.Queue), 0, -1).Val()
foundScheduled := false
for _, id := range scheduledIDs {
if id == s1.ID {
foundScheduled = true
}
}
if !foundScheduled {
t.Errorf("scheduled task %s not found in scheduled set", s1.ID)
}
taskKey := base.TaskKey(s1.Queue, s1.ID)
state := r.client.HGet(context.Background(), taskKey, "state").Val()
if state != "scheduled" {
t.Errorf("state for scheduled task %s = %q, want %q", s1.ID, state, "scheduled")
}
})
}
func TestEnqueueQueueCache(t *testing.T) {

View File

@@ -64,13 +64,13 @@ func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage,
return tb.real.EnqueueUnique(ctx, msg, ttl)
}
func (tb *TestBroker) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) {
func (tb *TestBroker) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) {
tb.mu.Lock()
defer tb.mu.Unlock()
if tb.sleeping {
return 0, errRedisDown
}
return tb.real.BatchEnqueue(ctx, msgs)
return tb.real.BatchEnqueue(ctx, items)
}
func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) {