Skip to content

Return best n categories when predicting #60

@sadrasabouri

Description

@sadrasabouri

Is your feature request related to a problem? Please describe.
In my case (using PLDA for information retrieval) it'd better to predict [let's say] best n options instead of the best one for a given query.
I figured out that the predict method does not support this feature. But it can be done using calc_logp_pp_categories method.

Describe the solution you'd like
My fast solution for solving this was to use bellow code:

def predict_doc_at(query, k=1):
    """
    Predict which document is matched to the given query.

    :param query: input query
    :type query: str (or list of strs)
    :param k: number of returning docs
    :type k: int 
    :return: return the document name
    """
    query_embedding = get_embeddings(query)
    data = PLDA_classifier.model.transform(query_embedding,
                                           from_space='D',
                                           to_space='U_model')
    logpps_k, K = PLDA_classifier.calc_logp_pp_categories(data,
                                                          False)
    best_k_idx = logpps_k.argsort()[::-1][:k]
    predictions = K[best_k_idx]
    return accuracy, predictions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions