Hướng dẫn plot multivariate gaussian python

I think this is indeed not very friendly. I will write here the code and explain why it works.

The equation of a multivariate gaussian is as follows:

Hướng dẫn plot multivariate gaussian python

In the 2D case,

and

are 2D column vectors,

is a 2x2 covariance matrix and n=2.

So in the 2D case, the vector

is actually a point (x,y), for which we want to compute function value, given the 2D mean vector

, which we can also write as (mX, mY), and the covariance matrix

.

To make it more friendly to implement, let's compute the result of

:

So

is the column vector (x - mX, y - mY). Therefore, the result of computing

is the 2D row vector:

(CI[0,0] * (x - mX) + CI[1,0] * (y - mY) , CI[0,1] * (x - mX) + CI[1,1] * (y - mY)), where CI is the inverse of the covariance matrix, shown in the equation as

, which is a 2x2 matrix, like

is.

Then, the current result, which is a 2D row vector, is multiplied (inner product) by the column vector

, which finally gives us the scalar:

CI[0,0](x - mX)^2 + (CI[1,0] + CI[0,1])(y - mY)(x - mX) + CI[1,1](y - mY)^2

This is going to be easier to implement this expression using NumPy, in comparison to

, even though they have the same value.
>>> m = np.array([[0.2],[0.6]])  # defining the mean of the Gaussian (mX = 0.2, mY=0.6)
>>> cov = np.array([[0.7, 0.4], [0.4, 0.25]])   # defining the covariance matrix
>>> cov_inv = np.linalg.inv(cov)  # inverse of covariance matrix
>>> cov_det = np.linalg.det(cov)  # determinant of covariance matrix
# Plotting
>>> x = np.linspace(-2, 2)
>>> y = np.linspace(-2, 2)
>>> X,Y = np.meshgrid(x,y)
>>> coe = 1.0 / ((2 * np.pi)**2 * cov_det)**0.5
>>> Z = coe * np.e ** (-0.5 * (cov_inv[0,0]*(X-m[0])**2 + (cov_inv[0,1] + cov_inv[1,0])*(X-m[0])*(Y-m[1]) + cov_inv[1,1]*(Y-m[1])**2))
>>> plt.contour(X,Y,Z)

>>> plt.show()

The result:

This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters

import numpy as np
import pdb
from matplotlib import pyplot as plt
from scipy.stats import multivariate_normal
def gauss2d(mu, sigma, to_plot=False):
w, h = 100, 100
std = [np.sqrt(sigma[0, 0]), np.sqrt(sigma[1, 1])]
x = np.linspace(mu[0] - 3 * std[0], mu[0] + 3 * std[0], w)
y = np.linspace(mu[1] - 3 * std[1], mu[1] + 3 * std[1], h)
x, y = np.meshgrid(x, y)
x_ = x.flatten()
y_ = y.flatten()
xy = np.vstack((x_, y_)).T
normal_rv = multivariate_normal(mu, sigma)
z = normal_rv.pdf(xy)
z = z.reshape(w, h, order='F')
if to_plot:
plt.contourf(x, y, z.T)
plt.show()
return z
MU = [50, 70]
SIGMA = [75.0, 90.0]
z = gauss2d(MU, SIGMA, True)