Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix randomness for threading #7925

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from

Conversation

marcus-wirtz-snkeos
Copy link

Description

Fixes #7922 by updating the random state of the Randomizable transform BEFORE copying the transforms. In the current implementation self.randomizable() is only called within the __call__() function and thus only updated inside the copy.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@YanxuanLiu
Copy link
Collaborator

/build

if isinstance(_transform, ThreadUnsafe):
if isinstance(_transform, Randomizable):
# update the random state before deepcopy, otherwise there is no randomness
_transform.randomize(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can definitely update the random state here, but I guess the issue here is that if the transform is thread unsafe, we can't guarantee that the same transform will be performed on all keys, which may cause problems.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of my understanding, the state is frozen for a single thread after the subsequent deepcopy of the Transform. Since all keys are processed by this copied Transform, a consistent state is guaranteed.

Copy link
Author

@marcus-wirtz-snkeos marcus-wirtz-snkeos Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I realized that .randomize() is not necessarily updating the random state self.R (cf. monai.transforms.transform.RandomizableTranform)
image

Therefore the correct way here would be to call the _transform.set_random_state() which is implemented in the Randomizable base class und updates self.R

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KumoLiu are there more transforms which inherit directly from ThreadUnsafe? I can only find Randomizable in the monai codebase, which would be covered here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, only Randomizable but all random transform inherit from RandomizableTransform. I'm not sure whether this change can also works well with invert. May also need to check that.

@ericspod
Copy link
Member

I'd like @atbenmurray to have a chance to review this before approving please.

@atbenmurray
Copy link
Contributor

Hi folks. Thanks @marcus-wirtz-snkeos for taking the time to raise the issue and PR. I need to take a careful look at this fix. From a design standpoint, we are very much focused on an "as if the pytorch team wrote it" design philosophy and I need to destruct test the change from this standpoint.

Signed-off-by: marcus.wirtz <[email protected]>
@johnzielke
Copy link
Contributor

Thanks everyone for the amazing work on Monai.
Seeing this, and having looked at parts of Monai random generation before, this is my humble opinion on this topic:
While the docs do mention issues with randomness and threading, I would not expect them to have these consequences. If I recall correctly, it used to be that there would be errors when calling transforms from multiple threads, but the deep copying every iteration was introduced in March of last year if I read the git history correctly.

In my opinion this should be forbidden by default and throw an error that needs to be disabled with a flag to prevent users from accidentally stumbling on this. The problem with the proposed solution is that there would be no reproducibility since a new randomstate is used every time. That is fine in my opinion if users have to use a flag to manually enable this behavior and will turn off threading when they need reproducibility.

But in the future, the whole random generation of Monai needs a refactor that solves the problem of multi-threading and randomness (see #7582 ) .
This could be completed together with the move to a new random generator api (see this PR that I tried my hand at, but realized that without discussion with the core team would need to make too many breaking changes).

@atbenmurray
Copy link
Contributor

Thanks everyone for the amazing work on Monai. Seeing this, and having looked at parts of Monai random generation before, this is my humble opinion on this topic: While the docs do mention issues with randomness and threading, I would not expect them to have these consequences. If I recall correctly, it used to be that there would be errors when calling transforms from multiple threads, but the deep copying every iteration was introduced in March of last year if I read the git history correctly.

In my opinion this should be forbidden by default and throw an error that needs to be disabled with a flag to prevent users from accidentally stumbling on this. The problem with the proposed solution is that there would be no reproducibility since a new randomstate is used every time. That is fine in my opinion if users have to use a flag to manually enable this behavior and will turn off threading when they need reproducibility.

But in the future, the whole random generation of Monai needs a refactor that solves the problem of multi-threading and randomness (see #7582 ) . This could be completed together with the move to a new random generator api (see this PR that I tried my hand at, but realized that without discussion with the core team would need to make too many breaking changes).

Thanks for bringing this up @johnzielke. I'll take a look at these items also.

@marcus-wirtz-snkeos
Copy link
Author

@johnzielke thanks for the feedback, fully agreeing. This fix can only be a temporary one, since the earlier introduced deepcopy() is problematic per se.

I verified with local batch generation that there is no randomness for the threading=True code as it is right now and also that there is randomness (though not deterministic) with my proposed fix.

Originally I tried to use _transform.randomize(data) rather than _transform.set_random_state(), which would only iterate the random state and therefore maintain reproducibility. I experienced some issues though with certain transforms (maybe not implementing self.randomize() correctly). I'll have a look on that again and keep you posted!

@marcus-wirtz-snkeos
Copy link
Author

marcus-wirtz-snkeos commented Jul 26, 2024

Should work now, the issue was in some of my custom Transforms indeed not implementing .randomize() correctly. @atbenmurray can you run the destruction checks to bring this as a temporary workaround?

@lukas-folle-snkeos
Copy link

@atbenmurray @ericspod what is the state of this PR?

@atbenmurray
Copy link
Contributor

atbenmurray commented Oct 4, 2024

@atbenmurray @ericspod what is the state of this PR?

@lukas-folle-snkeos, I'm refamiliarizing myself with it. Ideally, we'd like to do more to improve the randomness for threading, but if this change isn't breaking any scenarios, then we can go ahead with it and think about that subsequently.

@johnzielke
Copy link
Contributor

I don't think anyone relies on the current non-randomization behavior. I think there is an issue with the proposed approach, which I think are both part of randomize() not being the "correct" function in this case

  1. Calling randomize is usually done inside the "call" function already, so this would call it twice.
  2. Since the .randomize() is called on the transform shared across all threads, this approach might lead to race conditions with multiple threads calling it.
    I see two solutions for this:
    A. Use set_random_state() instead of randomize(). E.g.
if threading and isinstance(_transform, ThreadUnsafe):
            _shared_transform = _transform
            _transform = deepcopy(_transform)
            if isinstance(_transform, Randomizable):
                seed = _shared_transform.R.randint(0, 2**32 - 1) # Max value allowed as seed for numpy.random.RandomState 
                _transform.set_random_state(seed)

This makes sure that each iteration uses a different randomstate. You do not have reproducibility though, since the inidividual threads might not be calling this in a reproducible order
B. Use some kind of thread-local variable to keep individual "persistent" individual instances of the transforms.
While I think this is the most future-proof option, I think it needs a bit more thinking to prevent memory and other possible issues.

@atbenmurray
Copy link
Contributor

If I understand the rationale correctly. I think that calling randomize on the shared transform is the point of this modification.
If shared_transform.R instance is at state s and you are running 2 threads, for example, calling randomize on it puts the shared_transform.R into state s+1 in one thread, and then puts it into state s+2 on the other thread.

Now, this can absolutely cause race conditions. One source of race conditions is mitigated by the fact that randomize gets called again on the deepcopied transform. However, this doesn't eliminate race conditions entirely. Neither deepcopy nor RandomState are thread safe. Again, due to the call to randomize being performed on the deepcopied transform, I don't think that torn-state on the transform during deepcopy should be a problem in this situation. However, I do worry that internal state managed by RandomState can also get torn and that could potentially cause all kinds of problems.

This can be fixed by locking the section that calls randomize and deepcopies the transform. I made a suggested code change in the review.

@atbenmurray
Copy link
Contributor

I think it comes down to one of three choices:

  1. We take the current approach (using randomize then copying, but under lock)
  2. We take the set_random_state approach (again, under lock)
  3. We evaluate what we really want from multi-thread / process random augmentation pipelines

I'm not a huge fan of 1, as it is relying on side-effects of the way transforms are implemented. That said, I'm also not a huge fan of 2, because we are overwriting the random state that has been set by the user with new random state instances. This means that if the user provides a RandomState-like mock for testing purposes, for example, it is defeated by our replacing their random_state like object with an actual RandomState. I've had this problem myself before.

It is definitely true that reproducibility is compromised by threading / multiprocessing of randomized augmentation pipelines. Given that the mutation of a shared RandomState across threads already makes every run not-reproducible from an augmentation perspective, maybe it doesn't matter whether we replace random states but then, as I mentioned, there are other reasons to not do it this way? I think a gold standard solution would involve random states / seeds being assigned up front at the point that the pipeline is instantiated and the number of threads set.

WDYT?

Comment on lines 110 to 114
if isinstance(_transform, ThreadUnsafe):
if isinstance(_transform, Randomizable):
# update the random state before deepcopy, otherwise there is no randomness
_transform.set_random_state()
_transform = deepcopy(_transform)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this needs a lock

monai/transforms/compose.py Outdated Show resolved Hide resolved
@johnzielke
Copy link
Contributor

Thanks for your analysis @atbenmurray, and I agree that a lock would be a good idea here. I think the source of race conditions when using randomize on the "main-thread" instance mainly results from the fact that transforms sometimes rely on instance attributes in the randomize() method. This also means that if expensive calculations are performed in this step (i.e. calculating some information from the input), they would be single-threaded here.
I am in favor of option 2. The fact that custom random states set by the user would be overwritten would only be relevant when using the ThreadDataloader and only to the extent that the actual randomize() call in the threads would be using a normal RandomState instance, but the seed of that instance would still be controlled by the user.
Another option, if compatibility with mocks and custom RandomState classes is very important, would be to just advance the "main-thread" RandomState using .rand() before deep-copying and don't call set_random_state. This way each transform would not receive exactly the same values, but they would still overlap.

Regardless of the option, I think we should add a single-time warning explaining whatever caveats the solution has.

Co-authored-by: Ben Murray <[email protected]>
Signed-off-by: Marcus Wirtz <[email protected]>
@@ -106,8 +106,13 @@ def execute_compose(
return data

for _transform in transforms[start:end]:
if threading:
_transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
with lock:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll still need to create the lock object somewhere and get it to this function. Note that the lock must be created somewhere where only a single thread of execution is occurring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

No randomness for threading=True
7 participants