From d3dc0ec0f93f8c5fbec524997c6366089db9fa21 Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Fri, 1 May 2015 13:12:28 -0700 Subject: [PATCH 1/3] SumOfCosts.cost_per_example --- pylearn2/costs/cost.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pylearn2/costs/cost.py b/pylearn2/costs/cost.py index 37554a75ab..fce9dee243 100644 --- a/pylearn2/costs/cost.py +++ b/pylearn2/costs/cost.py @@ -353,6 +353,37 @@ def expr(self, model, data, ** kwargs): return sum_of_costs + def cost_per_example(self, model, data, ** kwargs): + """ + Returns the sum of the per-example constituent costs. + + Parameters + ---------- + model : pylearn2.models.model.Model + the model for which we want to calculate the sum of + per-example costs + data : flat tuple of tensor_like variables. + data has to follow the format defined by self.get_data_specs(), + but this format will always be a flat tuple. + """ + self.get_data_specs(model)[0].validate(data) + composite_specs, mapping = self.get_composite_specs_and_mapping(model) + nested_data = mapping.nest(data) + costs = [] + for cost, cost_data in safe_zip(self.costs, nested_data): + costs.append(cost.cost_per_example(model, cost_data, **kwargs)) + assert len(costs) > 0 + + if any([cost is None for cost in costs]): + sum_of_costs = None + else: + costs = [coeff * cost + for coeff, cost in safe_zip(self.coeffs, costs)] + assert len(costs) > 0 + sum_of_costs = reduce(lambda x, y: x + y, costs) + + return sum_of_costs + def get_composite_data_specs(self, model): """ Build and return a composite data_specs of all costs. From 7a1e45162fe5cdaca8b9bfcb94eb40fbaf93529d Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Fri, 1 May 2015 13:44:48 -0700 Subject: [PATCH 2/3] MP-DBM cost per example --- pylearn2/costs/dbm.py | 243 ++++++++++++++++++++++++++++++++++- pylearn2/models/dbm/layer.py | 127 +++++++++++++++++- 2 files changed, 367 insertions(+), 3 deletions(-) diff --git a/pylearn2/costs/dbm.py b/pylearn2/costs/dbm.py index f206a4c084..b3114e94f5 100644 --- a/pylearn2/costs/dbm.py +++ b/pylearn2/costs/dbm.py @@ -1467,6 +1467,99 @@ def expr(self, model, data, drop_mask=None, drop_mask_Y=None, return total_cost + @wraps(Cost.cost_per_example) + def cost_per_example(self, model, data, drop_mask=None, drop_mask_Y=None, + return_locals=False, include_toronto=True, ** kwargs): + if self.supervised: + X, Y = data + else: + X = data + Y = None + + if not self.supervised: + assert drop_mask_Y is None + # ignore Y if some other cost is supervised and has made it get + # passed in (can this still happen after the (space, source) + # interface change?) + Y = None + if self.supervised: + assert Y is not None + if drop_mask is not None: + assert drop_mask_Y is not None + + if not hasattr(model, 'cost'): + model.cost = self + if not hasattr(model, 'mask_gen'): + model.mask_gen = self.mask_gen + + dbm = model + + X_space = model.get_input_space() + + if drop_mask is None: + if self.supervised: + drop_mask, drop_mask_Y = self.mask_gen(X, Y, X_space=X_space) + else: + drop_mask = self.mask_gen(X, X_space=X_space) + + if drop_mask_Y is not None: + assert drop_mask_Y.ndim == 1 + + if drop_mask.ndim < X.ndim: + if self.mask_gen is not None: + assert self.mask_gen.sync_channels + if X.ndim != 4: + raise NotImplementedError() + drop_mask = drop_mask.dimshuffle(0, 1, 2, 'x') + + if not hasattr(self, 'noise'): + self.noise = False + + history = dbm.do_inpainting(X, Y=Y, drop_mask=drop_mask, + drop_mask_Y=drop_mask_Y, + return_history=True, + noise=self.noise, + niter=self.niter, + block_grad=self.block_grad) + final_state = history[-1] + + new_drop_mask = None + new_drop_mask_Y = None + new_history = [None for state in history] + + if not hasattr(self, 'both_directions'): + self.both_directions = False + if self.both_directions: + raise NotImplementedError("both_directions was a research " + "feature not supported going forward.") + + new_final_state = new_history[-1] + + cfs = self.per_example_cost_from_states + out = cfs(final_state, new_final_state, dbm, X, Y, drop_mask, + drop_mask_Y, new_drop_mask, new_drop_mask_Y, + return_locals=True) + total_per_example_cost, sublocals = out + assert total_per_example_cost.ndim == 1 + + if not hasattr(self, 'robustness'): + self.robustness = None + if self.robustness is not None: + raise NotImplementedError("robustness was a research " + "feature not supported going forward.") + + if not hasattr(self, 'toronto_act_targets'): + self.toronto_act_targets = None + toronto_act_cost = None + if self.toronto_act_targets is not None and include_toronto: + raise NotImplementedError("Toronto sparsity was a research " + "feature not supported going forward.") + + if return_locals: + return locals() + + return total_per_example_cost + def get_fixed_var_descr(self, model, data): """ Returns the FixedVarDescr object responsible for making sure the @@ -1603,7 +1696,7 @@ def get_gradients(self, model, X, Y=None, **kwargs): def get_inpaint_cost(self, dbm, X, V_hat_unmasked, drop_mask, state, Y, drop_mask_Y): """ - Returns the generalized pseudolikelihood giving raw data, a mask, + Returns the generalized pseudolikelihood given raw data, a mask, and the output of inference. Parameters @@ -1632,6 +1725,40 @@ def get_inpaint_cost(self, dbm, X, V_hat_unmasked, drop_mask, state, return rval + def get_inpaint_cost_per_example(self, dbm, X, V_hat_unmasked, drop_mask, state, + Y, drop_mask_Y): + """ + Returns the generalized pseudolikelihood per example given raw data, a mask, + and the output of inference. + + Parameters + ---------- + dbm : DBM + X : a batch of inputs + V_hat_unmasked : A batch of reconstructions of X + drop_mask : A batch of mask values + state : Hidden states of the DBM + Y : a batch of labels + drop_mask_Y : A batch of Y mask values + """ + rcpe = dbm.visible_layer.recons_cost_per_example + rval = rcpe(X, V_hat_unmasked, drop_mask, use_sum=self.use_sum) + assert rval.ndim == 1 + + if self.supervised: + # pyflakes is too dumb to see that both branches define `scale` + scale = None + if self.use_sum: + scale = 1. + else: + scale = 1. / float(dbm.get_input_space().get_total_dimension()) + Y_hat_unmasked = state['Y_hat_unmasked'] + rc = dbm.hidden_layers[-1].recons_cost_per_example + rval = rval + rc(Y, Y_hat_unmasked, drop_mask_Y, scale) + assert rval.ndim == 1 + + return rval + def cost_from_states(self, state, new_state, dbm, X, Y, drop_mask, drop_mask_Y, new_drop_mask, new_drop_mask_Y, return_locals=False): @@ -1820,6 +1947,120 @@ def cost_from_states(self, state, new_state, dbm, X, Y, drop_mask, return total_cost + def per_example_cost_from_states(self, state, new_state, dbm, X, Y, + drop_mask, drop_mask_Y, new_drop_mask, new_drop_mask_Y, + return_locals=False): + """ + Returns the total cost per example, given the states produced by + inference. + This includes activity regularization costs, not just generalized + pseudolikelihood costs. + + Parameters + ---------- + state : The state of the model after inference. + new_state : OrderedDict + The state of the model after inference with a different mask. + dbm : DBM. + X : A batch of input pixels. + Y : A batch of output labels. + drop_mask : A batch of mask values determining which pixels are inputs. + drop_mask_Y : Theano matrix + A batch of mask values determining which labels are inputs. + new_drop_mask : The second mask. + new_drop_mask_Y : The second label mask. + return_locals : bool + If True, return all local variables + + Returns + ------- + cost_per_example : 1D Theano tensor containing cost per example + locals : Optional + If return_locals is True, returns the dictionary of all local + variables. Note that this means all implementation changes are + now API changes. + """ + + if hasattr(self, 'both_directions') and self.both_directions: + raise NotImplementedError("both_directions was a research" + "feature that is no longer supported.") + + if not self.supervised: + assert drop_mask_Y is None + assert new_drop_mask_Y is None + if self.supervised: + assert drop_mask_Y is not None + assert Y is not None + + V_hat_unmasked = state['V_hat_unmasked'] + assert V_hat_unmasked.ndim == X.ndim + + if not hasattr(self, 'use_sum'): + self.use_sum = False + + gicpe = self.get_inpaint_cost_per_example + inpaint_cost_per_example = gicpe(dbm, X, V_hat_unmasked, drop_mask, + state, Y, drop_mask_Y) + assert inpaint_cost_per_example.ndim == 1 + + total_cost_per_example = inpaint_cost_per_example + + if hasattr(self, 'range_rewards') and self.range_rewards is not None: + raise NotImplementedError("range_rewards was a research feature " + "not supported going forward.") + + if hasattr(self, 'stdev_rewards') and self.stdev_rewards is not None: + raise NotImplementedError("stdev_rewards was a research feature " + "not supported going forward.") + + l1_act_cost_per_example = None + if self.l1_act_targets is not None: + l1_act_cost_per_example = T.zeros_like(total_cost_per_example) + if self.l1_act_eps is None: + self.l1_act_eps = [None] * len(self.l1_act_targets) + for layer, mf_state, targets, coeffs, eps in \ + safe_izip(dbm.hidden_layers, state['H_hat'], + self.l1_act_targets, self.l1_act_coeffs, + self.l1_act_eps): + + assert not isinstance(targets, str) + + try: + gl1acpe = layer.get_l1_act_cost_per_example + layer_cost_per_example = gl1acpe(mf_state, targets, + coeffs, eps) + assert layer_cost_per_example.ndim == 1 + except NotImplementedError: + if coeffs == 0.: + layer_cost_per_example = 0. + else: + raise + if layer_cost_per_example != 0.: + l1_act_cost_per_example += layer_cost_per_example + # end for substates + # end for layers + assert l1_act_cost_per_example.ndim == 1 + total_cost_per_example += l1_act_cost_per_example + # end if act penalty + + if not hasattr(self, 'hid_presynaptic_cost'): + self.hid_presynaptic_cost = None + if self.hid_presynaptic_cost is not None: + raise NotImplementedError("this was a research feature " + "not supported going forward.") + + if not hasattr(self, 'reweighted_act_targets'): + self.reweighted_act_targets = None + if self.reweighted_act_targets is not None: + raise NotImplementedError("this was a research feature " + "not supported going forward.") + + assert total_cost_per_example.ndim == 1 + if return_locals: + return total_cost_per_example, locals() + + return total_cost_per_example + default_seed = 20120712 diff --git a/pylearn2/models/dbm/layer.py b/pylearn2/models/dbm/layer.py index 99a69e5f3c..b02e6abf07 100644 --- a/pylearn2/models/dbm/layer.py +++ b/pylearn2/models/dbm/layer.py @@ -599,9 +599,17 @@ def inpaint_update(self, state_above, layer_above, drop_mask = None, V = None, r def recons_cost(self, V, V_hat_unmasked, drop_mask = None, use_sum=False): """ - .. todo:: + Returns the cost of reconstructing `V` as `V_hat` - WRITEME + Parameters + ---------- + V : theano matrix + A design matrix of binary features + V_hat_unmasked : theano matrix + A design matrix of mean field predictions of V + drop_mask : theano matrix + Mask specifying which examples are inputs and which are targets + use_sum : WRITEME """ if use_sum: raise NotImplementedError() @@ -639,6 +647,61 @@ def recons_cost(self, V, V_hat_unmasked, drop_mask = None, use_sum=False): return masked_cost.mean() + + def recons_cost_per_example(self, V, V_hat_unmasked, drop_mask = None, use_sum=False): + """ + Returns the cost of reconstructing `V` as `V_hat` + + Parameters + ---------- + V : theano matrix + A design matrix of binary features + V_hat_unmasked : theano matrix + A design matrix of mean field predictions of V + drop_mask : theano matrix + Mask specifying which examples are inputs and which are targets + use_sum : WRITEME + """ + if use_sum: + raise NotImplementedError() + + V_hat = V_hat_unmasked + + assert hasattr(V_hat, 'owner') + owner = V_hat.owner + assert owner is not None + op = owner.op + block_grad = False + if is_block_gradient(op): + assert isinstance(op.scalar_op, theano.scalar.Identity) + block_grad = True + real, = owner.inputs + owner = real.owner + op = owner.op + + if not hasattr(op, 'scalar_op'): + raise ValueError("Expected V_hat_unmasked to be generated by an Elemwise op, got "+str(op)+" of type "+str(type(op))) + assert isinstance(op.scalar_op, T.nnet.sigm.ScalarSigmoid) + z ,= owner.inputs + if block_grad: + z = block_gradient(z) + + if V.ndim != V_hat.ndim: + raise ValueError("V and V_hat_unmasked should have same ndim, but are %d and %d." % (V.ndim, V_hat.ndim)) + unmasked_cost = V * T.nnet.softplus(-z) + (1 - V) * T.nnet.softplus(z) + assert unmasked_cost.ndim == V_hat.ndim + + if drop_mask is None: + masked_cost = unmasked_cost + else: + masked_cost = drop_mask * unmasked_cost + + if masked_cost.ndim != 2: + raise NotImplementedError() + + return masked_cost.mean(axis=1) + + class BinaryVectorMaxPool(HiddenLayer): """ A hidden layer that does max-pooling on binary vectors. @@ -1844,6 +1907,66 @@ def recons_cost(self, Y, Y_hat_unmasked, drop_mask_Y, scale): return - rval + def recons_cost_per_example(self, Y, Y_hat_unmasked, drop_mask_Y, scale): + """ + The per-example cost of reconstructing `Y` as `Y_hat`. Specifically, + the negative log probability. + + This cost is for use with multi-prediction training. + + Parameters + ---------- + Y : target space batch + The data labels + Y_hat_unmasked : target space batch + The output of this layer's `mf_update`; the predicted + values of `Y`. Even though the model is only predicting + the dropped values, we take predictions for all the + values here. + drop_mask_Y : 1-D theano tensor + A batch of 0s/1s, with 1s indicating that variables + have been dropped, and should be included in the + reconstruction cost. One indicator per example in the + batch, since each example in this layer only has one + random variable in it. + scale : float + Multiply the cost by this amount. + We need to do this because the visible layer also goes into + the cost. We use the mean over units and examples, so that + the scale of the cost doesn't change too much with batch + size or example size. + We need to multiply this cost by scale to make sure that + it is put on the same scale as the reconstruction cost + for the visible units. ie, scale should be 1/nvis + """ + + + Y_hat = Y_hat_unmasked + assert hasattr(Y_hat, 'owner') + owner = Y_hat.owner + assert owner is not None + op = owner.op + if isinstance(op, Print): + assert len(owner.inputs) == 1 + Y_hat, = owner.inputs + owner = Y_hat.owner + op = owner.op + assert isinstance(op, T.nnet.Softmax) + z ,= owner.inputs + assert z.ndim == 2 + + z = z - z.max(axis=1).dimshuffle(0, 'x') + log_prob = z - T.log(T.exp(z).sum(axis=1).dimshuffle(0, 'x')) + # we use sum and not mean because this is really one variable per row + log_prob_of = (Y * log_prob).sum(axis=1) + masked = log_prob_of * drop_mask_Y + assert masked.ndim == 1 + + rval = masked * scale * self.copies + assert rval.ndim == 1 + + return - rval + def init_mf_state(self): """ .. todo:: From e64a78acc8424e4e722a6b8831f637727dd76917 Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Fri, 1 May 2015 17:14:30 -0700 Subject: [PATCH 3/3] PEP8 --- pylearn2/costs/dbm.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pylearn2/costs/dbm.py b/pylearn2/costs/dbm.py index b3114e94f5..93d1773535 100644 --- a/pylearn2/costs/dbm.py +++ b/pylearn2/costs/dbm.py @@ -1469,7 +1469,7 @@ def expr(self, model, data, drop_mask=None, drop_mask_Y=None, @wraps(Cost.cost_per_example) def cost_per_example(self, model, data, drop_mask=None, drop_mask_Y=None, - return_locals=False, include_toronto=True, ** kwargs): + return_locals=False, include_toronto=True, ** kwargs): if self.supervised: X, Y = data else: @@ -1531,7 +1531,7 @@ def cost_per_example(self, model, data, drop_mask=None, drop_mask_Y=None, self.both_directions = False if self.both_directions: raise NotImplementedError("both_directions was a research " - "feature not supported going forward.") + "feature not supported going forward.") new_final_state = new_history[-1] @@ -1546,14 +1546,14 @@ def cost_per_example(self, model, data, drop_mask=None, drop_mask_Y=None, self.robustness = None if self.robustness is not None: raise NotImplementedError("robustness was a research " - "feature not supported going forward.") + "feature not supported going forward.") if not hasattr(self, 'toronto_act_targets'): self.toronto_act_targets = None toronto_act_cost = None if self.toronto_act_targets is not None and include_toronto: raise NotImplementedError("Toronto sparsity was a research " - "feature not supported going forward.") + "feature not supported going forward.") if return_locals: return locals() @@ -1725,11 +1725,12 @@ def get_inpaint_cost(self, dbm, X, V_hat_unmasked, drop_mask, state, return rval - def get_inpaint_cost_per_example(self, dbm, X, V_hat_unmasked, drop_mask, state, - Y, drop_mask_Y): + def get_inpaint_cost_per_example( + self, dbm, X, V_hat_unmasked, drop_mask, state, + Y, drop_mask_Y): """ - Returns the generalized pseudolikelihood per example given raw data, a mask, - and the output of inference. + Returns the generalized pseudolikelihood per example given raw + data, a mask, and the output of inference. Parameters ---------- @@ -1948,8 +1949,8 @@ def cost_from_states(self, state, new_state, dbm, X, Y, drop_mask, return total_cost def per_example_cost_from_states(self, state, new_state, dbm, X, Y, - drop_mask, drop_mask_Y, new_drop_mask, new_drop_mask_Y, - return_locals=False): + drop_mask, drop_mask_Y, new_drop_mask, + new_drop_mask_Y, return_locals=False): """ Returns the total cost per example, given the states produced by inference. @@ -1983,7 +1984,7 @@ def per_example_cost_from_states(self, state, new_state, dbm, X, Y, if hasattr(self, 'both_directions') and self.both_directions: raise NotImplementedError("both_directions was a research" - "feature that is no longer supported.") + "feature that is no longer supported.") if not self.supervised: assert drop_mask_Y is None @@ -2007,11 +2008,11 @@ def per_example_cost_from_states(self, state, new_state, dbm, X, Y, if hasattr(self, 'range_rewards') and self.range_rewards is not None: raise NotImplementedError("range_rewards was a research feature " - "not supported going forward.") + "not supported going forward.") if hasattr(self, 'stdev_rewards') and self.stdev_rewards is not None: raise NotImplementedError("stdev_rewards was a research feature " - "not supported going forward.") + "not supported going forward.") l1_act_cost_per_example = None if self.l1_act_targets is not None: @@ -2047,13 +2048,13 @@ def per_example_cost_from_states(self, state, new_state, dbm, X, Y, self.hid_presynaptic_cost = None if self.hid_presynaptic_cost is not None: raise NotImplementedError("this was a research feature " - "not supported going forward.") + "not supported going forward.") if not hasattr(self, 'reweighted_act_targets'): self.reweighted_act_targets = None if self.reweighted_act_targets is not None: raise NotImplementedError("this was a research feature " - "not supported going forward.") + "not supported going forward.") assert total_cost_per_example.ndim == 1 if return_locals: