diff --git a/client.go b/client.go index ab2bcbb..6c0bec9 100644 --- a/client.go +++ b/client.go @@ -420,6 +420,98 @@ func (c *Client) EnqueueContext(ctx context.Context, task *Task, opts ...Option) return newTaskInfo(msg, state, opt.processAt, nil), nil } +// BatchEnqueueResult holds the result of enqueuing a single task within a batch. +type BatchEnqueueResult struct { + TaskInfo *TaskInfo + Err error +} + +// BatchEnqueueContext enqueues all given tasks using a single Redis pipeline round-trip. +// Each task gets its own result so callers can handle partial success. +// Only immediate-enqueue tasks are supported; scheduled, unique, and group tasks +// are rejected with an error in the corresponding BatchEnqueueResult. +func (c *Client) BatchEnqueueContext(ctx context.Context, tasks []*Task, opts ...Option) []BatchEnqueueResult { + results := make([]BatchEnqueueResult, len(tasks)) + if len(tasks) == 0 { + return results + } + + msgs := make([]*base.TaskMessage, 0, len(tasks)) + msgIndexes := make([]int, 0, len(tasks)) + + now := time.Now() + for i, task := range tasks { + if task == nil { + results[i] = BatchEnqueueResult{Err: fmt.Errorf("task cannot be nil")} + continue + } + if strings.TrimSpace(task.Type()) == "" { + results[i] = BatchEnqueueResult{Err: fmt.Errorf("task typename cannot be empty")} + continue + } + merged := append(task.opts, opts...) + opt, err := composeOptions(merged...) + if err != nil { + results[i] = BatchEnqueueResult{Err: err} + continue + } + if opt.processAt.After(now) { + results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support scheduled tasks")} + continue + } + if opt.group != "" { + results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support group tasks")} + continue + } + if opt.uniqueTTL > 0 { + results[i] = BatchEnqueueResult{Err: fmt.Errorf("batch enqueue does not support unique tasks")} + continue + } + deadline := noDeadline + if !opt.deadline.IsZero() { + deadline = opt.deadline + } + timeout := noTimeout + if opt.timeout != 0 { + timeout = opt.timeout + } + if deadline.Equal(noDeadline) && timeout == noTimeout { + timeout = defaultTimeout + } + msg := &base.TaskMessage{ + ID: opt.taskID, + Type: task.Type(), + Payload: task.Payload(), + Headers: task.Headers(), + Queue: opt.queue, + Retry: opt.retry, + Deadline: deadline.Unix(), + Timeout: int64(timeout.Seconds()), + Retention: int64(opt.retention.Seconds()), + } + msgs = append(msgs, msg) + msgIndexes = append(msgIndexes, i) + } + + if len(msgs) == 0 { + return results + } + + _, err := c.broker.BatchEnqueue(ctx, msgs) + if err != nil { + for _, idx := range msgIndexes { + results[idx] = BatchEnqueueResult{Err: err} + } + return results + } + + for j, idx := range msgIndexes { + info := newTaskInfo(msgs[j], base.TaskStatePending, now, nil) + results[idx] = BatchEnqueueResult{TaskInfo: info} + } + return results +} + // Ping performs a ping against the redis connection. func (c *Client) Ping() error { return c.broker.Ping() diff --git a/internal/base/base.go b/internal/base/base.go index 390e24d..e1bd58e 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -692,6 +692,7 @@ type Broker interface { Close() error Enqueue(ctx context.Context, msg *TaskMessage) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error + BatchEnqueue(ctx context.Context, msgs []*TaskMessage) (int, error) Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Done(ctx context.Context, msg *TaskMessage) error MarkAsComplete(ctx context.Context, msg *TaskMessage) error diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 22df506..376f63a 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -139,6 +139,65 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error { return nil } +// BatchEnqueue adds all given tasks to their respective pending lists using a +// single Redis pipeline round-trip. It returns the number of newly enqueued +// messages (tasks whose IDs already exist in Redis are silently skipped). +func (r *RDB) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) { + var op errors.Op = "rdb.BatchEnqueue" + if len(msgs) == 0 { + return 0, nil + } + + pipe := r.client.Pipeline() + + // Track which indices in the pipeline correspond to enqueueCmd results vs SADD commands. + type cmdIndex struct{ pipeIdx int } + scriptCmds := make([]cmdIndex, 0, len(msgs)) + pipeLen := 0 + + now := r.clock.Now().UnixNano() + + for _, msg := range msgs { + encoded, err := base.EncodeMessage(msg) + if err != nil { + return 0, errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) + } + if _, found := r.queuesPublished.Load(msg.Queue); !found { + pipe.SAdd(ctx, base.AllQueues, msg.Queue) + r.queuesPublished.Store(msg.Queue, true) + pipeLen++ + } + keys := []string{ + base.TaskKey(msg.Queue, msg.ID), + base.PendingKey(msg.Queue), + } + argv := []interface{}{encoded, msg.ID, now} + enqueueCmd.Run(ctx, pipe, keys, argv...) + scriptCmds = append(scriptCmds, cmdIndex{pipeIdx: pipeLen}) + pipeLen++ + } + + cmds, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis pipeline error: %v", err)) + } + + enqueued := 0 + for _, sc := range scriptCmds { + if sc.pipeIdx >= len(cmds) { + continue + } + res, err := cmds[sc.pipeIdx].(*redis.Cmd).Result() + if err != nil { + continue + } + if n, ok := res.(int64); ok && n == 1 { + enqueued++ + } + } + return enqueued, nil +} + // enqueueUniqueCmd enqueues the task message if the task is unique. // // KEYS[1] -> unique key diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index ffab6fe..63adf7a 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -64,6 +64,15 @@ func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, return tb.real.EnqueueUnique(ctx, msg, ttl) } +func (tb *TestBroker) BatchEnqueue(ctx context.Context, msgs []*base.TaskMessage) (int, error) { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return 0, errRedisDown + } + return tb.real.BatchEnqueue(ctx, msgs) +} + func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) { tb.mu.Lock() defer tb.mu.Unlock()