From efeaa104b4e7051208edce7e0230d17b6bf6149f Mon Sep 17 00:00:00 2001 From: Jeong Jinwoo Date: Wed, 17 Dec 2025 19:48:00 +0900 Subject: [PATCH 1/2] feat(pst_Q): add plot_pst_Q --- R/inst/plotting/plot_functions.R | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/R/inst/plotting/plot_functions.R b/R/inst/plotting/plot_functions.R index dbcf2b3e..20b7f8eb 100644 --- a/R/inst/plotting/plot_functions.R +++ b/R/inst/plotting/plot_functions.R @@ -405,6 +405,14 @@ plot_cra_exp <- function(obj, fontSize = 10, ncols = 3, binSize = 30) { return(h_all) } +plot_pst_Q <- function(obj, fontSize = 10, ncols = 2, binSize = 30) { + pars = obj$parVals + h1 = plotDist(sample = pars$mu_alpha, fontSize = fontSize, binSize = binSize, xLim = c(0,1), xLab = expression(paste(alpha, " (Learning Rate)"))) + h2 = plotDist(sample = pars$mu_beta, fontSize = fontSize, binSize = binSize, xLim = c(0,10), xLab = expression(paste(beta, " (Inverse Temp.)"))) + h_all = multiplot(h1, h2, cols = ncols) + return(h_all) +} + plot_pst_gainloss_Q <- function(obj, fontSize = 10, ncols = 3, binSize = 30) { pars = obj$parVals h1 = plotDist(sample = pars$mu_alpha_pos, fontSize = fontSize, binSize = binSize, xLim = c(0,2), xLab = expression(paste(alpha[pos], " (+Learning Rate)"))) From 32937f19996d2e5de71e76a6ec757a3bc6fa1e66 Mon Sep 17 00:00:00 2001 From: Jeong Jinwoo Date: Wed, 17 Dec 2025 19:48:25 +0900 Subject: [PATCH 2/2] feat(pst_Q): add prediction error as generated quantities --- R/R/hBayesDM_model.R | 3 +++ commons/stan_files/pst_Q.stan | 20 +++++++++----------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/R/R/hBayesDM_model.R b/R/R/hBayesDM_model.R index 603ac626..c94acaed 100644 --- a/R/R/hBayesDM_model.R +++ b/R/R/hBayesDM_model.R @@ -298,6 +298,9 @@ hBayesDM_model <- function(task_name = "", if ((model_name == "hgf_ibrb") && (model_type == "single")) { pars <- c(pars, paste0("logit_", names(parameters))) } + if ((task_name == "pst") && (model_name == "Q")) { + pars <- c(pars, "pe") + } pars <- c(pars, "log_lik") if (modelRegressor) { pars <- c(pars, names(regressors)) diff --git a/commons/stan_files/pst_Q.stan b/commons/stan_files/pst_Q.stan index 4a6cbb37..d3111606 100644 --- a/commons/stan_files/pst_Q.stan +++ b/commons/stan_files/pst_Q.stan @@ -31,24 +31,23 @@ transformed parameters { vector[N] alpha; vector[N] beta; - alpha = Phi_approx(mu_pr[1] + sigma[1] * alpha_pr); - beta = Phi_approx(mu_pr[2] + sigma[2] * beta_pr) * 10; + alpha = Phi_approx(mu_pr[1] + sigma[1] * alpha_pr); + beta = Phi_approx(mu_pr[2] + sigma[2] * beta_pr) * 10; } model { // Priors for group-level parameters - mu_pr ~ normal(0, 1); + mu_pr ~ normal(0, 1); sigma ~ normal(0, 0.2); // Priors for subject-level parameters alpha_pr ~ normal(0, 1); - beta_pr ~ normal(0, 1); + beta_pr ~ normal(0, 1); for (i in 1:N) { int co; // Chosen option real delta; // Difference between two options real pe; // Prediction error - //real alpha; vector[6] ev; // Expected values ev = initial_values; @@ -71,19 +70,18 @@ generated quantities { // For group-level parameters real mu_alpha; real mu_beta; + real pe[N, T]; // Prediction error // For log-likelihood calculation real log_lik[N]; - mu_alpha = Phi_approx(mu_pr[1]); - mu_beta = Phi_approx(mu_pr[2]) * 10; + mu_alpha = Phi_approx(mu_pr[1]); + mu_beta = Phi_approx(mu_pr[2]) * 10; { for (i in 1:N) { int co; // Chosen option real delta; // Difference between two options - real pe; // Prediction error - //real alpha; vector[6] ev; // Expected values ev = initial_values; @@ -97,8 +95,8 @@ generated quantities { delta = ev[option1[i, t]] - ev[option2[i, t]]; log_lik[i] += bernoulli_logit_lpmf(choice[i, t] | beta[i] * delta); - pe = reward[i, t] - ev[co]; - ev[co] += alpha[i] * pe; + pe[i, t] = reward[i, t] - ev[co]; + ev[co] += alpha[i] * pe[i, t]; } } }