Amortized variational inference

This is a demo of stochastic variational inference (SVI) using simulated data for 500 individuals

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import sys
sys.path.insert(0, "/home/rodr/code/amortized-mxl-dev/release") 

import logging
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Fix random seed for reproducibility
np.random.seed(42)

Generate simulated data

We predefine the fixed effects parameters (true_alpha) and random effects parameters (true_beta), as well as the covariance matrix (true_Omega), and sample simulated choice data for 50000 respondents (num_resp), each with 10 choice situations (num_menus). The number of choice alternatives is set to 5.

from core.dcm_fakedata import generate_fake_data_wide

num_resp = 50000
num_menus = 10
num_alternatives = 5

true_alpha = np.array([-0.8, 0.8, 1.2])
true_beta = np.array([-0.8, 0.8, 1.0, -0.8, 1.5])
# dynamic version of generating Omega
corr = 0.8
scale_factor = 1.0
true_Omega = corr*np.ones((len(true_beta),len(true_beta))) # off-diagonal values of cov matrix
true_Omega[np.arange(len(true_beta)), np.arange(len(true_beta))] = 1.0 # diagonal values of cov matrix
true_Omega *= scale_factor

df = generate_fake_data_wide(num_resp, num_menus, num_alternatives, true_alpha, true_beta, true_Omega)
df.head()
Generating fake data...
Error: 46.0406
ALT1_XF1 ALT1_XF2 ALT1_XF3 ALT1_XR1 ALT1_XR2 ALT1_XR3 ALT1_XR4 ALT1_XR5 ALT2_XF1 ALT2_XF2 ... ALT5_XR1 ALT5_XR2 ALT5_XR3 ALT5_XR4 ALT5_XR5 choice indID menuID obsID ones
0 0.374540 0.950714 0.731994 0.340311 0.768842 0.936032 0.481857 0.802253 0.598658 0.156019 ... 0.809189 0.914568 0.789045 0.110962 0.312995 1 0 0 0 1
1 0.183405 0.304242 0.524756 0.194837 0.796218 0.651730 0.237908 0.199814 0.431945 0.291229 ... 0.536775 0.940645 0.826931 0.510653 0.914976 4 0 1 1 1
2 0.607545 0.170524 0.065052 0.485441 0.186343 0.322737 0.514758 0.090070 0.948886 0.965632 ... 0.595050 0.947324 0.927226 0.962351 0.211986 3 0 2 2 1
3 0.662522 0.311711 0.520068 0.442947 0.350437 0.031725 0.790190 0.481335 0.546710 0.184854 ... 0.846481 0.875187 0.624215 0.993924 0.414757 4 0 3 3 1
4 0.388677 0.271349 0.828738 0.011319 0.689977 0.119363 0.251674 0.674041 0.356753 0.280935 ... 0.890742 0.984823 0.531394 0.295977 0.364010 2 0 4 4 1

5 rows × 45 columns

Mixed Logit specification

We now make use of the developed formula interface to specify the utilities of the mixed logit model.

We begin by defining the fixed effects parameters, the random effects parameters, and the observed variables. This creates instances of Python objects that can be put together to define the utility functions for the different alternatives.

Once the utilities are defined, we collect them in a Python dictionary mapping alternative names to their corresponding expressions.

from core.dcm_interface import FixedEffect, RandomEffect, ObservedVariable

# define fixed effects parameters
B_XF1 = FixedEffect('BETA_XF1')
B_XF2 = FixedEffect('BETA_XF2')
B_XF3 = FixedEffect('BETA_XF3')

# define random effects parameters
B_XR1 = RandomEffect('BETA_XR1')
B_XR2 = RandomEffect('BETA_XR2')
B_XR3 = RandomEffect('BETA_XR3')
B_XR4 = RandomEffect('BETA_XR4')
B_XR5 = RandomEffect('BETA_XR5')

# define observed variables
for attr in df.columns:
    exec("%s = ObservedVariable('%s')" % (attr,attr))

# define utility functions
V1 = B_XF1*ALT1_XF1 + B_XF2*ALT1_XF2 + B_XF3*ALT1_XF3 + B_XR1*ALT1_XR1 + B_XR2*ALT1_XR2 + B_XR3*ALT1_XR3 + B_XR4*ALT1_XR4 + B_XR5*ALT1_XR5
V2 = B_XF1*ALT2_XF1 + B_XF2*ALT2_XF2 + B_XF3*ALT2_XF3 + B_XR1*ALT2_XR1 + B_XR2*ALT2_XR2 + B_XR3*ALT2_XR3 + B_XR4*ALT2_XR4 + B_XR5*ALT2_XR5
V3 = B_XF1*ALT3_XF1 + B_XF2*ALT3_XF2 + B_XF3*ALT3_XF3 + B_XR1*ALT3_XR1 + B_XR2*ALT3_XR2 + B_XR3*ALT3_XR3 + B_XR4*ALT3_XR4 + B_XR5*ALT3_XR5
V4 = B_XF1*ALT4_XF1 + B_XF2*ALT4_XF2 + B_XF3*ALT4_XF3 + B_XR1*ALT4_XR1 + B_XR2*ALT4_XR2 + B_XR3*ALT4_XR3 + B_XR4*ALT4_XR4 + B_XR5*ALT4_XR5
V5 = B_XF1*ALT5_XF1 + B_XF2*ALT5_XF2 + B_XF3*ALT5_XF3 + B_XR1*ALT5_XR1 + B_XR2*ALT5_XR2 + B_XR3*ALT5_XR3 + B_XR4*ALT5_XR4 + B_XR5*ALT5_XR5

# associate utility functions with the names of the alternatives
utilities = {"ALT1": V1, "ALT2": V2, "ALT3": V3, "ALT4": V4, "ALT5": V5}

We are now ready to create a Specification object containing the utilities that we have just defined. Note that we must also specify the type of choice model to be used - a mixed logit model (MXL) in this case.

Note that we can inspect the specification by printing the dcm_spec object.

from core.dcm_interface import Specification

# create MXL specification object based on the utilities previously defined
dcm_spec = Specification('MXL', utilities)
print(dcm_spec)
----------------- MXL specification:
Alternatives: ['ALT1', 'ALT2', 'ALT3', 'ALT4', 'ALT5']
Utility functions:
   V_ALT1 = BETA_XF1*ALT1_XF1 + BETA_XF2*ALT1_XF2 + BETA_XF3*ALT1_XF3 + BETA_XR1_n*ALT1_XR1 + BETA_XR2_n*ALT1_XR2 + BETA_XR3_n*ALT1_XR3 + BETA_XR4_n*ALT1_XR4 + BETA_XR5_n*ALT1_XR5
   V_ALT2 = BETA_XF1*ALT2_XF1 + BETA_XF2*ALT2_XF2 + BETA_XF3*ALT2_XF3 + BETA_XR1_n*ALT2_XR1 + BETA_XR2_n*ALT2_XR2 + BETA_XR3_n*ALT2_XR3 + BETA_XR4_n*ALT2_XR4 + BETA_XR5_n*ALT2_XR5
   V_ALT3 = BETA_XF1*ALT3_XF1 + BETA_XF2*ALT3_XF2 + BETA_XF3*ALT3_XF3 + BETA_XR1_n*ALT3_XR1 + BETA_XR2_n*ALT3_XR2 + BETA_XR3_n*ALT3_XR3 + BETA_XR4_n*ALT3_XR4 + BETA_XR5_n*ALT3_XR5
   V_ALT4 = BETA_XF1*ALT4_XF1 + BETA_XF2*ALT4_XF2 + BETA_XF3*ALT4_XF3 + BETA_XR1_n*ALT4_XR1 + BETA_XR2_n*ALT4_XR2 + BETA_XR3_n*ALT4_XR3 + BETA_XR4_n*ALT4_XR4 + BETA_XR5_n*ALT4_XR5
   V_ALT5 = BETA_XF1*ALT5_XF1 + BETA_XF2*ALT5_XF2 + BETA_XF3*ALT5_XF3 + BETA_XR1_n*ALT5_XR1 + BETA_XR2_n*ALT5_XR2 + BETA_XR3_n*ALT5_XR3 + BETA_XR4_n*ALT5_XR4 + BETA_XR5_n*ALT5_XR5

Num. parameters to be estimated: 8
Fixed effects params: ['BETA_XF1', 'BETA_XF2', 'BETA_XF3']
Random effects params: ['BETA_XR1', 'BETA_XR2', 'BETA_XR3', 'BETA_XR4', 'BETA_XR5']

Once the Specification is defined, we need to define the DCM Dataset object that goes along with it. For this, we instantiate the Dataset class with the Pandas dataframe containing the data in the so-called “wide format”, the name of column in the dataframe containing the observed choices and the dcm_spec that we have previously created.

Note that since this is panel data, we must also specify the name of the column in the dataframe that contains the ID of the respondent (this should be a integer ranging from 0 the num_resp-1).

from core.dcm_interface import Dataset

# create DCM dataset object
dcm_dataset = Dataset(df, 'choice', dcm_spec, resp_id_col='indID')
Preparing dataset...
	Model type: MXL
	Num. observations: 500000
	Num. alternatives: 5
	Num. respondents: 50000
	Num. menus: 10
	Observations IDs: [     0      1      2 ... 499997 499998 499999]
	Alternative IDs: None
	Respondent IDs: [    0     0     0 ... 49999 49999 49999]
	Availability columns: None
	Attribute names: ['ALT1_XF1', 'ALT1_XF2', 'ALT1_XF3', 'ALT1_XR1', 'ALT1_XR2', 'ALT1_XR3', 'ALT1_XR4', 'ALT1_XR5', 'ALT2_XF1', 'ALT2_XF2', 'ALT2_XF3', 'ALT2_XR1', 'ALT2_XR2', 'ALT2_XR3', 'ALT2_XR4', 'ALT2_XR5', 'ALT3_XF1', 'ALT3_XF2', 'ALT3_XF3', 'ALT3_XR1', 'ALT3_XR2', 'ALT3_XR3', 'ALT3_XR4', 'ALT3_XR5', 'ALT4_XF1', 'ALT4_XF2', 'ALT4_XF3', 'ALT4_XR1', 'ALT4_XR2', 'ALT4_XR3', 'ALT4_XR4', 'ALT4_XR5', 'ALT5_XF1', 'ALT5_XF2', 'ALT5_XF3', 'ALT5_XR1', 'ALT5_XR2', 'ALT5_XR3', 'ALT5_XR4', 'ALT5_XR5']
	Fixed effects attribute names: ['ALT1_XF1', 'ALT1_XF2', 'ALT1_XF3', 'ALT2_XF1', 'ALT2_XF2', 'ALT2_XF3', 'ALT3_XF1', 'ALT3_XF2', 'ALT3_XF3', 'ALT4_XF1', 'ALT4_XF2', 'ALT4_XF3', 'ALT5_XF1', 'ALT5_XF2', 'ALT5_XF3']
	Fixed effects parameter names: ['BETA_XF1', 'BETA_XF2', 'BETA_XF3', 'BETA_XF1', 'BETA_XF2', 'BETA_XF3', 'BETA_XF1', 'BETA_XF2', 'BETA_XF3', 'BETA_XF1', 'BETA_XF2', 'BETA_XF3', 'BETA_XF1', 'BETA_XF2', 'BETA_XF3']
	Random effects attribute names: ['ALT1_XR1', 'ALT1_XR2', 'ALT1_XR3', 'ALT1_XR4', 'ALT1_XR5', 'ALT2_XR1', 'ALT2_XR2', 'ALT2_XR3', 'ALT2_XR4', 'ALT2_XR5', 'ALT3_XR1', 'ALT3_XR2', 'ALT3_XR3', 'ALT3_XR4', 'ALT3_XR5', 'ALT4_XR1', 'ALT4_XR2', 'ALT4_XR3', 'ALT4_XR4', 'ALT4_XR5', 'ALT5_XR1', 'ALT5_XR2', 'ALT5_XR3', 'ALT5_XR4', 'ALT5_XR5']
	Random effects parameter names: ['BETA_XR1', 'BETA_XR2', 'BETA_XR3', 'BETA_XR4', 'BETA_XR5', 'BETA_XR1', 'BETA_XR2', 'BETA_XR3', 'BETA_XR4', 'BETA_XR5', 'BETA_XR1', 'BETA_XR2', 'BETA_XR3', 'BETA_XR4', 'BETA_XR5', 'BETA_XR1', 'BETA_XR2', 'BETA_XR3', 'BETA_XR4', 'BETA_XR5', 'BETA_XR1', 'BETA_XR2', 'BETA_XR3', 'BETA_XR4', 'BETA_XR5']
	Alternative attributes ndarray.shape: (50000, 10, 40)
	Choices ndarray.shape: (50000, 10)
	Alternatives availability ndarray.shape: (50000, 10, 5)
	Data mask ndarray.shape: (50000, 10)
	Context data ndarray.shape: (50000, 0)
	Neural nets data ndarray.shape: (50000, 0)
Done!

As with the specification, we can inspect the DCM dataset by printing the dcm_dataset object:

print(dcm_dataset)
----------------- DCM dataset:
Model type: MXL
Num. observations: 500000
Num. alternatives: 5
Num. respondents: 50000
Num. menus: 10
Num. fixed effects: 15
Num. random effects: 25
Attribute names: ['ALT1_XF1', 'ALT1_XF2', 'ALT1_XF3', 'ALT1_XR1', 'ALT1_XR2', 'ALT1_XR3', 'ALT1_XR4', 'ALT1_XR5', 'ALT2_XF1', 'ALT2_XF2', 'ALT2_XF3', 'ALT2_XR1', 'ALT2_XR2', 'ALT2_XR3', 'ALT2_XR4', 'ALT2_XR5', 'ALT3_XF1', 'ALT3_XF2', 'ALT3_XF3', 'ALT3_XR1', 'ALT3_XR2', 'ALT3_XR3', 'ALT3_XR4', 'ALT3_XR5', 'ALT4_XF1', 'ALT4_XF2', 'ALT4_XF3', 'ALT4_XR1', 'ALT4_XR2', 'ALT4_XR3', 'ALT4_XR4', 'ALT4_XR5', 'ALT5_XF1', 'ALT5_XF2', 'ALT5_XF3', 'ALT5_XR1', 'ALT5_XR2', 'ALT5_XR3', 'ALT5_XR4', 'ALT5_XR5']

Bayesian Mixed Logit Model in PyTorch

It is now time to perform approximate Bayesian inference on the mixed logit model that we have specified. The generative process of the MXL model that we will be using is the following:

  1. Draw fixed taste parameters \(\boldsymbol\alpha \sim \mathcal{N}(\boldsymbol\lambda_0, \boldsymbol\Xi_0)\)

  2. Draw mean vector \(\boldsymbol\zeta \sim \mathcal{N}(\boldsymbol\mu_0, \boldsymbol\Sigma_0)\)

  3. Draw scales vector \(\boldsymbol\theta \sim \mbox{half-Cauchy}(\boldsymbol\sigma_0)\)

  4. Draw correlation matrix \(\boldsymbol\Psi \sim \mbox{LKJ}(\nu)\)

  5. For each decision-maker \(n \in \{1,\dots,N\}\)

    1. Draw random taste parameters \(\boldsymbol\beta_n \sim \mathcal{N}(\boldsymbol\zeta,\boldsymbol\Omega)\)

    2. For each choice occasion \(t \in \{1,\dots,T_n\}\)

      1. Draw observed choice \(y_{nt} \sim \mbox{MNL}(\boldsymbol\alpha, \boldsymbol\beta_n, \textbf{X}_{nt})\)

where \(\boldsymbol\Omega = \mbox{diag}(\boldsymbol\theta) \times \boldsymbol\Psi \times \mbox{diag}(\boldsymbol\theta)\).

We can instantiate this model from the TorchMXL using the following code. We can the run variational inference to approximate the posterior distribution of the latent variables in the model. Note that since in this case we know the true parameters that were used to generate the simualated choice data, we can pass them to the “infer” method in order to obtain additional information during the ELBO maximization (useful for tracking the progress of VI and for other debugging purposes).

%%time

from core.torch_mxl import TorchMXL

# instantiate MXL model
mxl = TorchMXL(dcm_dataset, batch_size=num_resp, use_inference_net=True, use_cuda=True)

# run Bayesian inference (variational inference)
results = mxl.infer(num_epochs=5000, true_alpha=true_alpha, true_beta=true_beta)
[Epoch     0] ELBO: 1224117; Loglik: -917052; Acc.: 0.243; Alpha RMSE: 0.948; Beta RMSE: 1.023
[Epoch   100] ELBO: 1038406; Loglik: -931396; Acc.: 0.226; Alpha RMSE: 0.408; Beta RMSE: 0.984
[Epoch   200] ELBO: 933971; Loglik: -813260; Acc.: 0.317; Alpha RMSE: 0.182; Beta RMSE: 0.941
[Epoch   300] ELBO: 993428; Loglik: -808890; Acc.: 0.314; Alpha RMSE: 0.072; Beta RMSE: 0.885
[Epoch   400] ELBO: 941356; Loglik: -806167; Acc.: 0.317; Alpha RMSE: 0.029; Beta RMSE: 0.834
[Epoch   500] ELBO: 1078565; Loglik: -801570; Acc.: 0.329; Alpha RMSE: 0.039; Beta RMSE: 0.769
[Epoch   600] ELBO: 908037; Loglik: -806531; Acc.: 0.302; Alpha RMSE: 0.030; Beta RMSE: 0.698
[Epoch   700] ELBO: 955672; Loglik: -786175; Acc.: 0.338; Alpha RMSE: 0.033; Beta RMSE: 0.637
[Epoch   800] ELBO: 900093; Loglik: -790306; Acc.: 0.338; Alpha RMSE: 0.019; Beta RMSE: 0.577
[Epoch   900] ELBO: 903703; Loglik: -791398; Acc.: 0.328; Alpha RMSE: 0.061; Beta RMSE: 0.513
[Epoch  1000] ELBO: 869773; Loglik: -785544; Acc.: 0.343; Alpha RMSE: 0.040; Beta RMSE: 0.467
[Epoch  1100] ELBO: 862214; Loglik: -784105; Acc.: 0.341; Alpha RMSE: 0.010; Beta RMSE: 0.427
[Epoch  1200] ELBO: 886401; Loglik: -783768; Acc.: 0.345; Alpha RMSE: 0.017; Beta RMSE: 0.392
[Epoch  1300] ELBO: 870709; Loglik: -783431; Acc.: 0.345; Alpha RMSE: 0.024; Beta RMSE: 0.348
[Epoch  1400] ELBO: 876040; Loglik: -783374; Acc.: 0.347; Alpha RMSE: 0.014; Beta RMSE: 0.304
[Epoch  1500] ELBO: 863066; Loglik: -782970; Acc.: 0.344; Alpha RMSE: 0.015; Beta RMSE: 0.264
[Epoch  1600] ELBO: 922069; Loglik: -780184; Acc.: 0.348; Alpha RMSE: 0.017; Beta RMSE: 0.232
[Epoch  1700] ELBO: 868486; Loglik: -779272; Acc.: 0.349; Alpha RMSE: 0.023; Beta RMSE: 0.205
[Epoch  1800] ELBO: 870762; Loglik: -779467; Acc.: 0.350; Alpha RMSE: 0.026; Beta RMSE: 0.175
[Epoch  1900] ELBO: 840950; Loglik: -780558; Acc.: 0.344; Alpha RMSE: 0.021; Beta RMSE: 0.152
[Epoch  2000] ELBO: 825965; Loglik: -777642; Acc.: 0.349; Alpha RMSE: 0.021; Beta RMSE: 0.132
[Epoch  2100] ELBO: 855912; Loglik: -777580; Acc.: 0.352; Alpha RMSE: 0.009; Beta RMSE: 0.114
[Epoch  2200] ELBO: 844465; Loglik: -776011; Acc.: 0.353; Alpha RMSE: 0.016; Beta RMSE: 0.100
[Epoch  2300] ELBO: 855954; Loglik: -775478; Acc.: 0.354; Alpha RMSE: 0.018; Beta RMSE: 0.087
[Epoch  2400] ELBO: 833110; Loglik: -773237; Acc.: 0.359; Alpha RMSE: 0.017; Beta RMSE: 0.079
[Epoch  2500] ELBO: 828358; Loglik: -771019; Acc.: 0.354; Alpha RMSE: 0.012; Beta RMSE: 0.078
[Epoch  2600] ELBO: 790738; Loglik: -772078; Acc.: 0.354; Alpha RMSE: 0.009; Beta RMSE: 0.072
[Epoch  2700] ELBO: 813134; Loglik: -770107; Acc.: 0.354; Alpha RMSE: 0.012; Beta RMSE: 0.057
[Epoch  2800] ELBO: 818668; Loglik: -768972; Acc.: 0.356; Alpha RMSE: 0.013; Beta RMSE: 0.053
[Epoch  2900] ELBO: 843629; Loglik: -770014; Acc.: 0.355; Alpha RMSE: 0.011; Beta RMSE: 0.047
[Epoch  3000] ELBO: 805063; Loglik: -767906; Acc.: 0.359; Alpha RMSE: 0.006; Beta RMSE: 0.039
[Epoch  3100] ELBO: 806284; Loglik: -766475; Acc.: 0.359; Alpha RMSE: 0.023; Beta RMSE: 0.042
[Epoch  3200] ELBO: 808099; Loglik: -764901; Acc.: 0.357; Alpha RMSE: 0.024; Beta RMSE: 0.042
[Epoch  3300] ELBO: 807587; Loglik: -765019; Acc.: 0.358; Alpha RMSE: 0.014; Beta RMSE: 0.035
[Epoch  3400] ELBO: 790571; Loglik: -762659; Acc.: 0.359; Alpha RMSE: 0.014; Beta RMSE: 0.040
[Epoch  3500] ELBO: 781814; Loglik: -762395; Acc.: 0.362; Alpha RMSE: 0.011; Beta RMSE: 0.040
[Epoch  3600] ELBO: 796089; Loglik: -758855; Acc.: 0.360; Alpha RMSE: 0.021; Beta RMSE: 0.028
[Epoch  3700] ELBO: 780398; Loglik: -758333; Acc.: 0.362; Alpha RMSE: 0.028; Beta RMSE: 0.031
[Epoch  3800] ELBO: 779150; Loglik: -756011; Acc.: 0.364; Alpha RMSE: 0.026; Beta RMSE: 0.033
[Epoch  3900] ELBO: 782050; Loglik: -756280; Acc.: 0.363; Alpha RMSE: 0.036; Beta RMSE: 0.045
[Epoch  4000] ELBO: 782104; Loglik: -754342; Acc.: 0.363; Alpha RMSE: 0.028; Beta RMSE: 0.039
[Epoch  4100] ELBO: 787934; Loglik: -754237; Acc.: 0.363; Alpha RMSE: 0.027; Beta RMSE: 0.038
[Epoch  4200] ELBO: 787634; Loglik: -752703; Acc.: 0.366; Alpha RMSE: 0.039; Beta RMSE: 0.034
[Epoch  4300] ELBO: 769860; Loglik: -752475; Acc.: 0.365; Alpha RMSE: 0.021; Beta RMSE: 0.040
[Epoch  4400] ELBO: 770447; Loglik: -752565; Acc.: 0.363; Alpha RMSE: 0.030; Beta RMSE: 0.032
[Epoch  4500] ELBO: 770664; Loglik: -749922; Acc.: 0.366; Alpha RMSE: 0.029; Beta RMSE: 0.040
[Epoch  4600] ELBO: 764101; Loglik: -749379; Acc.: 0.366; Alpha RMSE: 0.035; Beta RMSE: 0.038
[Epoch  4700] ELBO: 770784; Loglik: -748203; Acc.: 0.367; Alpha RMSE: 0.029; Beta RMSE: 0.032
[Epoch  4800] ELBO: 769441; Loglik: -749260; Acc.: 0.365; Alpha RMSE: 0.032; Beta RMSE: 0.039
[Epoch  4900] ELBO: 758093; Loglik: -746343; Acc.: 0.368; Alpha RMSE: 0.025; Beta RMSE: 0.039
Elapsed time: 608.7204344272614 

