Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ifp_egm] update namedtuple #130

Closed
wants to merge 2 commits into from
Closed

[ifp_egm] update namedtuple #130

wants to merge 2 commits into from

Conversation

shlff
Copy link
Member

@shlff shlff commented Oct 16, 2023

This PR fix #128 .

@shlff shlff added the in-work label Oct 16, 2023
@netlify
Copy link

netlify bot commented Oct 16, 2023

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
🔨 Latest commit 32a8f9b
🔍 Latest deploy log https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/65e531f94fddfe0008437304
😎 Deploy Preview https://deploy-preview-130--incomparable-parfait-2417f8.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@github-actions
Copy link

github-actions bot commented Oct 16, 2023

@github-actions github-actions bot temporarily deployed to pull request October 16, 2023 01:45 Inactive
@jstac
Copy link
Contributor

jstac commented Mar 3, 2024

@shlff , this is labeled as "in-work". Is it ready to review?

@shlff
Copy link
Member Author

shlff commented Mar 4, 2024

Thanks @jstac . I've been reviewing this PR, and I will get back to you very soon.

@github-actions github-actions bot temporarily deployed to pull request March 4, 2024 02:42 Inactive
@shlff
Copy link
Member Author

shlff commented Mar 4, 2024

Thanks @jstac . This PR is ready to review as well, and please find a preview:

https://65e53524b6933606366d8586--incomparable-parfait-2417f8.netlify.app/ifp_egm

In addition to using namedtuple, I also tidied up some code styles and words.

@shlff shlff added ready and removed in-work labels Mar 4, 2024
@mmcky mmcky requested review from jstac and HumphreyYang March 4, 2024 03:48
@HumphreyYang
Copy link
Collaborator

Many thanks @shlff, @jstac, and @mmcky. I will review PRs in this repo tonight.

Copy link
Collaborator

@HumphreyYang HumphreyYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks @shlff. Please find minor code improvement proposals below:

@@ -0,0 +1,619 @@
---
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file as it is a checkpoint file.

Comment on lines +104 to +114
Household = namedtuple('Household', ('β', 'R', 'γ', 's_size', 'y_size', \
's_grid', 'y_grid', 'P'))

def create_household(R=1.01, # gross interest rate
β=0.99, # discount factor
γ=1.5, # CRRA preference parameter
s_max=16, # savings grid max
s_size=200, # savings grid size
ρ=0.99, # income persistence
ν=0.02, # income volatility
y_size=25): # income grid size
Copy link
Collaborator

@HumphreyYang HumphreyYang Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's migrate the comments to where the namedtuple is defined : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree --- in fact let's have both, because the arguments are not identical to the elements of the namedtuple. Some overlap is fine.

Comment on lines +129 to +130
return Household(β=β, R=R, γ=γ, s_size=s_size, y_size=y_size, \
s_grid=s_grid, y_grid=y_grid, P=P)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return Household(β=β, R=R, γ=γ, s_size=s_size, y_size=y_size, \
s_grid=s_grid, y_grid=y_grid, P=P)
return Household(β=β, R=R, γ=γ, s_size=s_size, y_size=y_size,
s_grid=s_grid, y_grid=y_grid, P=P)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch @HumphreyYang . Not needed after comma.

Comment on lines -386 to +390
constants, sizes, arrays = model
β, R, γ, s_size, y_size, s_grid, y_grid, P = model

β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays
constants = β, R, γ
sizes = s_size, y_size
arrays = s_grid, y_grid, P
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a small observation on the original code. Here we are packing constants, sizes, and arrays for the function K_egm, but later in K_egm, we unpack them again:

def K_egm(a_in, σ_in, constants, sizes, arrays):
    """
    The vectorized operator K using EGM.

    """
    
    # Unpack
    β, R, γ = constants
    s_size, y_size = sizes
    s_grid, y_grid, P = arrays

perhaps we can remove the packing and pass the model directly into K_egm to remove this redundancy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @HumphreyYang . This is a good idea. I have thought and tried about this while coding but didn't find an elegant way to do then.

Comment on lines -439 to 443
model = ifp()
constants, sizes, arrays = model
β, R, γ = constants
s_size, y_size = sizes
model = create_household()
β, R, γ, s_size, y_size, s_grid, y_grid, P = model
arrays = s_grid, y_grid, P
s_grid, y_grid, P = (np.array(a) for a in arrays)

@jitclass(ifp_data)
Copy link
Collaborator

@HumphreyYang HumphreyYang Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating global variables in the cell below, I propose that we define the default values for class IFP using model directly:

ifp_data = [
    ('R', float64),              
    ('β', float64),             
    ('γ', float64),            
    ('P', float64[:, :]),     
    ('y_grid', float64[:]),  
    ('s_grid', float64[:])    
]

# Use the JAX IFP data as our defaults for the Numba version
model = create_household()
β, R, γ, s_size, y_size, s_grid, y_grid, P = model
arrays = s_grid, y_grid, P
s_grid, y_grid, P = (np.array(a) for a in arrays)

@jitclass(ifp_data)
class IFP:

    def __init__(self,
                 R=R,
                 β=β,
                 γ=γ,
                 P=np.array(P),
                 y_grid=np.array(y_grid),
                 s_grid=s_grid):

        self.R, self.β, self.γ = R, β, γ
        self.P, self.y_grid = P, y_grid
        self.s_grid = s_grid

        # Recall that we need R β < 1 for convergence.
        assert self.R * self.β < 1, "Stability condition violated."

    def u_prime(self, c):
        return c**(-self.γ)
    
    def u_prime_inv(self, u_prime):
        return u_prime**(-1/self.γ)

->

ifp_data = [
    ('R', float64),              
    ('β', float64),             
    ('γ', float64),   
    ('y_size', int64),
    ('s_size', int64),             
    ('P', float64[:, :]),     
    ('y_grid', float64[:]),  
    ('s_grid', float64[:]),
]

# Use the JAX IFP data as our defaults for the Numba version
model = create_household()

@jitclass(ifp_data)
class IFP:

    def __init__(self,
                 R=model.R,
                 β=model.β,
                 γ=model.γ,
                 P=np.array(model.P),
                 y_grid=np.array(model.y_grid),
                 s_grid=np.array(model.s_grid)):

        self.R, self.β, self.γ = R, β, γ
        self.P, self.y_grid = P, y_grid
        self.s_grid = s_grid
        self.y_size, self.s_size = len(y_grid), len(s_grid)

        # Recall that we need R β < 1 for convergence.
        assert self.R * self.β < 1, "Stability condition violated."

    def u_prime(self, c):
        return c**(-self.γ)
    
    def u_prime_inv(self, u_prime):
        return u_prime**(-1/self.γ)

Note that y_size will be removed from the global variable after this change. We will need to change the function successive_approx_numba below that is dependent on global y_size.

Also note that we will need to add int64 in the import:

from numba import njit, float64, int64

Hi @jstac, please let us know if you think it is a good improvement to the current version when you are available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree @HumphreyYang, this is neater.

# Shift to JAX arrays
P, y_grid = jax.device_put((P, y_grid))

s_grid = jnp.linspace(0, s_max, s_size)
sizes = s_size, y_size
s_grid, y_grid, P = jax.device_put((s_grid, y_grid, P))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s_grid and y_grid are already on the device.

@jstac
Copy link
Contributor

jstac commented Mar 4, 2024

Many thanks for your input @shlff and @HumphreyYang . I'll tidy this up and merge it. Please leave it to me.

@jstac
Copy link
Contributor

jstac commented Mar 4, 2024

Hey @shlff , I just looked back at the original issue #128 and the aim is to get rid of the class.

The only class is for the numba version, which is not fixed in this PR. Let's be careful when we say the PR fixes an issue, since we don't want to close an issue that's not addressed.

As background, I believe we should move away from using classes whenever we jit-compile with either numba or JAX. The reasons are:

  • functions are easier to debug -- we can write each one in pure Python and then incrementally add @jit to each one and debug the optimization
  • If we use classes, then it's always hard to delineate what should be a method and what should be a function acting on the class. This leads to inconsistency.
  • Classes are a little dangerous -- if we create an instance foo and then change parameter foo.a, it will not change foo.b where b is defined internally by self.b = 2 * a. With a namedtuple, we are forced to create a new instance (which is usually cheap).

CC @mmcky @thomassargent30

@jstac jstac closed this Mar 4, 2024
@jstac
Copy link
Contributor

jstac commented Mar 4, 2024

@mmcky This PR includes the ipynb checkpoint. I've closed the PR but I wonder if there's an easy way to unify the .gitignore files across the repos. If not, would you mind to do a quick review of the .gitignore files and check they are roughly the same?

@shlff
Copy link
Member Author

shlff commented Mar 4, 2024

Thanks @HumphreyYang and @jstac . I've read your comments and those are valuable!

@mmcky
Copy link
Contributor

mmcky commented Mar 4, 2024

@mmcky This PR includes the ipynb checkpoint. I've closed the PR but I wonder if there's an easy way to unify the .gitignore files across the repos. If not, would you mind to do a quick review of the .gitignore files and check they are roughly the same?

I have tried a couple times to see if we could use quantecon/.github as a place to store default files for repositories. It haven't gotten that to work effectively to date. I will open an issue and see if there have been improvements on GitHub to support this.

In the meantime I will do a quick review of them. My feeling is lecture-jax is an outlier. I have come across a few issues recently with some of the syntax used that I am working to fix up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ifp_egm] use namedtuple
4 participants