Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

safe.Checks + jobpool #64

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module github.com/aserto-dev/azm

go 1.22.10

toolchain go1.23.4
go 1.23.4

// replace github.com/aserto-dev/go-directory => ../go-directory

Expand Down
97 changes: 97 additions & 0 deletions jobpool/jobpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package jobpool

import (
"sync"

"github.com/pkg/errors"
)

var ErrJobPool = errors.New("job pool error")

// Consumer transforms IN to OUT.
type Consumer[IN any, OUT any] func(IN) OUT

// JobPool runs a sequence of tasks concurrently.
type JobPool[IN any, OUT any] struct {
consumer Consumer[IN, OUT]
consumerCount int
jobCount int
wg sync.WaitGroup
inbox chan job[IN]
outbox chan result[OUT]
producedCount int
}

type job[IN any] struct {
index int
task IN
}

type result[OUT any] struct {
index int
result OUT
}

// NewJobPool creates a new JobPool.
//
// If jobCount is zero, the number of jobs is unbounded.
// The number of consumers is the minimum of maxWorkers and jobCount.
func NewJobPool[IN any, OUT any](jobCount, maxConsumers int, consumer Consumer[IN, OUT]) *JobPool[IN, OUT] {
return &JobPool[IN, OUT]{
consumer: consumer,
consumerCount: min(maxConsumers, jobCount),
jobCount: jobCount,
inbox: make(chan job[IN], jobCount),
outbox: make(chan result[OUT], jobCount),
}
}

// Produces adds a job the to pool.
//
// Returns ErrJobPool if the pool was created with a non-zero jobCount
// and all jobs have already been produced.
//
// Note: Produce is not thread-safe.
func (jp *JobPool[IN, OUT]) Produce(in IN) error {
if jp.jobCount > 0 && jp.producedCount >= jp.jobCount {
return errors.Wrap(ErrJobPool, "job count exceeded")
}

jp.inbox <- job[IN]{jp.producedCount, in}
jp.producedCount++

return nil
}

// Start consuming jobs.
func (jp *JobPool[IN, OUT]) Start() {
for range jp.consumerCount {
jp.wg.Add(1)
go func() {
defer jp.wg.Done()

for job := range jp.inbox {
out := jp.consumer(job.task)
jp.outbox <- result[OUT]{job.index, out}
}
}()
}
}

// Wait for all jobs to complete and return their results.
//
// Results are returned in the order that jobs were produced.
func (jp *JobPool[IN, OUT]) Wait() []OUT {
close(jp.inbox)
go func() {
jp.wg.Wait()
close(jp.outbox)
}()

results := make([]OUT, len(jp.inbox))
for result := range jp.outbox {
results[result.index] = result.result
}

return results
}
51 changes: 51 additions & 0 deletions safe/checks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package safe

import (
"iter"

dsc3 "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr3 "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
)

type SafeChecks struct {
*dsr3.ChecksRequest
}

func Checks(i *dsr3.ChecksRequest) *SafeChecks {
return &SafeChecks{i}
}

// CheckRequests returns an iterator that materializes all checks in order.
func (c *SafeChecks) CheckRequests() iter.Seq[SafeCheck] {
return func(yield func(SafeCheck) bool) {
defaults := &dsc3.RelationIdentifier{
ObjectType: c.Default.ObjectType,
ObjectId: c.Default.ObjectId,
Relation: c.Default.Relation,
SubjectType: c.Default.SubjectType,
SubjectId: c.Default.SubjectId,
}

for _, check := range c.Checks {
req := &dsr3.CheckRequest{
ObjectType: fallback(check.ObjectType, defaults.ObjectType),
ObjectId: fallback(check.ObjectId, defaults.ObjectId),
Relation: fallback(check.Relation, defaults.Relation),
SubjectType: fallback(check.SubjectType, defaults.SubjectType),
SubjectId: fallback(check.SubjectId, defaults.SubjectId),
}
if !yield(SafeCheck{req}) {
break
}
}
}
}

func fallback[T comparable](val, fallback T) T {
var def T
if val == def {
return fallback
}
return val

}
Loading