Skip to content

Commit

Permalink
feat: add useShallow (#2090)
Browse files Browse the repository at this point in the history
* feat: add useShallow

See
- pmndrs/zustand#1937
- pmndrs/zustand#1937 (reply in thread)
- pmndrs/zustand#1937 (reply in thread)

* chore(useShallow): improve unit tests

* chore(useShallow): PR feedback pmndrs/zustand#2090 (comment)

* fix(useShallow): tests not working on test_matrix (cjs, production, CI-MATRIX-NOSKIP)

* chore(useShallow): fix eslint warning issue (unused import)

* refactor(useShallow): simplify tests

* docs(useShallow): add guide

* fix(useShallow): prettier:ci error https://github.com/pmndrs/zustand/actions/runs/6369420511/job/17289749161?pr=2090

* docs(useShallow): update readme

* docs(useShallow): remove obsolete line from readme

Co-authored-by: Daishi Kato <[email protected]>

* doc(useShallow): PR feedback pmndrs/zustand#2090 (comment)

* docs(useShallow): small improvements of the useShallow guide

---------

Co-authored-by: Daishi Kato <[email protected]>
  • Loading branch information
FaberVitale and dai-shi authored Oct 2, 2023
1 parent 122ff8d commit 262fcb7
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 22 deletions.
63 changes: 63 additions & 0 deletions docs/guides/prevent-rerenders-with-use-shallow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
---
title: Prevent rerenders with useShallow
nav: 16
---

When you need to subscribe to a computed state from a store, the recommended way is to
use a selector.

The computed selector will cause a rererender if the output has changed according to [Object.is](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/is?retiredLocale=it).

In this case you might want to use `useShallow` to avoid a rerender if the computed value is always shallow
equal the previous one.

## Example

We have a store that associates to each bear a meal and we want to render their names.

```js
import { create } from 'zustand'

const useMeals = create(() => ({
papaBear: 'large porridge-pot',
mamaBear: 'middle-size porridge pot',
littleBear: 'A little, small, wee pot',
}))

export const BearNames = () => {
const names = useMeals((state) => Object.keys(state))

return <div>{names.join(', ')}</div>
}
```

Now papa bear wants a pizza instead:

```js
useMeals.setState({
papaBear: 'a large pizza',
})
```

This change causes `BearNames` rerenders even tho the actual output of `names` has not changed according to shallow equal.

We can fix that using `useShallow`!

```js
import { create } from 'zustand'
import { useShallow } from 'zustand/shallow'

const useMeals = create(() => ({
papaBear: 'large porridge-pot',
mamaBear: 'middle-size porridge pot',
littleBear: 'A little, small, wee pot',
}))

export const BearNames = () => {
const names = useMeals(useShallow((state) => Object.keys(state)))

return <div>{names.join(', ')}</div>
}
```

Now they can all order other meals without causing unnecessary rerenders of our `BearNames` component.
32 changes: 12 additions & 20 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,38 +84,30 @@ const nuts = useBearStore((state) => state.nuts)
const honey = useBearStore((state) => state.honey)
```

If you want to construct a single object with multiple state-picks inside, similar to redux's mapStateToProps, you can tell zustand that you want the object to be diffed shallowly by passing the `shallow` equality function.

To use a custom equality function, you need `createWithEqualityFn` instead of `create`. Usually you want to specify `Object.is` as the second argument for the default equality function, but it's configurable.
If you want to construct a single object with multiple state-picks inside, similar to redux's mapStateToProps, you can use [useShallow](./docs/guides/prevent-rerenders-with-use-shallow.md) to prevent unnecessary rerenders when the selector output does not change according to shallow equal.

```jsx
import { createWithEqualityFn } from 'zustand/traditional'
import { shallow } from 'zustand/shallow'

// Use createWithEqualityFn instead of create
const useBearStore = createWithEqualityFn(
(set) => ({
bears: 0,
increasePopulation: () => set((state) => ({ bears: state.bears + 1 })),
removeAllBears: () => set({ bears: 0 }),
}),
Object.is // Specify the default equality function, which can be shallow
)
import { create } from 'zustand'
import { useShallow } from 'zustand/shallow'

const useBearStore = create((set) => ({
bears: 0,
increasePopulation: () => set((state) => ({ bears: state.bears + 1 })),
removeAllBears: () => set({ bears: 0 }),
}))

// Object pick, re-renders the component when either state.nuts or state.honey change
const { nuts, honey } = useBearStore(
(state) => ({ nuts: state.nuts, honey: state.honey }),
shallow
useShallow((state) => ({ nuts: state.nuts, honey: state.honey }))
)

// Array pick, re-renders the component when either state.nuts or state.honey change
const [nuts, honey] = useBearStore(
(state) => [state.nuts, state.honey],
shallow
useShallow((state) => [state.nuts, state.honey])
)

// Mapped picks, re-renders the component when state.treats changes in order, count or keys
const treats = useBearStore((state) => Object.keys(state.treats), shallow)
const treats = useBearStore(useShallow((state) => Object.keys(state.treats)))
```

For more control over re-rendering, you may provide any custom equality function.
Expand Down
13 changes: 13 additions & 0 deletions src/shallow.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { useRef } from 'react'

export function shallow<T>(objA: T, objB: T) {
if (Object.is(objA, objB)) {
return true
Expand Down Expand Up @@ -59,3 +61,14 @@ export default ((objA, objB) => {
}
return shallow(objA, objB)
}) as typeof shallow

export function useShallow<S, U>(selector: (state: S) => U): (state: S) => U {
const prev = useRef<U>()

return (state) => {
const next = selector(state)
return shallow(prev.current, next)
? (prev.current as U)
: (prev.current = next)
}
}
159 changes: 157 additions & 2 deletions tests/shallow.test.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { describe, expect, it } from 'vitest'
import { useState } from 'react'
import { act, fireEvent, render } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { create } from 'zustand'
import { shallow } from 'zustand/shallow'
import { shallow, useShallow } from 'zustand/shallow'

describe('shallow', () => {
it('compares primitive values', () => {
Expand Down Expand Up @@ -131,3 +133,156 @@ describe('unsupported cases', () => {
).not.toBe(false)
})
})

describe('useShallow', () => {
const testUseShallowSimpleCallback =
vi.fn<[{ selectorOutput: string[]; useShallowOutput: string[] }]>()
const TestUseShallowSimple = ({
selector,
state,
}: {
state: Record<string, unknown>
selector: (state: Record<string, unknown>) => string[]
}) => {
const selectorOutput = selector(state)
const useShallowOutput = useShallow(selector)(state)

return (
<div
data-testid="test-shallow"
onClick={() =>
testUseShallowSimpleCallback({ selectorOutput, useShallowOutput })
}
/>
)
}

beforeEach(() => {
testUseShallowSimpleCallback.mockClear()
})

it('input and output selectors always return shallow equal values', () => {
const res = render(
<TestUseShallowSimple state={{ a: 1, b: 2 }} selector={Object.keys} />
)

expect(testUseShallowSimpleCallback).toHaveBeenCalledTimes(0)
fireEvent.click(res.getByTestId('test-shallow'))

const firstRender = testUseShallowSimpleCallback.mock.lastCall?.[0]

expect(testUseShallowSimpleCallback).toHaveBeenCalledTimes(1)
expect(firstRender).toBeTruthy()
expect(firstRender?.selectorOutput).toEqual(firstRender?.useShallowOutput)

res.rerender(
<TestUseShallowSimple
state={{ a: 1, b: 2, c: 3 }}
selector={Object.keys}
/>
)

fireEvent.click(res.getByTestId('test-shallow'))
expect(testUseShallowSimpleCallback).toHaveBeenCalledTimes(2)

const secondRender = testUseShallowSimpleCallback.mock.lastCall?.[0]

expect(secondRender).toBeTruthy()
expect(secondRender?.selectorOutput).toEqual(secondRender?.useShallowOutput)
})

it('returns the previously computed instance when possible', () => {
const state = { a: 1, b: 2 }
const res = render(
<TestUseShallowSimple state={state} selector={Object.keys} />
)

fireEvent.click(res.getByTestId('test-shallow'))
expect(testUseShallowSimpleCallback).toHaveBeenCalledTimes(1)
const output1 =
testUseShallowSimpleCallback.mock.lastCall?.[0]?.useShallowOutput
expect(output1).toBeTruthy()

// Change selector, same output
res.rerender(
<TestUseShallowSimple
state={state}
selector={(state) => Object.keys(state)}
/>
)

fireEvent.click(res.getByTestId('test-shallow'))
expect(testUseShallowSimpleCallback).toHaveBeenCalledTimes(2)

const output2 =
testUseShallowSimpleCallback.mock.lastCall?.[0]?.useShallowOutput
expect(output2).toBeTruthy()

expect(output2).toBe(output1)
})

it('only re-renders if selector output has changed according to shallow', () => {
let countRenders = 0
const useMyStore = create(
(): Record<string, unknown> => ({ a: 1, b: 2, c: 3 })
)
const TestShallow = ({
selector = (state) => Object.keys(state).sort(),
}: {
selector?: (state: Record<string, unknown>) => string[]
}) => {
const output = useMyStore(useShallow(selector))

++countRenders

return <div data-testid="test-shallow">{output.join(',')}</div>
}

expect(countRenders).toBe(0)
const res = render(<TestShallow />)

expect(countRenders).toBe(1)
expect(res.getByTestId('test-shallow').textContent).toBe('a,b,c')

act(() => {
useMyStore.setState({ a: 4 }) // This will not cause a re-render.
})

expect(countRenders).toBe(1)

act(() => {
useMyStore.setState({ d: 10 }) // This will cause a re-render.
})

expect(countRenders).toBe(2)
expect(res.getByTestId('test-shallow').textContent).toBe('a,b,c,d')
})

it('does not cause stale closure issues', () => {
const useMyStore = create(
(): Record<string, unknown> => ({ a: 1, b: 2, c: 3 })
)
const TestShallowWithState = () => {
const [count, setCount] = useState(0)
const output = useMyStore(
useShallow((state) => Object.keys(state).concat([count.toString()]))
)

return (
<div
data-testid="test-shallow"
onClick={() => setCount((prev) => ++prev)}>
{output.join(',')}
</div>
)
}

const res = render(<TestShallowWithState />)

expect(res.getByTestId('test-shallow').textContent).toBe('a,b,c,0')

fireEvent.click(res.getByTestId('test-shallow'))

expect(res.getByTestId('test-shallow').textContent).toBe('a,b,c,1')
})
})

0 comments on commit 262fcb7

Please sign in to comment.