Source code for pystruct.utils

import itertools
import cPickle

import numpy as np


def unwrap_pairwise(y):
    """given a y that may contain pairwise marginals, yield plain y."""
    if isinstance(y, tuple):
        return y[0]
    return y


def make_grid_edges(x, neighborhood=4, return_lists=False):
    if neighborhood not in [4, 8]:
        raise ValueError("neighborhood can only be '4' or '8', got %s" %
                         repr(neighborhood))
    inds = np.arange(x.shape[0] * x.shape[1]).reshape(x.shape[:2])
    inds = inds.astype(np.int64)
    right = np.c_[inds[:, :-1].ravel(), inds[:, 1:].ravel()]
    down = np.c_[inds[:-1, :].ravel(), inds[1:, :].ravel()]
    edges = [right, down]
    if neighborhood == 8:
        upright = np.c_[inds[1:, :-1].ravel(), inds[:-1, 1:].ravel()]
        downright = np.c_[inds[:-1, :-1].ravel(), inds[1:, 1:].ravel()]
        edges.extend([upright, downright])
    if return_lists:
        return edges
    return np.vstack(edges)


def compute_energy(x, y, unary_params, pairwise_params, neighborhood=4):
    # x is unaries
    # y is a labeling
    n_states = x.shape[-1]
    if isinstance(y, tuple):
        # y can also be continuous (from lp)
        # in this case, it comes with accumulated edge marginals
        y, pw = y
        x_flat = x.reshape(-1, x.shape[-1])
        y_flat = y.reshape(-1, y.shape[-1])
        unaries_acc = np.sum(x_flat * y_flat, axis=0)
        pw = pw.reshape(-1, n_states, n_states).sum(axis=0)
    else:
        ## unary features:
        gx, gy = np.ogrid[:x.shape[0], :x.shape[1]]
        selected_unaries = x[gx, gy, y]
        unaries_acc = np.bincount(y.ravel(), selected_unaries.ravel(),
                                  minlength=n_states)

        ##accumulated pairwise
        #make one hot encoding
        labels = np.zeros((y.shape[0], y.shape[1], n_states),
                          dtype=np.int)
        labels[gx, gy, y] = 1

        if neighborhood == 4:
            # vertical edges
            vert = np.dot(labels[1:, :, :].reshape(-1, n_states).T,
                          labels[:-1, :, :].reshape(-1, n_states))
            # horizontal edges
            horz = np.dot(labels[:, 1:, :].reshape(-1, n_states).T,
                          labels[:, :-1, :].reshape(-1, n_states))
            pw = vert + horz
        elif neighborhood == 8:
            # vertical edges
            vert = np.dot(labels[1:, :, :].reshape(-1, n_states).T,
                          labels[:-1, :, :].reshape(-1, n_states))
            # horizontal edges
            horz = np.dot(labels[:, 1:, :].reshape(-1, n_states).T,
                          labels[:, :-1, :].reshape(-1, n_states))
            diag1 = np.dot(labels[1:, 1:, :].reshape(-1, n_states).T,
                           labels[1:, :-1, :].reshape(-1, n_states))
            diag2 = np.dot(labels[1:, 1:, :].reshape(-1, n_states).T,
                           labels[:-1, :-1, :].reshape(-1, n_states))
            pw = vert + horz + diag1 + diag2
    pw = pw + pw.T - np.diag(np.diag(pw))
    energy = (np.dot(unaries_acc, unary_params)
              + np.dot(np.tril(pw).ravel(), pairwise_params.ravel()))
    return energy


## global functions for easy parallelization
def find_constraint(model, x, y, w, y_hat=None, relaxed=True,
                    compute_difference=True):
    """Find most violated constraint, or, given y_hat,
    find slack and dpsi for this constraing.

    As for finding the most violated constraint, it is enough to compute
    psi(x, y_hat), not dpsi, we can optionally skip computing psi(x, y)
    using compute_differences=False
    """

    if y_hat is None:
        y_hat = model.loss_augmented_inference(x, y, w, relaxed=relaxed)
    psi = model.psi
    if compute_difference:
        delta_psi = psi(x, y) - psi(x, y_hat)
    else:
        delta_psi = -psi(x, y_hat)
    if isinstance(y_hat, tuple):
        # continuous label
        loss = model.continuous_loss(y, y_hat[0])
    else:
        loss = model.loss(y, y_hat)
    slack = max(loss - np.dot(w, delta_psi), 0)
    return y_hat, delta_psi, slack, loss


def find_constraint_latent(model, x, y, w, relaxed=True):
    """Find most violated constraint.

    As for finding the most violated constraint, it is enough to compute
    psi(x, y_hat), not dpsi, we can optionally skip computing psi(x, y)
    using compute_differences=False
    """
    h = model.latent(x, y, w)
    h_hat = model.loss_augmented_inference(x, h, w, relaxed=relaxed)
    psi = model.psi
    delta_psi = psi(x, h) - psi(x, h_hat)

    loss = model.loss(y, h_hat)
    slack = max(loss - np.dot(w, delta_psi), 0)
    return h_hat, delta_psi, slack, loss


def inference(model, x, w):
    return model.inference(x, w)


def loss_augmented_inference(model, x, y, w, relaxed=True):
    return model.loss_augmented_inference(x, y, w, relaxed=relaxed)


# easy debugging
def objective_primal(model, w, X, Y, C):
    objective = 0
    psi = model.psi
    for x, y in zip(X, Y):
        y_hat = model.loss_augmented_inference(x, y, w)
        loss = model.loss(y, y_hat)
        delta_psi = psi(x, y) - psi(x, y_hat)
        objective += loss - np.dot(w, delta_psi)
    objective /= float(len(X))
    objective += np.sum(w ** 2) / float(C) / 2.
    return objective


def exhaustive_loss_augmented_inference(model, x, y, w):
    size = y.size
    best_y = None
    best_energy = np.inf
    for y_hat in itertools.product(range(model.n_states), repeat=size):
        y_hat = np.array(y_hat).reshape(y.shape)
        #print("trying %s" % repr(y_hat))
        psi = model.psi(x, y_hat)
        energy = -model.loss(y, y_hat) - np.dot(w, psi)
        if energy < best_energy:
            best_energy = energy
            best_y = y_hat
    return best_y


def exhaustive_inference(model, x, w):
    # hack to get the grid shape of x
    if isinstance(x, np.ndarray):
        feats = x
    else:
        feats = model.get_features(x)
    size = np.prod(feats.shape[:-1])
    best_y = None
    best_energy = np.inf
    for y_hat in itertools.product(range(model.n_states), repeat=size):
        y_hat = np.array(y_hat).reshape(feats.shape[:-1])
        #print("trying %s" % repr(y_hat))
        psi = model.psi(x, y_hat)
        energy = -np.dot(w, psi)
        if energy < best_energy:
            best_energy = energy
            best_y = y_hat
    return best_y


[docs]class SaveLogger(object): """Logging class that stores the model periodically. Can be used to back up a model during learning. Also a prototype to demonstrate the logging interface. Parameters ---------- file_name : string File in which the model will be stored. If the string contains '%d', this will be replaced with the current iteration. save_every : int (default=10) How often the model should be stored (in iterations). verbose : int (default=0) Verbosity level. """
[docs] def __init__(self, file_name, save_every=10, verbose=0): self.file_name = file_name self.save_every = save_every self.verbose = verbose
def __repr__(self): return ('%s(file_name="%s", save_every=%s)' % (self.__class__.__name__, self.file_name, self.save_every))
[docs] def __call__(self, learner, iteration=0): """Save learner if iterations is a multiple of save_every or "final". Parameters ---------- learner : object Learning object to be saved. iteration : int or 'final' (default=0) If 'final' or save_every % iteration == 0, the model will be saved. """ if iteration == 'final' or not iteration % self.save_every: file_name = self.file_name if "%d" in file_name: file_name = file_name % iteration if self.verbose > 0: print("saving %s to file %s" % (learner, file_name)) with open(file_name, "wb") as f: if hasattr(learner, 'inference_cache_'): # don't store the large inference cache! learner.inference_cache_, tmp = (None, learner.inference_cache_) cPickle.dump(learner, f, -1) learner.inference_cache_ = tmp else: cPickle.dump(learner, f, -1)
[docs] def load(self): """Load the model stoed in file_name and return it.""" with open(self.file_name, "rb") as f: learner = cPickle.load(f) return learner