Mixed logit with Toyota data
Contents
Mixed logit with Toyota data¶
This demo uses the dataset that was made available by Kenneth Train at https://eml.berkeley.edu/~train/ec244ps.html
The data represent consumers’ choices among vehicles in stated preference experiments. The data is from a study that Kenneth Train did for Toyota and GM to assist them in their analysis of the potential marketability of electric and hybrid vehicles, back before hybrids were introduced.
We begin by performing the necessary imports:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import sys
sys.path.insert(0, "/home/rodr/code/amortized-mxl-dev/release")
import logging
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Fix random seed for reproducibility
np.random.seed(42)
Load Toyota dataset¶
About the data:
In each choice experiment, the respondent was presented with three vehicles, with the price and other attributes of each vehicle described. The respondent was asked to state which of the three vehicles he/she would buy if the these vehicles were the only ones available in the market. There are 100 respondents in our dataset (which, to reduce estimation time, is a subset of the full dataset which contains 500 respondents.) Each respondent was presented with 15 choice experiments, and most respondents answered all 15. The attributes of the vehicles were varied over experiments, both for a given respondent and over respondents. The attributes are: price, operating cost in dollars per month, engine type (gas, electric, or hybrid), range if electric (in hundreds of miles between recharging), and the performance level of the vehicle (high, medium, or low). The performance level was described in terms of top speed and acceleration, and these descriptions did not vary for each level; for example, “High” performance was described as having a top speed of 100 mpg and 12 seconds to reach 60 mpg, and this description was the same for all “high” performance vehicles.
A detailed description of the data is provided by Kenneth Train at https://eml.berkeley.edu/~train/ec244ps.html
column_names = ["IndID","ObsID", "Chosen", "Price", "OperCost", "Range", "EV", "Gas", "Hybrid", "HighPerf", "MedHighPerf"]
df = pd.read_csv("data/toyota.txt", delimiter=" ", names=column_names)
df["Price"] = df["Price"]/10000 # scale price to be in tens of thousands of dollars.
df["OperCost"] = df["OperCost"]/10 # scale operating cost to be in tens of dollars.
# fix dataframe to match expected format
altID = []
menuID = []
curr_n = -1
curr_o = -1
curr_a = -1
curr_t = -1
for n,o in df[["IndID", "ObsID"]].values:
if n != curr_n:
curr_n += 1
curr_t = 0
if o != curr_o:
curr_t += 1
curr_a = 0
curr_a += 1
curr_n = n
curr_o = o
altID.append(curr_a)
menuID.append(curr_t)
#print(n,o,curr_t,curr_a)
df["AltID"] = altID
df["MenuID"] = menuID
df.head()
IndID | ObsID | Chosen | Price | OperCost | Range | EV | Gas | Hybrid | HighPerf | MedHighPerf | AltID | MenuID | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 0 | 4.6763 | 4.743 | 0.0 | 0 | 0 | 1 | 0 | 0 | 1 | 1 |
1 | 1 | 1 | 1 | 5.7209 | 2.743 | 1.3 | 1 | 0 | 0 | 1 | 1 | 2 | 1 |
2 | 1 | 1 | 0 | 8.7960 | 3.241 | 1.2 | 1 | 0 | 0 | 0 | 1 | 3 | 1 |
3 | 1 | 2 | 1 | 3.3768 | 0.489 | 1.3 | 1 | 0 | 0 | 1 | 1 | 1 | 2 |
4 | 1 | 2 | 0 | 9.0336 | 3.019 | 0.0 | 0 | 0 | 1 | 0 | 1 | 2 | 2 |
At the moment, the provided interface only supports data in the so-called “wide format”, so we need to convert it first:
# convert to wide format
data_wide = []
for ix in range(0,len(df),3):
new_row = df.loc[ix][["IndID","ObsID","MenuID"]].values.tolist()
new_row += df.loc[ix][["Price","OperCost","Range","EV","Hybrid","HighPerf","MedHighPerf"]].values.tolist()
new_row += df.loc[ix+1][["Price","OperCost","Range","EV","Hybrid","HighPerf","MedHighPerf"]].values.tolist()
new_row += df.loc[ix+2][["Price","OperCost","Range","EV","Hybrid","HighPerf","MedHighPerf"]].values.tolist()
choice = np.argmax([df.loc[ix]["Chosen"], df.loc[ix+1]["Chosen"], df.loc[ix+2]["Chosen"]])
new_row += [choice]
#print(new_row)
data_wide.append(new_row)
column_names = ["IndID","ObsID","MenuID",
"Price1","OperCost1","Range1","EV1","Hybrid1","HighPerf1","MedHighPerf1",
"Price2","OperCost2","Range2","EV2","Hybrid2","HighPerf2","MedHighPerf2",
"Price3","OperCost3","Range3","EV3","Hybrid3","HighPerf3","MedHighPerf3",
"Chosen"]
df_wide = pd.DataFrame(data_wide, columns=column_names)
df_wide['ones'] = np.ones(len(data_wide)).astype(int)
df_wide.head()
IndID | ObsID | MenuID | Price1 | OperCost1 | Range1 | EV1 | Hybrid1 | HighPerf1 | MedHighPerf1 | ... | MedHighPerf2 | Price3 | OperCost3 | Range3 | EV3 | Hybrid3 | HighPerf3 | MedHighPerf3 | Chosen | ones | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.0 | 1.0 | 1.0 | 4.6763 | 4.743 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 1.0 | 8.7960 | 3.241 | 1.2 | 1.0 | 0.0 | 0.0 | 1.0 | 1 | 1 |
1 | 1.0 | 2.0 | 2.0 | 3.3768 | 0.489 | 1.3 | 1.0 | 0.0 | 1.0 | 1.0 | ... | 1.0 | 5.7099 | 2.716 | 1.8 | 1.0 | 0.0 | 1.0 | 1.0 | 0 | 1 |
2 | 1.0 | 3.0 | 3.0 | 4.5534 | 1.072 | 1.2 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 1.0 | 3.4031 | 6.062 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 | 1 |
3 | 1.0 | 4.0 | 4.0 | 0.8639 | 2.216 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | ... | 0.0 | 6.9325 | 2.884 | 1.6 | 1.0 | 0.0 | 0.0 | 0.0 | 1 | 1 |
4 | 1.0 | 5.0 | 5.0 | 5.2145 | 3.975 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | ... | 1.0 | 2.1282 | 5.272 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0 | 1 |
5 rows × 26 columns
Mixed Logit specification¶
We now make use of the developed formula interface to specify the utilities of the mixed logit model.
We begin by defining the fixed effects parameters, the random effects parameters, and the observed variables. This creates instances of Python objects that can be put together to define the utility functions for the different alternatives.
Once the utilities are defined, we collect them in a Python dictionary mapping alternative names to their corresponding expressions.
from core.dcm_interface import FixedEffect, RandomEffect, ObservedVariable
# define fixed effects parameters
B_PRICE = FixedEffect('B_PRICE')
# define random effects parameters
B_OperCost = RandomEffect('B_OperCost')
B_Range = RandomEffect('B_Range')
B_EV = RandomEffect('B_EV')
B_Hybrid = RandomEffect('B_Hybrid')
B_HighPerf = RandomEffect('B_HighPerf')
B_MedHighPerf = RandomEffect('B_MedHighPerf')
# define observed variables
for attr in df_wide.columns:
exec("%s = ObservedVariable('%s')" % (attr,attr))
# define utility functions
V1 = B_PRICE*Price1 + B_OperCost*OperCost1 + B_Range*Range1 + B_EV*EV1 + B_Hybrid*Hybrid1 + B_HighPerf*HighPerf1 + B_MedHighPerf*MedHighPerf1
V2 = B_PRICE*Price2 + B_OperCost*OperCost2 + B_Range*Range2 + B_EV*EV2 + B_Hybrid*Hybrid2 + B_HighPerf*HighPerf2 + B_MedHighPerf*MedHighPerf2
V3 = B_PRICE*Price3 + B_OperCost*OperCost3 + B_Range*Range3 + B_EV*EV3 + B_Hybrid*Hybrid3 + B_HighPerf*HighPerf3 + B_MedHighPerf*MedHighPerf3
# associate utility functions with the names of the alternatives
utilities = {"ALT1": V1, "ALT2": V2, "ALT3": V3}
We are now ready to create a Specification object containing the utilities that we have just defined. Note that we must also specify the type of choice model to be used - a mixed logit model (MXL) in this case.
Note that we can inspect the specification by printing the dcm_spec object.
from core.dcm_interface import Specification
# create MXL specification object based on the utilities previously defined
dcm_spec = Specification('MXL', utilities)
print(dcm_spec)
----------------- MXL specification:
Alternatives: ['ALT1', 'ALT2', 'ALT3']
Utility functions:
V_ALT1 = B_PRICE*Price1 + B_OperCost_n*OperCost1 + B_Range_n*Range1 + B_EV_n*EV1 + B_Hybrid_n*Hybrid1 + B_HighPerf_n*HighPerf1 + B_MedHighPerf_n*MedHighPerf1
V_ALT2 = B_PRICE*Price2 + B_OperCost_n*OperCost2 + B_Range_n*Range2 + B_EV_n*EV2 + B_Hybrid_n*Hybrid2 + B_HighPerf_n*HighPerf2 + B_MedHighPerf_n*MedHighPerf2
V_ALT3 = B_PRICE*Price3 + B_OperCost_n*OperCost3 + B_Range_n*Range3 + B_EV_n*EV3 + B_Hybrid_n*Hybrid3 + B_HighPerf_n*HighPerf3 + B_MedHighPerf_n*MedHighPerf3
Num. parameters to be estimated: 7
Fixed effects params: ['B_PRICE']
Random effects params: ['B_OperCost', 'B_Range', 'B_EV', 'B_Hybrid', 'B_HighPerf', 'B_MedHighPerf']
Once the Specification is defined, we need to define the DCM Dataset object that goes along with it. For this, we instantiate the Dataset class with the Pandas dataframe containing the data in the so-called “wide format”, the name of column in the dataframe containing the observed choices and the dcm_spec that we have previously created.
Note that since this is panel data, we must also specify the name of the column in the dataframe that contains the ID of the respondent (this should be a integer ranging from 0 the num_resp-1).
from core.dcm_interface import Dataset
# create DCM dataset object
dcm_dataset = Dataset(df_wide, 'Chosen', dcm_spec, resp_id_col='IndID')
Preparing dataset...
Model type: MXL
Num. observations: 1484
Num. alternatives: 3
Num. respondents: 100
Num. menus: 15
Observations IDs: [ 0 1 2 ... 1481 1482 1483]
Alternative IDs: None
Respondent IDs: [ 1 1 1 ... 100 100 100]
Availability columns: None
Attribute names: ['Price1', 'OperCost1', 'Range1', 'EV1', 'Hybrid1', 'HighPerf1', 'MedHighPerf1', 'Price2', 'OperCost2', 'Range2', 'EV2', 'Hybrid2', 'HighPerf2', 'MedHighPerf2', 'Price3', 'OperCost3', 'Range3', 'EV3', 'Hybrid3', 'HighPerf3', 'MedHighPerf3']
Fixed effects attribute names: ['Price1', 'Price2', 'Price3']
Fixed effects parameter names: ['B_PRICE', 'B_PRICE', 'B_PRICE']
Random effects attribute names: ['OperCost1', 'Range1', 'EV1', 'Hybrid1', 'HighPerf1', 'MedHighPerf1', 'OperCost2', 'Range2', 'EV2', 'Hybrid2', 'HighPerf2', 'MedHighPerf2', 'OperCost3', 'Range3', 'EV3', 'Hybrid3', 'HighPerf3', 'MedHighPerf3']
Random effects parameter names: ['B_OperCost', 'B_Range', 'B_EV', 'B_Hybrid', 'B_HighPerf', 'B_MedHighPerf', 'B_OperCost', 'B_Range', 'B_EV', 'B_Hybrid', 'B_HighPerf', 'B_MedHighPerf', 'B_OperCost', 'B_Range', 'B_EV', 'B_Hybrid', 'B_HighPerf', 'B_MedHighPerf']
Alternative attributes ndarray.shape: (100, 15, 21)
Choices ndarray.shape: (100, 15)
Alternatives availability ndarray.shape: (100, 15, 3)
Data mask ndarray.shape: (100, 15)
Context data ndarray.shape: (100, 0)
Done!
As with the specification, we can inspect the DCM dataset by printing the dcm_dataset object:
print(dcm_dataset)
----------------- DCM dataset:
Model type: MXL
Num. observations: 1484
Num. alternatives: 3
Num. respondents: 100
Num. menus: 15
Num. fixed effects: 3
Num. random effects: 18
Attribute names: ['Price1', 'OperCost1', 'Range1', 'EV1', 'Hybrid1', 'HighPerf1', 'MedHighPerf1', 'Price2', 'OperCost2', 'Range2', 'EV2', 'Hybrid2', 'HighPerf2', 'MedHighPerf2', 'Price3', 'OperCost3', 'Range3', 'EV3', 'Hybrid3', 'HighPerf3', 'MedHighPerf3']
Bayesian Mixed Logit Model in PyTorch¶
It is now time to perform approximate Bayesian inference on the mixed logit model that we have specified. The generative process of the MXL model that we will be using is the following:
Draw fixed taste parameters \(\boldsymbol\alpha \sim \mathcal{N}(\boldsymbol\lambda_0, \boldsymbol\Xi_0)\)
Draw mean vector \(\boldsymbol\zeta \sim \mathcal{N}(\boldsymbol\mu_0, \boldsymbol\Sigma_0)\)
Draw scales vector \(\boldsymbol\theta \sim \mbox{half-Cauchy}(\boldsymbol\sigma_0)\)
Draw correlation matrix \(\boldsymbol\Psi \sim \mbox{LKJ}(\nu)\)
For each decision-maker \(n \in \{1,\dots,N\}\)
Draw random taste parameters \(\boldsymbol\beta_n \sim \mathcal{N}(\boldsymbol\zeta,\boldsymbol\Omega)\)
For each choice occasion \(t \in \{1,\dots,T_n\}\)
Draw observed choice \(y_{nt} \sim \mbox{MNL}(\boldsymbol\alpha, \boldsymbol\beta_n, \textbf{X}_{nt})\)
where \(\boldsymbol\Omega = \mbox{diag}(\boldsymbol\theta) \times \boldsymbol\Psi \times \mbox{diag}(\boldsymbol\theta)\).
We can instantiate this model from the TorchMXL using the following code. We can the run variational inference to approximate the posterior distribution of the latent variables in the model.
Note that we are providing the “infer” method with the parameter estimates obtained by Biogeme in order to track the progress of VI and for comparison and debugging purposes. I.e., the Alpha RMSE and Beta RMSE values outputted correspond to the RMSE with respect to the results of Biogeme.
%%time
from core.torch_mxl import TorchMXL
# instantiate MXL model
mxl = TorchMXL(dcm_dataset, batch_size=dcm_dataset.num_resp, use_inference_net=False, use_cuda=True)
# we are using Biogeme's results as a reference results for comparison
biogeme_alpha = np.array([-0.5080])
biogeme_beta = np.array([-0.1355, 0.4759, -1.5995, 0.5256, 0.1116, 0.5333])
# run Bayesian inference (variational inference)
results = mxl.infer(num_epochs=10000, true_alpha=biogeme_alpha, true_beta=biogeme_beta)
[Epoch 0] ELBO: 5683; Loglik: -4727; Acc.: 0.228; Alpha RMSE: 0.498; Beta RMSE: 0.752
[Epoch 100] ELBO: 2387; Loglik: -1793; Acc.: 0.489; Alpha RMSE: 0.072; Beta RMSE: 0.786
[Epoch 200] ELBO: 1785; Loglik: -1366; Acc.: 0.604; Alpha RMSE: 0.090; Beta RMSE: 0.766
[Epoch 300] ELBO: 2024; Loglik: -1530; Acc.: 0.617; Alpha RMSE: 0.158; Beta RMSE: 0.703
[Epoch 400] ELBO: 1864; Loglik: -1343; Acc.: 0.599; Alpha RMSE: 0.177; Beta RMSE: 0.645
[Epoch 500] ELBO: 1776; Loglik: -1234; Acc.: 0.653; Alpha RMSE: 0.188; Beta RMSE: 0.615
[Epoch 600] ELBO: 1730; Loglik: -1186; Acc.: 0.656; Alpha RMSE: 0.195; Beta RMSE: 0.581
[Epoch 700] ELBO: 1642; Loglik: -1208; Acc.: 0.679; Alpha RMSE: 0.201; Beta RMSE: 0.539
[Epoch 800] ELBO: 1602; Loglik: -1133; Acc.: 0.675; Alpha RMSE: 0.167; Beta RMSE: 0.521
[Epoch 900] ELBO: 1602; Loglik: -1173; Acc.: 0.660; Alpha RMSE: 0.164; Beta RMSE: 0.522
[Epoch 1000] ELBO: 1581; Loglik: -1112; Acc.: 0.687; Alpha RMSE: 0.179; Beta RMSE: 0.515
[Epoch 1100] ELBO: 1620; Loglik: -1185; Acc.: 0.667; Alpha RMSE: 0.175; Beta RMSE: 0.526
[Epoch 1200] ELBO: 1581; Loglik: -1158; Acc.: 0.681; Alpha RMSE: 0.184; Beta RMSE: 0.507
[Epoch 1300] ELBO: 1588; Loglik: -1139; Acc.: 0.660; Alpha RMSE: 0.200; Beta RMSE: 0.494
[Epoch 1400] ELBO: 1533; Loglik: -1126; Acc.: 0.668; Alpha RMSE: 0.165; Beta RMSE: 0.492
[Epoch 1500] ELBO: 1549; Loglik: -1132; Acc.: 0.689; Alpha RMSE: 0.164; Beta RMSE: 0.462
[Epoch 1600] ELBO: 1486; Loglik: -1078; Acc.: 0.691; Alpha RMSE: 0.172; Beta RMSE: 0.449
[Epoch 1700] ELBO: 1532; Loglik: -1119; Acc.: 0.685; Alpha RMSE: 0.177; Beta RMSE: 0.442
[Epoch 1800] ELBO: 1499; Loglik: -1146; Acc.: 0.664; Alpha RMSE: 0.148; Beta RMSE: 0.422
[Epoch 1900] ELBO: 1523; Loglik: -1157; Acc.: 0.672; Alpha RMSE: 0.179; Beta RMSE: 0.413
[Epoch 2000] ELBO: 1481; Loglik: -1126; Acc.: 0.683; Alpha RMSE: 0.141; Beta RMSE: 0.407
[Epoch 2100] ELBO: 1550; Loglik: -1147; Acc.: 0.648; Alpha RMSE: 0.119; Beta RMSE: 0.405
[Epoch 2200] ELBO: 1519; Loglik: -1137; Acc.: 0.659; Alpha RMSE: 0.134; Beta RMSE: 0.380
[Epoch 2300] ELBO: 1445; Loglik: -1109; Acc.: 0.675; Alpha RMSE: 0.139; Beta RMSE: 0.348
[Epoch 2400] ELBO: 1445; Loglik: -1122; Acc.: 0.681; Alpha RMSE: 0.108; Beta RMSE: 0.349
[Epoch 2500] ELBO: 1460; Loglik: -1122; Acc.: 0.671; Alpha RMSE: 0.124; Beta RMSE: 0.335
[Epoch 2600] ELBO: 1464; Loglik: -1144; Acc.: 0.671; Alpha RMSE: 0.108; Beta RMSE: 0.323
[Epoch 2700] ELBO: 1465; Loglik: -1140; Acc.: 0.675; Alpha RMSE: 0.110; Beta RMSE: 0.308
[Epoch 2800] ELBO: 1466; Loglik: -1124; Acc.: 0.683; Alpha RMSE: 0.108; Beta RMSE: 0.291
[Epoch 2900] ELBO: 1443; Loglik: -1114; Acc.: 0.677; Alpha RMSE: 0.099; Beta RMSE: 0.283
[Epoch 3000] ELBO: 1458; Loglik: -1140; Acc.: 0.657; Alpha RMSE: 0.081; Beta RMSE: 0.280
[Epoch 3100] ELBO: 1439; Loglik: -1137; Acc.: 0.667; Alpha RMSE: 0.080; Beta RMSE: 0.276
[Epoch 3200] ELBO: 1405; Loglik: -1104; Acc.: 0.675; Alpha RMSE: 0.097; Beta RMSE: 0.256
[Epoch 3300] ELBO: 1434; Loglik: -1115; Acc.: 0.673; Alpha RMSE: 0.079; Beta RMSE: 0.227
[Epoch 3400] ELBO: 1438; Loglik: -1126; Acc.: 0.691; Alpha RMSE: 0.069; Beta RMSE: 0.231
[Epoch 3500] ELBO: 1418; Loglik: -1127; Acc.: 0.666; Alpha RMSE: 0.063; Beta RMSE: 0.224
[Epoch 3600] ELBO: 1413; Loglik: -1104; Acc.: 0.685; Alpha RMSE: 0.053; Beta RMSE: 0.220
[Epoch 3700] ELBO: 1428; Loglik: -1123; Acc.: 0.672; Alpha RMSE: 0.064; Beta RMSE: 0.205
[Epoch 3800] ELBO: 1423; Loglik: -1126; Acc.: 0.665; Alpha RMSE: 0.056; Beta RMSE: 0.193
[Epoch 3900] ELBO: 1410; Loglik: -1119; Acc.: 0.673; Alpha RMSE: 0.078; Beta RMSE: 0.195
[Epoch 4000] ELBO: 1438; Loglik: -1151; Acc.: 0.660; Alpha RMSE: 0.052; Beta RMSE: 0.174
[Epoch 4100] ELBO: 1427; Loglik: -1141; Acc.: 0.685; Alpha RMSE: 0.058; Beta RMSE: 0.167
[Epoch 4200] ELBO: 1418; Loglik: -1131; Acc.: 0.648; Alpha RMSE: 0.067; Beta RMSE: 0.155
[Epoch 4300] ELBO: 1417; Loglik: -1134; Acc.: 0.672; Alpha RMSE: 0.049; Beta RMSE: 0.142
[Epoch 4400] ELBO: 1418; Loglik: -1139; Acc.: 0.663; Alpha RMSE: 0.073; Beta RMSE: 0.143
[Epoch 4500] ELBO: 1416; Loglik: -1132; Acc.: 0.680; Alpha RMSE: 0.075; Beta RMSE: 0.134
[Epoch 4600] ELBO: 1430; Loglik: -1157; Acc.: 0.660; Alpha RMSE: 0.029; Beta RMSE: 0.139
[Epoch 4700] ELBO: 1395; Loglik: -1113; Acc.: 0.679; Alpha RMSE: 0.073; Beta RMSE: 0.133
[Epoch 4800] ELBO: 1435; Loglik: -1150; Acc.: 0.666; Alpha RMSE: 0.050; Beta RMSE: 0.126
[Epoch 4900] ELBO: 1380; Loglik: -1102; Acc.: 0.675; Alpha RMSE: 0.056; Beta RMSE: 0.109
[Epoch 5000] ELBO: 1381; Loglik: -1110; Acc.: 0.671; Alpha RMSE: 0.045; Beta RMSE: 0.118
[Epoch 5100] ELBO: 1443; Loglik: -1160; Acc.: 0.649; Alpha RMSE: 0.056; Beta RMSE: 0.111
[Epoch 5200] ELBO: 1432; Loglik: -1159; Acc.: 0.661; Alpha RMSE: 0.049; Beta RMSE: 0.094
[Epoch 5300] ELBO: 1399; Loglik: -1129; Acc.: 0.668; Alpha RMSE: 0.020; Beta RMSE: 0.085
[Epoch 5400] ELBO: 1403; Loglik: -1124; Acc.: 0.681; Alpha RMSE: 0.025; Beta RMSE: 0.086
[Epoch 5500] ELBO: 1407; Loglik: -1132; Acc.: 0.683; Alpha RMSE: 0.035; Beta RMSE: 0.081
[Epoch 5600] ELBO: 1432; Loglik: -1162; Acc.: 0.653; Alpha RMSE: 0.037; Beta RMSE: 0.081
[Epoch 5700] ELBO: 1398; Loglik: -1126; Acc.: 0.668; Alpha RMSE: 0.029; Beta RMSE: 0.074
[Epoch 5800] ELBO: 1382; Loglik: -1114; Acc.: 0.664; Alpha RMSE: 0.031; Beta RMSE: 0.091
[Epoch 5900] ELBO: 1392; Loglik: -1117; Acc.: 0.691; Alpha RMSE: 0.045; Beta RMSE: 0.068
[Epoch 6000] ELBO: 1427; Loglik: -1145; Acc.: 0.661; Alpha RMSE: 0.042; Beta RMSE: 0.079
[Epoch 6100] ELBO: 1418; Loglik: -1148; Acc.: 0.654; Alpha RMSE: 0.036; Beta RMSE: 0.063
[Epoch 6200] ELBO: 1393; Loglik: -1126; Acc.: 0.678; Alpha RMSE: 0.042; Beta RMSE: 0.056
[Epoch 6300] ELBO: 1397; Loglik: -1134; Acc.: 0.673; Alpha RMSE: 0.051; Beta RMSE: 0.049
[Epoch 6400] ELBO: 1420; Loglik: -1159; Acc.: 0.658; Alpha RMSE: 0.031; Beta RMSE: 0.053
[Epoch 6500] ELBO: 1406; Loglik: -1136; Acc.: 0.670; Alpha RMSE: 0.031; Beta RMSE: 0.065
[Epoch 6600] ELBO: 1419; Loglik: -1152; Acc.: 0.677; Alpha RMSE: 0.027; Beta RMSE: 0.033
[Epoch 6700] ELBO: 1406; Loglik: -1139; Acc.: 0.659; Alpha RMSE: 0.037; Beta RMSE: 0.057
[Epoch 6800] ELBO: 1399; Loglik: -1136; Acc.: 0.668; Alpha RMSE: 0.023; Beta RMSE: 0.055
[Epoch 6900] ELBO: 1414; Loglik: -1148; Acc.: 0.656; Alpha RMSE: 0.037; Beta RMSE: 0.048
[Epoch 7000] ELBO: 1400; Loglik: -1131; Acc.: 0.661; Alpha RMSE: 0.034; Beta RMSE: 0.049
[Epoch 7100] ELBO: 1432; Loglik: -1168; Acc.: 0.652; Alpha RMSE: 0.028; Beta RMSE: 0.037
[Epoch 7200] ELBO: 1421; Loglik: -1149; Acc.: 0.668; Alpha RMSE: 0.025; Beta RMSE: 0.045
[Epoch 7300] ELBO: 1405; Loglik: -1144; Acc.: 0.670; Alpha RMSE: 0.036; Beta RMSE: 0.048
[Epoch 7400] ELBO: 1419; Loglik: -1152; Acc.: 0.651; Alpha RMSE: 0.039; Beta RMSE: 0.047
[Epoch 7500] ELBO: 1426; Loglik: -1158; Acc.: 0.653; Alpha RMSE: 0.009; Beta RMSE: 0.046
[Epoch 7600] ELBO: 1385; Loglik: -1123; Acc.: 0.671; Alpha RMSE: 0.043; Beta RMSE: 0.046
[Epoch 7700] ELBO: 1398; Loglik: -1128; Acc.: 0.674; Alpha RMSE: 0.038; Beta RMSE: 0.037
[Epoch 7800] ELBO: 1427; Loglik: -1154; Acc.: 0.653; Alpha RMSE: 0.023; Beta RMSE: 0.044
[Epoch 7900] ELBO: 1393; Loglik: -1122; Acc.: 0.683; Alpha RMSE: 0.050; Beta RMSE: 0.032
[Epoch 8000] ELBO: 1419; Loglik: -1158; Acc.: 0.652; Alpha RMSE: 0.019; Beta RMSE: 0.044
[Epoch 8100] ELBO: 1402; Loglik: -1140; Acc.: 0.668; Alpha RMSE: 0.038; Beta RMSE: 0.026
[Epoch 8200] ELBO: 1441; Loglik: -1170; Acc.: 0.665; Alpha RMSE: 0.019; Beta RMSE: 0.044
[Epoch 8300] ELBO: 1408; Loglik: -1148; Acc.: 0.664; Alpha RMSE: 0.024; Beta RMSE: 0.043
[Epoch 8400] ELBO: 1411; Loglik: -1146; Acc.: 0.662; Alpha RMSE: 0.012; Beta RMSE: 0.049
[Epoch 8500] ELBO: 1404; Loglik: -1116; Acc.: 0.671; Alpha RMSE: 0.022; Beta RMSE: 0.051
[Epoch 8600] ELBO: 1407; Loglik: -1141; Acc.: 0.668; Alpha RMSE: 0.038; Beta RMSE: 0.036
[Epoch 8700] ELBO: 1413; Loglik: -1150; Acc.: 0.670; Alpha RMSE: 0.019; Beta RMSE: 0.046
[Epoch 8800] ELBO: 1401; Loglik: -1141; Acc.: 0.666; Alpha RMSE: 0.013; Beta RMSE: 0.035
[Epoch 8900] ELBO: 1418; Loglik: -1153; Acc.: 0.662; Alpha RMSE: 0.001; Beta RMSE: 0.037
[Epoch 9000] ELBO: 1404; Loglik: -1138; Acc.: 0.670; Alpha RMSE: 0.041; Beta RMSE: 0.038
[Epoch 9100] ELBO: 1383; Loglik: -1126; Acc.: 0.675; Alpha RMSE: 0.030; Beta RMSE: 0.044
[Epoch 9200] ELBO: 1412; Loglik: -1150; Acc.: 0.666; Alpha RMSE: 0.016; Beta RMSE: 0.040
[Epoch 9300] ELBO: 1418; Loglik: -1156; Acc.: 0.661; Alpha RMSE: 0.025; Beta RMSE: 0.038
[Epoch 9400] ELBO: 1425; Loglik: -1160; Acc.: 0.654; Alpha RMSE: 0.029; Beta RMSE: 0.041
[Epoch 9500] ELBO: 1385; Loglik: -1124; Acc.: 0.681; Alpha RMSE: 0.047; Beta RMSE: 0.036
[Epoch 9600] ELBO: 1410; Loglik: -1145; Acc.: 0.660; Alpha RMSE: 0.018; Beta RMSE: 0.046
[Epoch 9700] ELBO: 1416; Loglik: -1149; Acc.: 0.665; Alpha RMSE: 0.053; Beta RMSE: 0.034
[Epoch 9800] ELBO: 1423; Loglik: -1161; Acc.: 0.646; Alpha RMSE: 0.027; Beta RMSE: 0.025
[Epoch 9900] ELBO: 1441; Loglik: -1180; Acc.: 0.651; Alpha RMSE: 0.016; Beta RMSE: 0.047
Elapsed time: 109.56723594665527
True alpha: [-0.508]
Est. alpha: [-0.5470354]
B_PRICE: -0.547
True zeta: [-0.1355 0.4759 -1.5995 0.5256 0.1116 0.5333]
Est. zeta: [-0.14659663 0.44419056 -1.6546605 0.53451145 0.11935437 0.60611176]
B_OperCost: -0.147
B_Range: 0.444
B_EV: -1.655
B_Hybrid: 0.535
B_HighPerf: 0.119
B_MedHighPerf: 0.606

