# -*- coding: utf-8 -*-
# =============================================================================
# Imports nécessaires
# =============================================================================

from math import sqrt,pi
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# =============================================================================
# Constantes globales utiles
# =============================================================================
G=6.67e-11          #Cste de gravitation USI
MT=5.97e24          #Masse de la Terre en kg
m=730               #masse d'un satellite en kg
RT=6.370e6          #Rayon de la Terre en m

r0= 23616e3 + RT
v0= sqrt(G*MT/r0)   #vitesse initiale pour une trajectoire circulaire

# =============================================================================
# Tracé de la trajectoire normalisée
# =============================================================================
def traject(CI, dureemax):
    def deriv(Y,t):
        r2=Y[0]**2+Y[1]**2 #OM²
        r=sqrt(r2) #OM
        return [Y[2],Y[3],-G*MT/(r2)*Y[0]/r,-G*MT/(r2)*Y[1]/r]
    tabt=np.linspace(0,dureemax,1000)
    res=odeint(deriv,CI,tabt)
    return tabt, res

# =============================================================================
# Tracé des graphes
# =============================================================================
plt.figure(1)
X0=[r0,0,0,v0]
tabt,res=traject(X0,100000)
plt.plot(res[:,0]/RT,res[:,1]/RT,label='Trajectoire')
plt.plot(np.cos(np.linspace(0,2*pi)),np.sin(np.linspace(0,2*pi)),'k-') #Terre
plt.text(0, 0, 'Terre', horizontalalignment='center', verticalalignment='center', fontsize=12)
plt.xlabel("$x/R_T$")
plt.ylabel("$y/R_T$")
plt.axis('equal')
plt.legend()
plt.show()

#%% erreur distance
plt.figure('Erreur distance')
X0=[r0,0,0,v0*0.99]
tabt,res=traject(X0,100000)
plt.plot(res[:,0]/RT,res[:,1]/RT,label='Trajectoire modif')

X0=[r0,0,0,v0]
tabt,res=traject(X0,100000)
plt.plot(res[:,0]/RT,res[:,1]/RT,label='Trajectoire')

plt.plot(np.cos(np.linspace(0,2*pi)),np.sin(np.linspace(0,2*pi)),'k-') #Terre
plt.text(0, 0, 'Terre', horizontalalignment='center', verticalalignment='center', fontsize=12)
plt.xlabel("$x/R_T$")
plt.ylabel("$y/R_T$")
plt.axis('equal')
plt.legend(loc='upper right')
plt.show()

#%% erreur angle
plt.figure('Erreur angle')
theta=1*np.pi/180 #( ! en radians)
X0=[r0,0,v0*np.sin(theta),v0*np.cos(theta)]
tabt,res=traject(X0,100000)
plt.plot(res[:,0]/RT,res[:,1]/RT,label='Trajectoire modif')
Tableau_R_modif=np.sqrt(res[:,0]*res[:,0]+res[:,1]*res[:,1]);

X0=[r0,0,0,v0]
tabt,res=traject(X0,100000)
plt.plot(res[:,0]/RT,res[:,1]/RT,label='Trajectoire')
Tableau_R_base=np.sqrt(res[:,0]*res[:,0]+res[:,1]*res[:,1]);

diff_tableau=np.abs(Tableau_R_modif-Tableau_R_base)
max_dist=np.max(diff_tableau);
print(max_dist/1000) #en kilomètre
plt.plot(np.cos(np.linspace(0,2*pi)),np.sin(np.linspace(0,2*pi)),'k-') #Terre
plt.text(0, 0, 'Terre', horizontalalignment='center', verticalalignment='center', fontsize=12)
plt.xlabel("$x/R_T$")
plt.ylabel("$y/R_T$")
plt.axis('equal')
plt.legend(loc='upper right')
plt.show()