-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #64 from aserto-dev/checks
safe.Checks + jobpool
- Loading branch information
Showing
3 changed files
with
149 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package jobpool | ||
|
||
import ( | ||
"sync" | ||
|
||
"github.com/pkg/errors" | ||
) | ||
|
||
var ErrJobPool = errors.New("job pool error") | ||
|
||
// Consumer transforms IN to OUT. | ||
type Consumer[IN any, OUT any] func(IN) OUT | ||
|
||
// JobPool runs a sequence of tasks concurrently. | ||
type JobPool[IN any, OUT any] struct { | ||
consumer Consumer[IN, OUT] | ||
consumerCount int | ||
jobCount int | ||
wg sync.WaitGroup | ||
inbox chan job[IN] | ||
outbox chan result[OUT] | ||
producedCount int | ||
} | ||
|
||
type job[IN any] struct { | ||
index int | ||
task IN | ||
} | ||
|
||
type result[OUT any] struct { | ||
index int | ||
result OUT | ||
} | ||
|
||
// NewJobPool creates a new JobPool. | ||
// | ||
// If jobCount is zero, the number of jobs is unbounded. | ||
// The number of consumers is the minimum of maxWorkers and jobCount. | ||
func NewJobPool[IN any, OUT any](jobCount, maxConsumers int, consumer Consumer[IN, OUT]) *JobPool[IN, OUT] { | ||
return &JobPool[IN, OUT]{ | ||
consumer: consumer, | ||
consumerCount: min(maxConsumers, jobCount), | ||
jobCount: jobCount, | ||
inbox: make(chan job[IN], jobCount), | ||
outbox: make(chan result[OUT], jobCount), | ||
} | ||
} | ||
|
||
// Produces adds a job the to pool. | ||
// | ||
// Returns ErrJobPool if the pool was created with a non-zero jobCount | ||
// and all jobs have already been produced. | ||
// | ||
// Note: Produce is not thread-safe. | ||
func (jp *JobPool[IN, OUT]) Produce(in IN) error { | ||
if jp.jobCount > 0 && jp.producedCount >= jp.jobCount { | ||
return errors.Wrap(ErrJobPool, "job count exceeded") | ||
} | ||
|
||
jp.inbox <- job[IN]{jp.producedCount, in} | ||
jp.producedCount++ | ||
|
||
return nil | ||
} | ||
|
||
// Start consuming jobs. | ||
func (jp *JobPool[IN, OUT]) Start() { | ||
for range jp.consumerCount { | ||
jp.wg.Add(1) | ||
go func() { | ||
defer jp.wg.Done() | ||
|
||
for job := range jp.inbox { | ||
out := jp.consumer(job.task) | ||
jp.outbox <- result[OUT]{job.index, out} | ||
} | ||
}() | ||
} | ||
} | ||
|
||
// Wait for all jobs to complete and return their results. | ||
// | ||
// Results are returned in the order that jobs were produced. | ||
func (jp *JobPool[IN, OUT]) Wait() []OUT { | ||
close(jp.inbox) | ||
go func() { | ||
jp.wg.Wait() | ||
close(jp.outbox) | ||
}() | ||
|
||
results := make([]OUT, len(jp.inbox)) | ||
for result := range jp.outbox { | ||
results[result.index] = result.result | ||
} | ||
|
||
return results | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package safe | ||
|
||
import ( | ||
"iter" | ||
|
||
dsc3 "github.com/aserto-dev/go-directory/aserto/directory/common/v3" | ||
dsr3 "github.com/aserto-dev/go-directory/aserto/directory/reader/v3" | ||
) | ||
|
||
type SafeChecks struct { | ||
*dsr3.ChecksRequest | ||
} | ||
|
||
func Checks(i *dsr3.ChecksRequest) *SafeChecks { | ||
return &SafeChecks{i} | ||
} | ||
|
||
// CheckRequests returns an iterator that materializes all checks in order. | ||
func (c *SafeChecks) CheckRequests() iter.Seq[SafeCheck] { | ||
return func(yield func(SafeCheck) bool) { | ||
defaults := &dsc3.RelationIdentifier{ | ||
ObjectType: c.Default.ObjectType, | ||
ObjectId: c.Default.ObjectId, | ||
Relation: c.Default.Relation, | ||
SubjectType: c.Default.SubjectType, | ||
SubjectId: c.Default.SubjectId, | ||
} | ||
|
||
for _, check := range c.Checks { | ||
req := &dsr3.CheckRequest{ | ||
ObjectType: fallback(check.ObjectType, defaults.ObjectType), | ||
ObjectId: fallback(check.ObjectId, defaults.ObjectId), | ||
Relation: fallback(check.Relation, defaults.Relation), | ||
SubjectType: fallback(check.SubjectType, defaults.SubjectType), | ||
SubjectId: fallback(check.SubjectId, defaults.SubjectId), | ||
} | ||
if !yield(SafeCheck{req}) { | ||
break | ||
} | ||
} | ||
} | ||
} | ||
|
||
func fallback[T comparable](val, fallback T) T { | ||
var def T | ||
if val == def { | ||
return fallback | ||
} | ||
return val | ||
|
||
} |