paste3.helper.kl_divergence
- paste3.helper.kl_divergence(a_exp_dissim, b_exp_dissim, is_generalized=False)[source]
Calculates the Kullback-Leibler divergence (KL) or generalized KL divergence between two distributions.
- Parameters:
a_exp_dissim (torch.Tensor) -- A tensor representing the first probability distribution.
b_exp_dissim (torch.Tensor) -- A tensor representing the second probability distribution.
is_generalized (bool) -- If True, computes generalized KL divergence between two distribution
- Returns:
divergence -- A tensor containing the Kullback-Leibler divergence for each sample.
- Return type:
torch.Tensor