## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)

## -----------------------------------------------------------------------------
library(XGeoRTR)

scale01 <- function(x) {
x <- as.numeric(x)
rng <- range(x, finite = TRUE)
if (!all(is.finite(rng)) || diff(rng) == 0) {
  return(rep(0.5, length(x)))
}
(x - rng[[1]]) / diff(rng)
}

finish_state <- function(state, embedding_source = "explanations", k = 3L) {
state <- compute_xgeo_embedding(
  state,
  method = "pca",
  source = embedding_source,
  dims = 2
)
embedding_name <- paste("pca", embedding_source, sep = "_")
state <- set_active_embedding(state, embedding_name)
state <- compute_xgeo_diagnostics(
  state,
  embedding = embedding_name,
  source = embedding_source,
  k = k
)
build_xgeo_lod(
  state,
  embedding = embedding_name,
  levels = c(8L, 16L),
  auto_threshold = 10L
)
}

## -----------------------------------------------------------------------------
mt <- datasets::mtcars
mt$car <- rownames(mt)

fit_lm <- stats::lm(mpg ~ wt + hp + qsec, data = mt)
terms_lm <- c("wt", "hp", "qsec")
centered_lm <- scale(mt[, terms_lm, drop = FALSE], center = TRUE, scale = FALSE)
contrib_lm <- sweep(centered_lm, 2, stats::coef(fit_lm)[terms_lm], `*`)
fitted_lm <- stats::predict(fit_lm)
resid_lm <- stats::residuals(fit_lm)

lm_tbl <- data.frame(
point_id = rep(mt$car, each = length(terms_lm)),
feature = rep(terms_lm, times = nrow(mt)),
value = as.vector(t(contrib_lm)),
x = rep(scale01(mt$wt), each = length(terms_lm)),
y = rep(scale01(fitted_lm), each = length(terms_lm)),
z = rep(scale01(abs(resid_lm)), each = length(terms_lm)),
response = rep(mt$mpg, each = length(terms_lm)),
fitted = rep(fitted_lm, each = length(terms_lm)),
residual = rep(resid_lm, each = length(terms_lm))
)

state_lm <- as_xgeo_state(
lm_tbl,
point_id_col = "point_id",
feature_col = "feature",
method = "linear-model-coefficient-contributions",
meta = list(dataset = "datasets::mtcars", model = "stats::lm")
)

state_lm <- finish_state(state_lm)
summary(state_lm)

## -----------------------------------------------------------------------------
mt$efficient <- as.integer(mt$mpg > stats::median(mt$mpg))
fit_glm <- stats::glm(efficient ~ wt + hp + qsec, data = mt, family = stats::binomial())
prob_glm <- stats::predict(fit_glm, type = "response")
terms_glm <- c("wt", "hp", "qsec")
centered_glm <- scale(mt[, terms_glm, drop = FALSE], center = TRUE, scale = FALSE)
contrib_glm <- sweep(centered_glm, 2, stats::coef(fit_glm)[terms_glm], `*`)

glm_tbl <- data.frame(
point_id = rep(mt$car, each = length(terms_glm)),
feature = rep(terms_glm, times = nrow(mt)),
value = as.vector(t(contrib_glm)),
x = rep(scale01(mt$wt), each = length(terms_glm)),
y = rep(prob_glm, each = length(terms_glm)),
z = rep(scale01(abs(stats::predict(fit_glm, type = "link"))), each = length(terms_glm)),
class = rep(ifelse(mt$efficient == 1L, "high_mpg", "low_mpg"), each = length(terms_glm)),
probability = rep(prob_glm, each = length(terms_glm))
)

state_glm <- as_xgeo_state(
glm_tbl,
point_id_col = "point_id",
feature_col = "feature",
method = "logistic-model-coefficient-contributions",
meta = list(dataset = "datasets::mtcars", model = "stats::glm")
)

state_glm <- finish_state(state_glm)
summary(state_glm)

## -----------------------------------------------------------------------------
iris_x <- scale(datasets::iris[, 1:4])
km <- stats::kmeans(iris_x, centers = 3L, nstart = 5L)
pca_iris <- stats::prcomp(iris_x, center = FALSE, scale. = FALSE)
cluster_residual <- iris_x - km$centers[km$cluster, , drop = FALSE]

iris_tbl <- data.frame(
point_id = rep(paste0("iris_", seq_len(nrow(iris_x))), each = ncol(iris_x)),
feature = rep(colnames(iris_x), times = nrow(iris_x)),
value = as.vector(t(cluster_residual)),
x = rep(pca_iris$x[, 1], each = ncol(iris_x)),
y = rep(pca_iris$x[, 2], each = ncol(iris_x)),
z = rep(scale01(km$cluster), each = ncol(iris_x)),
species = rep(as.character(datasets::iris$Species), each = ncol(iris_x)),
cluster = rep(paste0("cluster_", km$cluster), each = ncol(iris_x))
)

state_km <- as_xgeo_state(
iris_tbl,
point_id_col = "point_id",
feature_col = "feature",
method = "kmeans-residual-geometry",
meta = list(dataset = "datasets::iris", model = "stats::kmeans")
)

state_km <- finish_state(state_km, k = 5L)
summary(state_km)

## -----------------------------------------------------------------------------
arrests <- datasets::USArrests
pca_arrests <- stats::prcomp(arrests, center = TRUE, scale. = TRUE)
scaled_arrests <- scale(arrests, center = pca_arrests$center, scale = pca_arrests$scale)
pc1_contrib <- sweep(scaled_arrests, 2, pca_arrests$rotation[, 1], `*`)

pca_tbl <- data.frame(
point_id = rep(rownames(arrests), each = ncol(arrests)),
feature = rep(colnames(arrests), times = nrow(arrests)),
value = as.vector(t(pc1_contrib)),
x = rep(pca_arrests$x[, 1], each = ncol(arrests)),
y = rep(pca_arrests$x[, 2], each = ncol(arrests)),
z = rep(scale01(rowSums(abs(pc1_contrib))), each = ncol(arrests))
)

state_pca <- as_xgeo_state(
pca_tbl,
point_id_col = "point_id",
feature_col = "feature",
method = "pca-loading-contribution-geometry",
meta = list(dataset = "datasets::USArrests", model = "stats::prcomp")
)

state_pca <- finish_state(state_pca, k = 4L)
summary(state_pca)

## -----------------------------------------------------------------------------
state_volcano <- as_xgeo_state(
  datasets::volcano,
  method = "matrix-regular-grid",
  meta = list(dataset = "datasets::volcano")
)
state_volcano <- finish_state(state_volcano, embedding_source = "points", k = 4L)

point_tbl <- xgeo_point_values(state_volcano)
grid <- xgeo_regular_grid(point_tbl)
names(grid)

## -----------------------------------------------------------------------------
json_file <- tempfile(fileext = ".json")
write_xgeo_state(state_lm, json_file)
restored <- read_xgeo_state(json_file)

class(restored)
restored$attributes$embeddings$active

## -----------------------------------------------------------------------------
long_tbl <- xgeo_explanation_table(state_lm)
point_tbl <- xgeo_point_values(state_lm)

utils::head(long_tbl)
utils::head(point_tbl)

