Skip to content

Initval refactoring #4924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Aug 13, 2021 · 0 comments · Fixed by #4983
Closed

Initval refactoring #4924

ricardoV94 opened this issue Aug 13, 2021 · 0 comments · Fixed by #4983

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 13, 2021

Proposal

Create a model.compile_initial_point_fn which creates a compiled aesara function that computes a (transformed) initial point for each RV in the model simultaneously.

This method takes as input a dictionary with optional user choices concerning initvals, whose values might be either numerical values, symbolic expressions (which can only depend on upstream RVs / Deterministics / or shared variables), or a string of ["random", "moment"], which means a random draw should be taken from the RV or a fixed moment should be extracted from the RV (using something like #4912).

These choices would be saved in a model dictionary such as model._user_initval_choices as new variables are defined.

Finally model.initial_point would look something like this:

def initial_point(self, recompute=False):
    if recompute or self.last_computed_initval is None:
        initial_point = self.compile_initial_point_fn(self._user_initval_choices)():
        ... # convert to dict format
        self.last_computed_initval = initial_point_dict
    return self.last_computed_initval

The main goal is to decouple the model definition from the sampling phases, addressing issues like #4918

A recompute flag is used to ensure a new initial_point is not recomputed uselessly, as many places in the codebase call this property frequently. Alternatively we can keep the old property and add a recompute_inital_point() that changes the self.last_computed_initval

Other benefits

This would allow us to simplify the slightly redundant and potentially defective model.update_start_vals (mentioned in #4484 (comment)): https://github.com/pymc-devs/pymc3/blob/6a75744b31f3ec015856ac6ea374fe12be8cc156/pymc3/model.py#L1548-L1556

to something like:

def update_start_vals(self, new_initval_choices):
    initval_choices = deepcopy.copy(self._user_initval_choices)
    initval_choices.update(new_initval_choices)
    start_vals = model.compile_initial_point_fn(initval_choices)()
    ... # convert to dict
    return start_vals_dict

We could also revert some ugly changes to prior_predictive_sampling introduced by yours truly in 687f044:

https://github.com/pymc-devs/pymc3/blob/6d2aa5ddebed01d81c2ab66b9d4bd02194f82508/pymc3/sampling.py#L1983-L1999

This hack resulted from a different difficulty in obtaining multiple transformed initial_points (to kick-start the SMC sampler). All that would be needed after the proposed changes would be something like the following:

random_initvals = {var: "random" for var in model._user_initval_choices.keys()}
initial_point_fn = model.compile_initial_point_fn(random_initvals)
initvals = zip(*(initial_point_fn() for i in range(samples)))
...  # convert to dict format

And the changes in 687f044 could be removed altogether.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant