Obtains predictions of topics for new documents from a fitted LDA model

# S3 method for lda_topic_model
predict(object, newdata, method = c("gibbs",
  "dot"), iterations = NULL, burnin = -1, ...)

Arguments

object

a fitted object of class lda_topic_model

newdata

a DTM or TCM of class dgCMatrix or a numeric vector

method

one of either "gibbs" or "dot". If "gibbs" Gibbs sampling is used and iterations must be specified.

iterations

If method = "gibbs", an integer number of iterations for the Gibbs sampler to run. A future version may include automatic stopping criteria.

burnin

If method = "gibbs", an integer number of burnin iterations. If burnin is greater than -1, the entries of the resulting "theta" matrix are an average over all iterations greater than burnin.

...

Other arguments to be passed to TmParallelApply

Value

a "theta" matrix with one row per document and one column per topic

Examples

if (FALSE) { # load some data data(nih_sample_dtm) # fit a model set.seed(12345) m <- FitLdaModel(dtm = nih_sample_dtm[1:20,], k = 5, iterations = 200, burnin = 175) str(m) # predict on held-out documents using gibbs sampling "fold in" p1 <- predict(m, nih_sample_dtm[21:100,], method = "gibbs", iterations = 200, burnin = 175) # predict on held-out documents using the dot product method p2 <- predict(m, nih_sample_dtm[21:100,], method = "dot") # compare the methods barplot(rbind(p1[1,],p2[1,]), beside = TRUE, col = c("red", "blue")) }