import sympy as sym
import numpy as np
from IPython.display import display,Math
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
print('Critical points: 極値')
display(Math('f(x) = -x^4 + 3x^2'))
display(Math('\\frac{df}{dx} = -4x^3 + 6x'))
display(Math('-4x^3 + 6x = 0'))
display(Math('x = 0, \\pm \\sqrt{\\frac{3}{2}}'))
x = np.linspace(-5,5,1001)
fx = x**2 * np.exp(-x**2)
dfx = np.diff(fx)/(x[1]-x[0]) # df/dx (differential of x)
# print(np.diff(fx))
print(x[1]-x[0])
localmax = find_peaks(fx)[0]
print(localmax)
localmin = find_peaks(-fx)[0]
print(localmin)
print('The critical points are ' + str(x[localmax]) + ' ' + str(x[localmin]))
plt.plot(x,fx)
plt.plot(x[0:-1],dfx)
plt.plot(x[localmax],fx[localmax],'ro')
plt.plot(x[localmin],fx[localmin],'gs')
plt.show()
x = sym.symbols('x')
fx = x**2 * sym.exp(-x**2)
dfx = sym.diff(fx)
print(dfx)
critpoints = sym.solve(dfx)
print(critpoints)