CPU times: user 24min, sys: 1.69 s, total: 24min 1s
Wall time: 1min 52s
The “results” dictionary containts a summary of the results of variational inference, including means of the posterior approximations for the different parameters in the model:
results
{'Estimation time': 109.56723594665527,
'Est. alpha': array([-0.5470354], dtype=float32),
'Est. zeta': array([-0.14659663, 0.44419056, -1.6546605 , 0.53451145, 0.11935437,
0.60611176], dtype=float32),
'Est. beta_n': array([[-0.177707 , 1.0113539 , -0.7460837 , 1.0617609 , 0.27421653,
0.60733056],
[ 0.02112017, 0.7033575 , -1.8779882 , 0.1324655 , 0.2555254 ,
0.88266927],
[-0.46877855, 0.40089908, -1.4657907 , 0.5733812 , 0.4453968 ,
0.4732788 ],
[ 0.4595958 , 0.24240786, 0.12841338, 1.1730931 , -0.71696794,
0.57585335],
[ 0.1390024 , 0.15335542, -2.420903 , -0.7401863 , -0.30360064,
-0.5032247 ],
[-0.32350585, 0.45646194, -1.7088293 , 1.2187247 , 0.5599405 ,
1.4971392 ],
[-0.6117315 , 0.46711022, -3.4774067 , -0.32910496, 0.6010447 ,
-0.19895634],
[ 0.92446315, 0.6807212 , -0.22696076, 1.7059586 , -0.2456472 ,
1.2439083 ],
[-0.09174854, 0.70348144, -0.8749374 , 0.5783474 , 0.04826947,
1.1662296 ],
[ 0.12984519, 0.28342915, -2.7004666 , -0.53102976, 0.04108321,
0.8549261 ],
[-0.7333127 , 0.9637862 , -1.5531871 , 0.641572 , -0.47131714,
0.73706514],
[ 0.8362044 , 0.2756676 , -0.37897837, 1.8493651 , -0.13540296,
1.6530721 ],
[ 0.25754914, -0.09405021, -3.172244 , -0.51131094, 0.07619096,
0.83034545],
[ 0.04579633, 0.05631334, -0.52795535, 1.0814569 , 0.5003116 ,
0.6081643 ],
[-0.6590507 , 0.8525584 , -2.4398053 , 1.554091 , 0.7125349 ,
1.1074538 ],
[ 0.30353227, -0.18303353, 0.28051695, 2.706717 , -0.29144388,
0.667579 ],
[-0.29875475, 0.7697504 , -2.109047 , 1.0144535 , -0.31672242,
1.4924587 ],
[-0.21054259, 0.11951059, -3.0621219 , -0.7193891 , 0.34762073,
0.5762015 ],
[ 0.04357907, 0.3549275 , -1.1500883 , 0.5528347 , 0.00699469,
1.1988566 ],
[-0.02196664, 0.65842885, -1.7127484 , 1.5090653 , -0.51405674,
0.67175627],
[-0.1952134 , 0.27798685, -1.690919 , 0.70789415, 0.61291754,
1.1184491 ],
[ 0.5696251 , 0.77130747, -0.9967521 , 1.1786315 , -0.37863505,
1.2350929 ],
[ 0.2750134 , 0.22697286, -1.4674928 , 0.31419078, -0.3145748 ,
0.9684196 ],
[ 0.09756894, 0.55155325, -0.9178999 , 0.8059993 , -0.30238897,
0.36997637],
[-0.49000105, 0.89142257, -0.12127484, 0.79534197, 0.7437891 ,
0.0130516 ],
[-0.16105093, 0.41348612, -2.6368873 , 0.21805559, 0.40065786,
0.62451535],
[-0.5788692 , 0.29272774, -2.853478 , 0.6862823 , 0.28825903,
0.43893313],
[-0.4745806 , 0.41062862, -2.6093678 , 0.5830976 , -0.3794036 ,
0.5791176 ],
[ 0.13417561, 0.51926494, -2.007187 , 0.34334454, 0.8999807 ,
0.5003878 ],
[-0.36330462, 0.7096105 , -0.6794857 , 0.55922747, 1.3516222 ,
0.4547909 ],
[-0.8217943 , 0.48797315, -2.7661672 , -0.60169613, 0.15072615,
0.0810528 ],
[ 0.03605545, 0.05726599, -2.7004492 , -0.43817383, -0.72704196,
1.0329393 ],
[-0.54753566, 0.28528103, -3.3279772 , -0.58225995, 0.19522145,
-0.09694449],
[-1.1952757 , 0.9741091 , -0.58089614, 1.546057 , 0.7921876 ,
-0.13799347],
[-0.04887389, 0.1884401 , 0.8485187 , 2.4413493 , -0.20275696,
0.0700904 ],
[ 0.25447348, 0.4195048 , -2.102889 , 0.3986832 , -0.6470184 ,
1.0526685 ],
[ 0.5010214 , 0.09405853, -2.201361 , -0.2316083 , -0.23557864,
1.5389867 ],
[-0.07038663, 0.46781984, -1.9639238 , -0.24734204, 0.33527586,
0.29294652],
[ 0.1503094 , 0.23298982, -1.9607797 , 1.1741935 , -0.8740111 ,
0.97417295],
[-0.37940443, 0.37554666, -1.8523827 , 0.03588488, -0.31177834,
-0.42141438],
[-0.5819648 , 0.2691778 , -2.4946191 , 1.2719803 , -0.34214464,
0.61334294],
[ 0.29541776, 0.42502668, -1.5860796 , 0.18306787, 0.19034417,
0.7087213 ],
[ 0.141139 , 0.14624447, -1.3941545 , 1.6764681 , 0.00801994,
1.1356167 ],
[ 0.2878245 , 0.06642059, -1.4519833 , 1.4811226 , 0.06787331,
1.2093438 ],
[ 0.2366184 , 0.33093786, -1.7743711 , -0.6150913 , 0.23874493,
0.68332964],
[ 0.30008066, 0.17750189, -1.1383712 , 1.7899371 , -0.36139536,
0.68949676],
[-0.7036854 , 1.1573654 , -1.3029658 , 1.284133 , 0.85282916,
0.3058794 ],
[-0.3163713 , 0.2859441 , -1.7541716 , 0.9449666 , -0.27674296,
0.4321956 ],
[ 0.16240735, 0.7487032 , -1.2054157 , 0.35689217, -0.8797293 ,
0.43226606],
[-0.2532407 , 1.0531042 , -0.72793233, 2.361825 , -0.4945146 ,
1.0822694 ],
[-0.3206012 , 0.6440221 , -1.8012266 , -0.13124613, 0.08223417,
0.34207818],
[-0.4707393 , 0.30571786, -2.598116 , -0.76840836, -0.21016322,
0.06640501],
[ 0.22311091, 0.4311655 , -1.6283368 , 0.37116113, -0.02935001,
0.5076597 ],
[-0.5732584 , 0.764475 , 0.85053515, 2.1707242 , 1.0849651 ,
0.59486187],
[-0.45556068, -0.0309206 , -2.8480577 , -1.3389957 , 0.27272677,
-0.2950476 ],
[ 0.0232279 , 0.5938791 , -1.6596189 , 0.44760114, 0.13615698,
0.9790388 ],
[-0.7535434 , 0.54067254, -3.2077584 , -1.1829981 , 0.5262319 ,
-0.54130614],
[-0.27430007, 0.23566282, -2.5115805 , -1.8501102 , 0.45072752,
-0.11731756],
[ 0.12438589, 0.79021436, -1.8448889 , 0.12681451, 0.23276518,
0.59742624],
[ 0.23092309, 0.04065663, -2.9042256 , -1.1874903 , 0.03304958,
0.68960273],
[-0.18408623, 0.6796636 , -0.61042285, 0.92572254, 0.6931009 ,
1.2469733 ],
[ 0.40842015, 1.4191489 , 0.9158343 , 3.7165003 , 0.05681641,
0.9118288 ],
[-0.6732889 , 0.8358001 , -1.4125152 , -0.51843876, 0.61122996,
0.0757769 ],
[ 0.15961568, 0.5225313 , 0.70618683, 1.3960788 , -0.04634688,
0.6307678 ],
[-0.05050385, 0.52446634, -1.0928475 , 1.457288 , 0.00473524,
1.2484493 ],
[-0.2301843 , 0.47248495, -2.2114363 , -0.72258615, 0.68282944,
0.19176461],
[ 0.7929539 , 0.1988315 , -0.46857688, 2.066399 , -0.23438525,
1.3617845 ],
[-0.62669545, 0.54139423, -2.7335048 , -0.44224697, 0.9262033 ,
0.43287283],
[ 0.32092977, 0.02560924, -1.4327469 , 0.35616314, -0.32611033,
0.72870916],
[-0.2547146 , 0.6896984 , -1.2926666 , 0.6888522 , 0.60718244,
0.60635334],
[ 0.35429716, 0.03550514, 0.12470555, -0.26977608, 0.25745317,
1.0862434 ],
[-0.42555618, -0.05655691, -3.9245088 , -1.4244858 , -0.17335439,
-0.09013958],
[-0.4651504 , 0.7434606 , -2.239503 , 0.75055176, 0.54644156,
0.34595886],
[ 0.63010734, 0.6159445 , 1.2773534 , 2.219995 , 0.10865933,
1.9697647 ],
[-0.4101898 , 0.44985878, -3.3480463 , -1.4234467 , 0.2384841 ,
-0.43555364],
[-0.42685336, 0.13127893, -1.8134395 , -0.4010766 , 0.00424892,
0.17551205],
[ 0.35147732, 0.83848727, -0.86693704, 0.47967935, -0.16607812,
0.75730807],
[-0.93755615, 0.24146217, -3.8447652 , -0.6739151 , -0.03253016,
-0.10361817],
[-0.46006858, -0.08233909, -2.877979 , -1.4984815 , 0.19420421,
-0.50902283],
[ 0.1336475 , 0.08712339, -2.5506248 , -0.44947082, 0.06793094,
0.29414123],
[-0.05205289, 0.46750787, -0.562862 , 0.79309595, 0.9287599 ,
0.91958445],
[-0.43299416, 0.25338745, -3.470944 , -0.13188887, -0.19592729,
0.3431562 ],
[-0.34419057, 0.50629777, -1.8314816 , 1.0059108 , -0.19910304,
-0.14778145],
[ 0.06375652, 0.02395127, -2.800709 , -0.85930014, -0.09970355,
0.36512405],
[ 0.47356573, -0.09470741, 1.0054905 , 2.1352615 , -0.6509158 ,
1.2955196 ],
[-1.0344052 , 0.5275322 , -1.9389765 , 0.9547346 , 0.47533652,
0.4593783 ],
[-0.5062121 , 0.9244115 , -2.1304772 , -0.23069848, 0.780848 ,
0.6200896 ],
[-0.91686946, 0.46768087, -2.422462 , 1.3331742 , 0.06141543,
0.5350673 ],
[-0.6934724 , 0.05096803, -3.2469478 , -0.17226821, 0.42541552,
0.06623555],
[ 0.00524725, 1.0506604 , -0.0461508 , 2.1009803 , 0.01012792,
1.1727788 ],
[ 0.52080137, -0.03548645, -0.06771066, 1.1446848 , -0.42017722,
0.87145203],
[-0.2316351 , 0.27021438, -1.8255385 , 0.4829427 , 0.08318917,
0.82220113],
[-0.46092078, 0.12618686, -2.4406548 , 0.98854434, 0.09932676,
0.14322324],
[-0.5501464 , 0.36832312, -2.2889638 , -0.5813244 , 0.43697608,
0.09164318],
[-0.49933255, 0.1601074 , -2.3500824 , 0.1463487 , -0.09433421,
0.4768962 ],
[-0.6585893 , 0.3187867 , -1.3876522 , 0.62179625, 0.6058429 ,
0.7718066 ],
[-0.20308788, 0.34395215, -2.0297759 , 0.8884229 , 0.12686421,
1.0142841 ],
[-0.65177286, -0.24952668, -2.3811662 , 0.91291296, 0.087266 ,
0.08013618],
[-0.18831654, 1.2571774 , -1.0406775 , 1.5704321 , 0.16461624,
0.8958613 ],
[-0.8817895 , 0.5991308 , -2.9539628 , -0.31898227, 0.3813351 ,
0.22220296]], dtype=float32),
'ELBO': 1408.323486328125,
'Loglikelihood': -1140.353271484375,
'Accuracy': 0.6812668442726135}