-
Notifications
You must be signed in to change notification settings - Fork 1
/
KL_divergence.py
53 lines (43 loc) · 1.36 KB
/
KL_divergence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 26 16:05:57 2020
@author: vkchlt0297
"""
from matplotlib import pyplot
from math import log2
import numpy as np
#Define event
events=['red','green','blue']
p=[0.10,0.40,0.50]
q=[0.80,0.15,0.05]
print('p=%.3f q=%.3f' % (sum(p),sum(q)))
pyplot.subplot(2,1,1)
pyplot.bar(events, p)
# plot second distribution
pyplot.subplot(2,1,2)
pyplot.bar(events, q)
# show the plot
pyplot.show()
def kl_divergence(p,q):
return sum(p[i]*log2(p[i]/q[i]) for i in range(len(p)))
def js_divergence(p, q):
m = 0.5 * (p + q)
return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)
kl_pq = kl_divergence(p, q)
#Note directly using p and q to calculate JS_DIVERGENCE RESULTS IN AN ERROR
#The issue here is the multiply operation by list is treated as creating N copies
#for example [1]*5 results in a list with1 being duplicated 5 times.
#Now when I do [1]*0.5 an error will pop up
#can't multiply sequence by non-int of type 'float'
#So you can either convert it into a numpy array(Smart and efficient)
p = np.asarray([0.10, 0.40, 0.50])
q = np.asarray([0.80, 0.15, 0.05])
js_pq=js_divergence(p,q)
print('KL(P || Q): %.3f bits' % kl_pq)
print('JS(P || Q): %.3f bits' % js_pq)
# calculate (Q || P)
kl_qp = kl_divergence(q, p)
js_qp=js_divergence(q,p)
print('KL(Q || P): %.3f bits' % kl_qp)
print('JS(Q || P): %.3f bits' % js_qp)