mirror of
https://github.com/hibiken/asynq.git
synced 2026-07-02 10:50:50 +08:00
Fix BatchEnqueueContext time comparison and add scheduled task support
BatchEnqueueContext had a time comparison bug where `now` was captured before the loop but `processAt` was set to time.Now() inside composeOptions during each iteration, causing all immediate tasks to be incorrectly classified as scheduled and rejected. Fix: move `now` capture inside the loop, after composeOptions. Additionally, extend BatchEnqueueContext to support scheduled tasks in the same pipeline. Tasks with a future ProcessAt are now routed to scheduleCmd (ZADD to scheduled set) instead of being rejected. Only unique and group tasks remain unsupported. Changes: - Add BatchEnqueueItem type pairing TaskMessage with optional ProcessAt - Update Broker interface, RDB, and testbroker to use BatchEnqueueItem - Route immediate tasks to enqueueCmd, scheduled tasks to scheduleCmd - Return correct TaskState (Pending vs Scheduled) in results - Add tests for immediate, scheduled, and mixed batch scenarios
This commit is contained in:
47
client.go
47
client.go
@@ -428,18 +428,23 @@ type BatchEnqueueResult struct {
|
||||
|
||||
// BatchEnqueueContext enqueues all given tasks using a single Redis pipeline round-trip.
|
||||
// Each task gets its own result so callers can handle partial success.
|
||||
// Only immediate-enqueue tasks are supported; scheduled, unique, and group tasks
|
||||
// are rejected with an error in the corresponding BatchEnqueueResult.
|
||||
// Immediate and scheduled tasks are supported; unique and group tasks are
|
||||
// rejected with an error in the corresponding BatchEnqueueResult.
|
||||
func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ...Option) []BatchEnqueueResult {
|
||||
results := make([]BatchEnqueueResult, len(tasks))
|
||||
if len(tasks) == 0 {
|
||||
return results
|
||||
}
|
||||
|
||||
msgs := make([]*base.TaskMessage, 0, len(tasks))
|
||||
msgIndexes := make([]int, 0, len(tasks))
|
||||
type itemMeta struct {
|
||||
state base.TaskState
|
||||
processAt time.Time
|
||||
}
|
||||
|
||||
items := make([]base.BatchEnqueueItem, 0, len(tasks))
|
||||
itemIndexes := make([]int, 0, len(tasks))
|
||||
itemMetas := make([]itemMeta, 0, len(tasks))
|
||||
|
||||
now := time.Now()
|
||||
for i, task := range tasks {
|
||||
if task == nil {
|
||||
results[i] = BatchEnqueueResult{Err: fmt.Errorf("task cannot be nil")}
|
||||
@@ -455,10 +460,6 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ..
|
||||
results[i] = BatchEnqueueResult{Err: err}
|
||||
continue
|
||||
}
|
||||
if opt.processAt.After(now) {
|
||||
results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support scheduled tasks")}
|
||||
continue
|
||||
}
|
||||
if opt.group != "" {
|
||||
results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support group tasks")}
|
||||
continue
|
||||
@@ -489,24 +490,38 @@ func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ..
|
||||
Timeout: int64(timeout.Seconds()),
|
||||
Retention: int64(opt.retention.Seconds()),
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
msgIndexes = append(msgIndexes, i)
|
||||
|
||||
now := time.Now()
|
||||
scheduled := opt.processAt.After(now)
|
||||
|
||||
item := base.BatchEnqueueItem{Msg: msg}
|
||||
var meta itemMeta
|
||||
if scheduled {
|
||||
item.ProcessAt = opt.processAt
|
||||
meta = itemMeta{state: base.TaskStateScheduled, processAt: opt.processAt}
|
||||
} else {
|
||||
meta = itemMeta{state: base.TaskStatePending, processAt: now}
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
itemIndexes = append(itemIndexes, i)
|
||||
itemMetas = append(itemMetas, meta)
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
if len(items) == 0 {
|
||||
return results
|
||||
}
|
||||
|
||||
_, err := c.broker.BatchEnqueue(ctx, msgs)
|
||||
_, err := c.broker.BatchEnqueue(ctx, items)
|
||||
if err != nil {
|
||||
for _, idx := range msgIndexes {
|
||||
for _, idx := range itemIndexes {
|
||||
results[idx] = BatchEnqueueResult{Err: err}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
for j, idx := range msgIndexes {
|
||||
info := newTaskInfo(msgs[j], base.TaskStatePending, now, nil)
|
||||
for j, idx := range itemIndexes {
|
||||
info := newTaskInfo(items[j].Msg, itemMetas[j].state, itemMetas[j].processAt, nil)
|
||||
results[idx] = BatchEnqueueResult{TaskInfo: info}
|
||||
}
|
||||
return results
|
||||
|
||||
124
client_test.go
124
client_test.go
@@ -1661,3 +1661,127 @@ func TestClientEnqueueWithHeadersAndGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchEnqueueContext_ImmediateTasks(t *testing.T) {
|
||||
r := setup(t)
|
||||
client := NewClient(getRedisConnOpt(t))
|
||||
defer client.Close()
|
||||
|
||||
tasks := []*Task{
|
||||
NewTask("task1", []byte("payload1")),
|
||||
NewTask("task2", []byte("payload2")),
|
||||
NewTask("task3", []byte("payload3")),
|
||||
}
|
||||
|
||||
results := client.BatchEnqueueContext(context.Background(), tasks)
|
||||
if len(results) != 3 {
|
||||
t.Fatalf("BatchEnqueueContext returned %d results, want 3", len(results))
|
||||
}
|
||||
for i, res := range results {
|
||||
if res.Err != nil {
|
||||
t.Errorf("results[%d].Err = %v, want nil", i, res.Err)
|
||||
}
|
||||
if res.TaskInfo == nil {
|
||||
t.Errorf("results[%d].TaskInfo is nil, want non-nil", i)
|
||||
continue
|
||||
}
|
||||
if res.TaskInfo.Queue != "default" {
|
||||
t.Errorf("results[%d].TaskInfo.Queue = %q, want %q", i, res.TaskInfo.Queue, "default")
|
||||
}
|
||||
if res.TaskInfo.State != TaskStatePending {
|
||||
t.Errorf("results[%d].TaskInfo.State = %v, want %v", i, res.TaskInfo.State, TaskStatePending)
|
||||
}
|
||||
}
|
||||
|
||||
gotPending := h.GetPendingMessages(t, r, "default")
|
||||
if len(gotPending) != 3 {
|
||||
t.Errorf("len(pending) = %d, want 3", len(gotPending))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchEnqueueContext_ScheduledTask(t *testing.T) {
|
||||
r := setup(t)
|
||||
client := NewClient(getRedisConnOpt(t))
|
||||
defer client.Close()
|
||||
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
tasks := []*Task{
|
||||
NewTask("scheduled_task", []byte("payload"), ProcessAt(future)),
|
||||
}
|
||||
|
||||
results := client.BatchEnqueueContext(context.Background(), tasks)
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("BatchEnqueueContext returned %d results, want 1", len(results))
|
||||
}
|
||||
if results[0].Err != nil {
|
||||
t.Fatalf("results[0].Err = %v, want nil", results[0].Err)
|
||||
}
|
||||
if results[0].TaskInfo == nil {
|
||||
t.Fatal("results[0].TaskInfo is nil, want non-nil")
|
||||
}
|
||||
if results[0].TaskInfo.State != TaskStateScheduled {
|
||||
t.Errorf("results[0].TaskInfo.State = %v, want %v", results[0].TaskInfo.State, TaskStateScheduled)
|
||||
}
|
||||
|
||||
gotScheduled := h.GetScheduledMessages(t, r, "default")
|
||||
if len(gotScheduled) != 1 {
|
||||
t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchEnqueueContext_MixedBatch(t *testing.T) {
|
||||
r := setup(t)
|
||||
client := NewClient(getRedisConnOpt(t))
|
||||
defer client.Close()
|
||||
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
tasks := []*Task{
|
||||
NewTask("immediate1", []byte("p1")),
|
||||
NewTask("scheduled1", []byte("p2"), ProcessAt(future)),
|
||||
NewTask("immediate2", []byte("p3")),
|
||||
NewTask("grouped1", []byte("p4"), Group("mygroup")),
|
||||
NewTask("immediate3", []byte("p5")),
|
||||
}
|
||||
|
||||
results := client.BatchEnqueueContext(context.Background(), tasks)
|
||||
if len(results) != 5 {
|
||||
t.Fatalf("BatchEnqueueContext returned %d results, want 5", len(results))
|
||||
}
|
||||
|
||||
// Immediate tasks (indices 0, 2, 4) should succeed with Pending state.
|
||||
for _, idx := range []int{0, 2, 4} {
|
||||
if results[idx].Err != nil {
|
||||
t.Errorf("results[%d].Err = %v, want nil (immediate task)", idx, results[idx].Err)
|
||||
}
|
||||
if results[idx].TaskInfo == nil {
|
||||
t.Errorf("results[%d].TaskInfo is nil, want non-nil", idx)
|
||||
continue
|
||||
}
|
||||
if results[idx].TaskInfo.State != TaskStatePending {
|
||||
t.Errorf("results[%d].TaskInfo.State = %v, want %v", idx, results[idx].TaskInfo.State, TaskStatePending)
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduled task (index 1) should succeed with Scheduled state.
|
||||
if results[1].Err != nil {
|
||||
t.Errorf("results[1].Err = %v, want nil (scheduled task)", results[1].Err)
|
||||
}
|
||||
if results[1].TaskInfo != nil && results[1].TaskInfo.State != TaskStateScheduled {
|
||||
t.Errorf("results[1].TaskInfo.State = %v, want %v", results[1].TaskInfo.State, TaskStateScheduled)
|
||||
}
|
||||
|
||||
// Grouped task (index 3) should be rejected.
|
||||
if results[3].Err == nil {
|
||||
t.Error("results[3].Err is nil, want error for group task")
|
||||
}
|
||||
|
||||
gotPending := h.GetPendingMessages(t, r, "default")
|
||||
if len(gotPending) != 3 {
|
||||
t.Errorf("len(pending) = %d, want 3", len(gotPending))
|
||||
}
|
||||
|
||||
gotScheduled := h.GetScheduledMessages(t, r, "default")
|
||||
if len(gotScheduled) != 1 {
|
||||
t.Errorf("len(scheduled) = %d, want 1", len(gotScheduled))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -684,6 +684,14 @@ func (l *Lease) IsValid() bool {
|
||||
return l.expireAt.After(now) || l.expireAt.Equal(now)
|
||||
}
|
||||
|
||||
// BatchEnqueueItem pairs a task message with optional scheduling metadata for
|
||||
// batch enqueue operations. If ProcessAt is zero, the task is enqueued for
|
||||
// immediate processing; otherwise it is added to the scheduled set.
|
||||
type BatchEnqueueItem struct {
|
||||
Msg *TaskMessage
|
||||
ProcessAt time.Time // zero value → immediate
|
||||
}
|
||||
|
||||
// Broker is a message broker that supports operations to manage task queues.
|
||||
//
|
||||
// See rdb.RDB as a reference implementation.
|
||||
@@ -692,7 +700,7 @@ type Broker interface {
|
||||
Close() error
|
||||
Enqueue(ctx context.Context, msg *TaskMessage) error
|
||||
EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error
|
||||
BatchEnqueue(ctx context.Context, msgs []*TaskMessage) (int, error)
|
||||
BatchEnqueue(ctx context.Context, items []BatchEnqueueItem) (int, error)
|
||||
Dequeue(qnames ...string) (*TaskMessage, time.Time, error)
|
||||
Done(ctx context.Context, msg *TaskMessage) error
|
||||
MarkAsComplete(ctx context.Context, msg *TaskMessage) error
|
||||
|
||||
@@ -142,37 +142,48 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
|
||||
// BatchEnqueue adds all given tasks to their respective pending lists using a
|
||||
// single Redis pipeline round-trip. It returns the number of newly enqueued
|
||||
// messages (tasks whose IDs already exist in Redis are silently skipped).
|
||||
func (r *RDB) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) {
|
||||
// BatchEnqueue adds all given tasks to Redis using a single pipeline round-trip.
|
||||
// Each item is either enqueued immediately or scheduled based on its ProcessAt field.
|
||||
func (r *RDB) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) {
|
||||
var op errors.Op = "rdb.BatchEnqueue"
|
||||
if len(msgs) == 0 {
|
||||
if len(items) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Track which indices in the pipeline correspond to enqueueCmd results vs SADD commands.
|
||||
type cmdIndex struct{ pipeIdx int }
|
||||
scriptCmds := make([]cmdIndex, 0, len(msgs))
|
||||
scriptCmds := make([]cmdIndex, 0, len(items))
|
||||
pipeLen := 0
|
||||
|
||||
now := r.clock.Now().UnixNano()
|
||||
|
||||
for _, msg := range msgs {
|
||||
encoded, err := base.EncodeMessage(msg)
|
||||
for _, item := range items {
|
||||
encoded, err := base.EncodeMessage(item.Msg)
|
||||
if err != nil {
|
||||
return 0, errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
|
||||
}
|
||||
if _, found := r.queuesPublished.Load(msg.Queue); !found {
|
||||
pipe.SAdd(ctx, base.AllQueues, msg.Queue)
|
||||
r.queuesPublished.Store(msg.Queue, true)
|
||||
if _, found := r.queuesPublished.Load(item.Msg.Queue); !found {
|
||||
pipe.SAdd(ctx, base.AllQueues, item.Msg.Queue)
|
||||
r.queuesPublished.Store(item.Msg.Queue, true)
|
||||
pipeLen++
|
||||
}
|
||||
keys := []string{
|
||||
base.TaskKey(msg.Queue, msg.ID),
|
||||
base.PendingKey(msg.Queue),
|
||||
|
||||
if item.ProcessAt.IsZero() {
|
||||
keys := []string{
|
||||
base.TaskKey(item.Msg.Queue, item.Msg.ID),
|
||||
base.PendingKey(item.Msg.Queue),
|
||||
}
|
||||
argv := []interface{}{encoded, item.Msg.ID, now}
|
||||
enqueueCmd.Run(ctx, pipe, keys, argv...)
|
||||
} else {
|
||||
keys := []string{
|
||||
base.TaskKey(item.Msg.Queue, item.Msg.ID),
|
||||
base.ScheduledKey(item.Msg.Queue),
|
||||
}
|
||||
argv := []interface{}{encoded, item.ProcessAt.Unix(), item.Msg.ID}
|
||||
scheduleCmd.Run(ctx, pipe, keys, argv...)
|
||||
}
|
||||
argv := []interface{}{encoded, msg.ID, now}
|
||||
enqueueCmd.Run(ctx, pipe, keys, argv...)
|
||||
scriptCmds = append(scriptCmds, cmdIndex{pipeIdx: pipeLen})
|
||||
pipeLen++
|
||||
}
|
||||
|
||||
@@ -173,9 +173,13 @@ func TestBatchEnqueue(t *testing.T) {
|
||||
|
||||
t.Run("enqueue multiple tasks", func(t *testing.T) {
|
||||
h.FlushDB(t, r.client)
|
||||
msgs := []*base.TaskMessage{t1, t2, t3}
|
||||
items := []base.BatchEnqueueItem{
|
||||
{Msg: t1},
|
||||
{Msg: t2},
|
||||
{Msg: t3},
|
||||
}
|
||||
|
||||
n, err := r.BatchEnqueue(context.Background(), msgs)
|
||||
n, err := r.BatchEnqueue(context.Background(), items)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchEnqueue returned error: %v", err)
|
||||
}
|
||||
@@ -183,7 +187,8 @@ func TestBatchEnqueue(t *testing.T) {
|
||||
t.Errorf("BatchEnqueue returned %d, want 3", n)
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
for _, item := range items {
|
||||
msg := item.Msg
|
||||
pendingKey := base.PendingKey(msg.Queue)
|
||||
pendingIDs := r.client.LRange(context.Background(), pendingKey, 0, -1).Val()
|
||||
found := false
|
||||
@@ -227,7 +232,11 @@ func TestBatchEnqueue(t *testing.T) {
|
||||
dup := *t1
|
||||
newMsg := h.NewTaskMessage("new_task", nil)
|
||||
|
||||
n, err := r.BatchEnqueue(context.Background(), []*base.TaskMessage{&dup, newMsg})
|
||||
items := []base.BatchEnqueueItem{
|
||||
{Msg: &dup},
|
||||
{Msg: newMsg},
|
||||
}
|
||||
n, err := r.BatchEnqueue(context.Background(), items)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchEnqueue returned error: %v", err)
|
||||
}
|
||||
@@ -235,6 +244,55 @@ func TestBatchEnqueue(t *testing.T) {
|
||||
t.Errorf("BatchEnqueue returned %d, want 1 (duplicate should be skipped)", n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("scheduled tasks", func(t *testing.T) {
|
||||
h.FlushDB(t, r.client)
|
||||
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
s1 := h.NewTaskMessage("deferred_email", nil)
|
||||
items := []base.BatchEnqueueItem{
|
||||
{Msg: t1},
|
||||
{Msg: s1, ProcessAt: future},
|
||||
}
|
||||
|
||||
n, err := r.BatchEnqueue(context.Background(), items)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchEnqueue returned error: %v", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("BatchEnqueue returned %d, want 2", n)
|
||||
}
|
||||
|
||||
// Immediate task should be in pending.
|
||||
pendingIDs := r.client.LRange(context.Background(), base.PendingKey(t1.Queue), 0, -1).Val()
|
||||
foundPending := false
|
||||
for _, id := range pendingIDs {
|
||||
if id == t1.ID {
|
||||
foundPending = true
|
||||
}
|
||||
}
|
||||
if !foundPending {
|
||||
t.Errorf("immediate task %s not found in pending list", t1.ID)
|
||||
}
|
||||
|
||||
// Scheduled task should be in scheduled set.
|
||||
scheduledIDs := r.client.ZRange(context.Background(), base.ScheduledKey(s1.Queue), 0, -1).Val()
|
||||
foundScheduled := false
|
||||
for _, id := range scheduledIDs {
|
||||
if id == s1.ID {
|
||||
foundScheduled = true
|
||||
}
|
||||
}
|
||||
if !foundScheduled {
|
||||
t.Errorf("scheduled task %s not found in scheduled set", s1.ID)
|
||||
}
|
||||
|
||||
taskKey := base.TaskKey(s1.Queue, s1.ID)
|
||||
state := r.client.HGet(context.Background(), taskKey, "state").Val()
|
||||
if state != "scheduled" {
|
||||
t.Errorf("state for scheduled task %s = %q, want %q", s1.ID, state, "scheduled")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnqueueQueueCache(t *testing.T) {
|
||||
|
||||
@@ -64,13 +64,13 @@ func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage,
|
||||
return tb.real.EnqueueUnique(ctx, msg, ttl)
|
||||
}
|
||||
|
||||
func (tb *TestBroker) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) {
|
||||
func (tb *TestBroker) BatchEnqueue(ctx context.Context, items []base.BatchEnqueueItem) (int, error) {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
if tb.sleeping {
|
||||
return 0, errRedisDown
|
||||
}
|
||||
return tb.real.BatchEnqueue(ctx, msgs)
|
||||
return tb.real.BatchEnqueue(ctx, items)
|
||||
}
|
||||
|
||||
func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) {
|
||||
|
||||
Reference in New Issue
Block a user