Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New flagged storage for parallel joins #737

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ name = "cluster_bomb"
[[example]]
name = "bitset"

[[example]]
name = "track"
# TODO: restricted storage is unsound without streaming iterator or some changes made to it (i.e. to not allow access to component data on other entities in the mutable case)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My addition of GAT support of the storage traits should allow for this, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Join::Type would probably need to be made into a GAT as well. Are there any GAT based iterator crates that could serve as a replacement for the Iterator trait from std? I'm not seeing any 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty trivial to create one.

trait StreamingIterator {
    type Item<'a>;

    fn next(&mut self) -> Self::Item<'_>;
}

I've not seen any 'canonical' streaming iterator crates that use GATs yet. Perhaps when std gets such a trait (here's hoping!) with for integration we can use that instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will need some of the common adapters/methods: map, for_each, filter, filter_map, collect. Which seems fairly doable but is a bit more than the just the minimal version.

# [[example]]
# name = "track"

[[example]]
name = "ordered_track"
Expand Down
2 changes: 1 addition & 1 deletion src/changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl<'a, T> Join for &'a mut ChangeSet<T> {
// exists yet.
unsafe fn get(v: &mut Self::Value, id: Index) -> Self::Type {
let value: *mut Self::Value = v as *mut Self::Value;
(*value).get_mut(id)
(*value).get_access_mut(id)
}
}

Expand Down
22 changes: 18 additions & 4 deletions src/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,24 @@ pub trait Join {
/// then illegal memory access can occur.
unsafe fn open(self) -> (Self::Mask, Self::Value);

// TODO: copy safety notes for callers to the impls of this method.
// TODO: also evaluate all impls
/// Get a joined component value by a given index.
///
/// # Safety
///
/// * A call to `get` must be preceded by a check if `id` is part of
/// `Self::Mask`
/// * The implementation of this method may use unsafe code, but has no
/// invariants to meet
/// `Self::Mask`.
/// * The caller must ensure the lifetimes of mutable references returned from a call
/// of this method end before subsequent calls with the same `id`.
/// * Conversly, the implementation of the method must never create a mutable reference (even if it isn't
/// returned) that was returned by a previous call with a distinct `id`. Since there is no
/// requirement that the caller (if there was `JoinIter` would half to be a streaming
/// iterator).
/// are no guarantees that the caller will release
/// * The implementation of this method may use `unsafe` to extend the lifetime of returned references but
/// must ensure that any references within Self::Type cannot outlive the references
/// they were derived from in Self::Value.
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type;

/// If this `Join` typically returns all indices in the mask, then iterating
Expand Down Expand Up @@ -322,6 +332,10 @@ impl<J: Join> JoinIter<J> {
}

impl<J: Join> JoinIter<J> {
// TODO: these are unsound because they can be used to get two mutable references to the same
// component (in safe code)

/*
/// Allows getting joined values for specific entity.
///
/// ## Example
Expand Down Expand Up @@ -394,7 +408,7 @@ impl<J: Join> JoinIter<J> {
} else {
None
}
}
}*/
}

impl<J: Join> std::iter::Iterator for JoinIter<J> {
Expand Down
13 changes: 8 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#![warn(missing_docs)]
#![cfg_attr(feature = "nightly", feature(generic_associated_types, associated_type_defaults))]
#![cfg_attr(
feature = "nightly",
feature(generic_associated_types, associated_type_defaults)
)]

//! # SPECS Parallel ECS
//!
Expand Down Expand Up @@ -207,9 +210,9 @@ pub mod world;

pub use hibitset::BitSet;
pub use shred::{
Accessor, AccessorCow, BatchAccessor, BatchController, BatchUncheckedWorld,
Dispatcher, DispatcherBuilder, Read, ReadExpect, RunNow,
RunningTime, StaticAccessor, System, SystemData, World, Write, WriteExpect,
Accessor, AccessorCow, BatchAccessor, BatchController, BatchUncheckedWorld, Dispatcher,
DispatcherBuilder, Read, ReadExpect, RunNow, RunningTime, StaticAccessor, System, SystemData,
World, Write, WriteExpect,
};
pub use shrev::ReaderId;

Expand All @@ -232,4 +235,4 @@ pub use crate::{
};

#[cfg(feature = "nightly")]
pub use crate::storage::DerefFlaggedStorage;
pub use crate::storage::UnsplitFlaggedStorage;
12 changes: 10 additions & 2 deletions src/storage/deref_flagged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ where
}

impl<C: Component, T: UnprotectedStorage<C>> UnprotectedStorage<C> for DerefFlaggedStorage<C, T> {
#[rustfmt::skip]
type AccessMut<'a> where T: 'a = FlaggedAccessMut<'a, <T as UnprotectedStorage<C>>::AccessMut<'a>, C>;

unsafe fn clean<B>(&mut self, has: B)
Expand All @@ -68,13 +69,20 @@ impl<C: Component, T: UnprotectedStorage<C>> UnprotectedStorage<C> for DerefFlag
self.storage.get(id)
}

unsafe fn get_mut(&mut self, id: Index) -> Self::AccessMut<'_> {
unsafe fn get_mut(&mut self, id: Index) -> &mut C {
if self.emit_event() {
self.channel.single_write(ComponentEvent::Modified(id));
}
self.storage.get_mut(id),
}

