paste3.helper.kl_divergence

paste3.helper.kl_divergence(a_exp_dissim, b_exp_dissim, is_generalized=False)[source]

Calculates the Kullback-Leibler divergence (KL) or generalized KL divergence between two distributions.

Parameters:
  • a_exp_dissim (torch.Tensor) -- A tensor representing the first probability distribution.

  • b_exp_dissim (torch.Tensor) -- A tensor representing the second probability distribution.

  • is_generalized (bool) -- If True, computes generalized KL divergence between two distribution

Returns:

divergence -- A tensor containing the Kullback-Leibler divergence for each sample.

Return type:

torch.Tensor