Simple pairwise model with arbitrary interactions on a 4-connected grid. The inputs contain a cross pattern with a separate state for the center. The crosses are placed randomly in the image and noise is added.
The center state is not encoded in the input, so that the task can not be solved without pairwise interactions.
Script output:
Training 1-slack dual structural SVM
new constraint too weak.
new constraint too weak.
new constraint too weak.
new constraint too weak.
no additional constraints
overall accuracy (training set): 0.975309
Python source code: plot_grid_crf.py
import numpy as np
import matplotlib.pyplot as plt
from pystruct.models import GridCRF
import pystruct.learners as ssvm
import pystruct.toy_datasets as toy
X, Y = toy.generate_crosses_explicit(n_samples=50, noise=10)
n_labels = len(np.unique(Y))
crf = GridCRF(n_states=n_labels, inference_method="lp", neighborhood=4)
clf = ssvm.OneSlackSSVM(model=crf, max_iter=1000, C=100, verbose=0,
check_constraints=True, n_jobs=-1, inference_cache=100,
inactive_window=50, tol=.1)
clf.fit(X, Y)
Y_pred = np.array(clf.predict(X))
print("overall accuracy (training set): %f" % clf.score(X, Y))
# plot one example
x, y, y_pred = X[0], Y[0], Y_pred[0]
y_pred = y_pred.reshape(x.shape[:2])
fig, plots = plt.subplots(1, 4, figsize=(12, 4))
plots[0].matshow(y)
plots[0].set_title("ground truth")
plots[1].matshow(np.argmax(x, axis=-1))
plots[1].set_title("input")
plots[2].matshow(y_pred)
plots[2].set_title("prediction")
loss_augmented = clf.model.loss_augmented_inference(x, y, clf.w)
loss_augmented = loss_augmented.reshape(y.shape)
plots[3].matshow(loss_augmented)
plots[3].set_title("loss augmented")
for p in plots:
p.set_xticks(())
p.set_yticks(())
# visualize weights
w_un = clf.w[:3 * 3].reshape(3, 3)
# decode the symmetric pairwise potential
w_pw = np.zeros((crf.n_states, crf.n_states))
w_pw[np.tri(crf.n_states, dtype=np.bool)] = clf.w[3 * 3:]
w_pw = w_pw + w_pw.T - np.diag(np.diag(w_pw))
fig, plots = plt.subplots(1, 2, figsize=(8, 4))
plots[0].matshow(w_un, cmap='gray', vmin=-5, vmax=5)
plots[0].set_title("Unary weights")
plots[1].matshow(w_pw, cmap='gray', vmin=-5, vmax=5)
plots[1].set_title("Pairwise weights")
for p in plots:
p.set_xticks(())
p.set_yticks(())
plt.show()
Total running time of the example: 44.55 seconds