unsafe fn get_access_mut(&mut self, id: Index) -> Self::AccessMut<'_> {
let emit = self.emit_event();
FlaggedAccessMut {
channel: &mut self.channel,
emit,
id,
access: self.storage.get_mut(id),
access: self.storage.get_access_mut(id),
phantom: PhantomData,
}
}
Expand Down
12 changes: 8 additions & 4 deletions src/storage/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,24 @@ where
pub fn get_mut(&mut self) -> AccessMutReturn<'_, T> {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get_mut(self.id) }
unsafe { self.storage.data.inner.get_access_mut(self.id) }
}

/// Converts the `OccupiedEntry` into a mutable reference bounded by
/// the storage's lifetime.
pub fn into_mut(self) -> AccessMutReturn<'a, T> {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get_mut(self.id) }
unsafe { self.storage.data.inner.get_access_mut(self.id) }
}

/// Inserts a value into the storage and returns the old one.
pub fn insert(&mut self, mut component: T) -> T {
std::mem::swap(&mut component, self.get_mut().deref_mut());
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
std::mem::swap(&mut component, unsafe {
self.storage.data.inner.get_mut(self.id)
});
component
}

Expand Down Expand Up @@ -235,7 +239,7 @@ where
// SAFETY: This is safe since we added `self.id` to the mask.
unsafe {
self.storage.data.inner.insert(self.id, component);
self.storage.data.inner.get_mut(self.id)
self.storage.data.inner.get_access_mut(self.id)
}
}
}
Expand Down
19 changes: 15 additions & 4 deletions src/storage/flagged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use shrev::EventChannel;
/// **Note:** Joining over all components of a `FlaggedStorage`
/// mutably will flag all components.
///
/// **Note:** restricted storages are currently removed since they need some changes to be sound so
/// the below advice won't currently work. Use `UnsplitFlaggedStorage` instead.
///
/// What you want to instead is to use `restrict_mut()` to first
/// get the entities which contain the component and then conditionally
/// modify the component after a call to `get_mut_unchecked()` or `get_mut()`.
Expand Down Expand Up @@ -201,6 +204,7 @@ where

impl<C: Component, T: UnprotectedStorage<C>> UnprotectedStorage<C> for FlaggedStorage<C, T> {
#[cfg(feature = "nightly")]
#[rustfmt::skip]
type AccessMut<'a> where T: 'a = <T as UnprotectedStorage<C>>::AccessMut<'a>;

unsafe fn clean<B>(&mut self, has: B)
Expand All @@ -214,20 +218,27 @@ impl<C: Component, T: UnprotectedStorage<C>> UnprotectedStorage<C> for FlaggedSt
self.storage.get(id)
}

#[cfg(feature = "nightly")]
unsafe fn get_mut(&mut self, id: Index) -> <T as UnprotectedStorage<C>>::AccessMut<'_> {
unsafe fn get_mut(&mut self, id: Index) -> &mut C {
if self.emit_event() {
self.channel.single_write(ComponentEvent::Modified(id));
}
self.storage.get_mut(id)
}

#[cfg(feature = "nightly")]
unsafe fn get_access_mut(&mut self, id: Index) -> <T as UnprotectedStorage<C>>::AccessMut<'_> {
if self.emit_event() {
self.channel.single_write(ComponentEvent::Modified(id));
}
self.storage.get_access_mut(id)
}

#[cfg(not(feature = "nightly"))]
unsafe fn get_mut(&mut self, id: Index) -> &mut C {
unsafe fn get_access_mut(&mut self, id: Index) -> &mut C {
if self.emit_event() {
self.channel.single_write(ComponentEvent::Modified(id));
}
self.storage.get_mut(id)
self.storage.get_access_mut(id)
}

unsafe fn insert(&mut self, id: Index, comp: C) {
Expand Down
15 changes: 10 additions & 5 deletions src/storage/generic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#[cfg(feature = "nightly")]
use std::ops::DerefMut;
use crate::{
storage::{InsertResult, ReadStorage, WriteStorage, AccessMutReturn},
storage::{AccessMutReturn, InsertResult, ReadStorage, WriteStorage},
world::{Component, Entity},
};

#[cfg(feature = "nightly")]
use crate::storage::UnprotectedStorage;

Expand Down Expand Up @@ -88,7 +87,8 @@ pub trait GenericWriteStorage {
type Component: Component;
/// The wrapper through with mutable access of a component is performed.
#[cfg(feature = "nightly")]
type AccessMut<'a>: DerefMut<Target=Self::Component> where Self: 'a;
#[rustfmt::skip]
type AccessMut<'a> where Self: 'a;

/// Get mutable access to an `Entity`s component
fn get_mut(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, Self::Component>>;
Expand All @@ -97,7 +97,10 @@ pub trait GenericWriteStorage {
/// exist, it is automatically created using `Default::default()`.
///
/// Returns None if the entity is dead.
fn get_mut_or_default(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, Self::Component>>
fn get_mut_or_default(
&mut self,
entity: Entity,
) -> Option<AccessMutReturn<'_, Self::Component>>
where
Self::Component: Default;

Expand All @@ -117,6 +120,7 @@ where
{
type Component = T;
#[cfg(feature = "nightly")]
#[rustfmt::skip]
type AccessMut<'b> where Self: 'b = <<T as Component>::Storage as UnprotectedStorage<T>>::AccessMut<'b>;

fn get_mut(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, T>> {
Expand Down Expand Up @@ -155,6 +159,7 @@ where
{
type Component = T;
#[cfg(feature = "nightly")]
#[rustfmt::skip]
type AccessMut<'c> where Self: 'c = <<T as Component>::Storage as UnprotectedStorage<T>>::AccessMut<'c>;

fn get_mut(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, T>> {
Expand Down
Loading