#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 12 19:02:11 2018

@author: mischaknabenhans
"""
import numpy as np

class position:
    def __init__(self,x,y,z):
        self.__x=x
        self.__y=y
        self.__z=z
    
    def getPos(self, coords="cartesian"):
        if coords == "cartesian":   
            return np.array([self.__x,self.__y,self.__z])
       
        elif coords == "cylindrical":
            r = np.sqrt(self.__x**2 + self.__y**2)
            phi = np.arctan2(self.__y,self.__x)
            
            return np.array([r,phi,self.__z])
    
        elif coords == "spherical":
            r = np.sqrt(self.__x**2 + self.__y**2 + self.__z**2)
            theta = np.arccos(self.__z/r)
            phi = np.arctan2(self.__y,self.__x)
            
            return [r,theta,phi]
        else:
            print("No valid coordinate system defined. Choose from 'cartesian', 'cylindrical' or 'spherical'!\n")
    
    def setPos(self,x,y,z):
        self.__x=x
        self.__y=y
        self.__z=z
    
class velocity:
    def __init__(self,vx,vy,vz):
        self.__vx=vx
        self.__vy=vy
        self.__vz=vz
        
    def getVel(self):
        return np.array([self.__vx,self.__vy,self.__vz])
    
    def setVel(self,vx,vy,vz):
        self.__vx=vx
        self.__vy=vy
        self.__vz=vz
        
class KeplerBody:
    def __init__(self,name,m,x,y,z,vx,vy,vz):
        self.__name = name
        self.__m = m
        self.__pos = position(x,y,z)
        self.__vel = velocity(vx,vy,vz)
        
    def getName(self):
        return self.__name
    
    def getMass(self):
        return self.__m
    
    def getPos(self, coordsyst="cartesian"):
        return self.__pos.getPos(coords=coordsyst)
    
    def getVel(self):
        return self.__vel.getVel()
    
    def setPos(self,x,y,z):
        self.__pos.setPos(x,y,z)
    
    def setVel(self,vx,vy,vz):
        self.__vel.setVel(vx,vy,vz)
    
    def getDistVec(self, KeplerObj, Norm=False):
        vec = KeplerObj.__pos.getPos() - self.__pos.getPos()
        
        if (not Norm):
            return vec
        else:
            return np.sqrt(np.dot(vec,vec))
    
    def getLinMomentum(self, Norm=False):
        p = self.getMass() * self.getVel()
        
        if Norm:
            return np.sqrt(np.dot(p,p))
        else:
            return p
    
    def getEpot(self,KeplerBodyList):
        k = 0.01720209895 # Gaussian constant of gravity in [AU^(3/2)/(M_solar^(1/2)*d)]
        G = k*k #Newton's constant in AU^3/(M_solar * day^2)
        M = self.getMass()

        U=G*M * sum([body.getMass()/self.getDistVec(body,Norm=True) for body in KeplerBodyList])
        return U
    
    def getEkin(self):
        v = self.__vel.getVel()
        return 0.5 * self.getMass() * np.dot(v,v)
  
    def getEtot(self, KeplerBodyList):
        return self.getEkin() + self.getEpot(KeplerBodyList)
    
    def getAcceleration(self,KeplerBodyList, Norm=False):
        k = 0.01720209895 # Gaussian constant of gravity in [AU^(3/2)/(M_solar^(1/2)*d)]
        G = k*k #Newton's constant in AU^3/(M_solar * day^2)
       
        a=G*sum([body.getMass()  * self.getDistVec(body)/(self.getDistVec(body,Norm=True)**3) for body in KeplerBodyList])

        if Norm:
            return np.sqrt(np.dot(a,a))
        else:
            return a
    
    def getForce(self, KeplerBodyList, Norm=False):
        F = self.getMass()  * self.getAcceleration(KeplerBodyList)
        
        if Norm:
            return np.sqrt(np.dot(F,F))
        else:
            return F
    
    
    def getAngMomentum(self, Norm=False):
        r_vec = self.__pos.getPos()
        v_vec = self.__vel.getVel()
        
        L = np.cross(r_vec, self.getMass() *v_vec)
        
        if Norm:
            return np.sqrt(np.dot(L,L))
        else:
            return L
    
    def printInfo(self, env=None):
        print("mass:\t\t%e" %self.getMass())
        print("position:\t[%e, %e, %e]" %(self.getPos()[0],self.getPos()[1],self.getPos()[2]))
        print("velocity:\t[%e, %e, %e]" %(self.getVel()[0],self.getVel()[1],self.getVel()[2]))
        print("E_kin:\t\t%e" %self.getEkin())
       
        if not(env==None):
            assert type(env)==list
            print("E_pot:\t\t%e" %self.getEpot(env))
            print("E_total:\t%e" %self.getEtot(env))
            
            F=self.getForce(env)
            print("F_total:\t[%e, %e, %e]" %(F[0],F[1],F[2]))
            
        print("L_traj:\t\t[%e, %e, %e]" %(self.getAngMomentum()[0],self.getAngMomentum()[1],self.getAngMomentum()[2]))
        
        