True alpha: [-0.8  0.8  1.2]
Est. alpha: [-0.78864455  0.7716446   1.1562718 ]
	BETA_XF1: -0.789
	BETA_XF2: 0.772
	BETA_XF3: 1.156

True zeta: [-0.8  0.8  1.  -0.8  1.5]
Est. zeta: [-0.7907126   0.7485124   0.96251327 -0.81944495  1.4549862 ]
	BETA_XR1: -0.791
	BETA_XR2: 0.749
	BETA_XR3: 0.963
	BETA_XR4: -0.819
	BETA_XR5: 1.455
../_images/avi-sim10000_13_1.png
CPU times: user 37min 41s, sys: 6min 34s, total: 44min 15s
Wall time: 10min 12s

On a NVIDIA RTX 2080 GPU this took just over 10 minutes, which is not bad for a dataset of 50000 individual and 10 menus per individual :-)

The “results” dictionary containts a summary of the results of variational inference, including means of the posterior approximations for the different parameters in the model:

results
{'Estimation time': 608.7204344272614,
 'Est. alpha': array([-0.78864455,  0.7716446 ,  1.1562718 ], dtype=float32),
 'Est. zeta': array([-0.7907126 ,  0.7485124 ,  0.96251327, -0.81944495,  1.4549862 ],
       dtype=float32),
 'Est. beta_n': array([[-0.8350792 ,  0.7905707 ,  0.9477427 , -0.81516874,  1.4440804 ],
        [-0.83584404,  0.79179126,  0.9485877 , -0.8159292 ,  1.4461908 ],
        [-0.80974424,  0.7707391 ,  0.9281843 , -0.7975924 ,  1.4110804 ],
        ...,
        [-0.80880046,  0.7670302 ,  0.9265256 , -0.7959152 ,  1.4057504 ],
        [-0.8118125 ,  0.76915276,  0.92753565, -0.7983073 ,  1.4084319 ],
        [-0.8383087 ,  0.79370105,  0.95091426, -0.81761533,  1.4495381 ]],
       dtype=float32),
 'ELBO': 760327.625,
 'Loglikelihood': -747095.875,
 'Accuracy': 0.3681280016899109}

This interface is currently being improved to include additional output information, but additional information can be obtained from the attributes of the “mxl” object for now.