import numpy as np
import monteCarloPCP as mc
import scipy
import os
import sys
from scipy import sparse
from scipy.sparse import linalg as la
from scipy.sparse import coo_matrix
from scipy.sparse import csr_matrix
from numpy import linalg as LA
from scipy.sparse.linalg import eigsh
import time

def initializeInversePower(packing,nonAdditivity = 40,neighborCut=2):
	packing.setDistanceType(mc.enums.distanceEnum.nonAdditive)
	packing.setNonAdditivity(nonAdditivity)
	packing.setPotentialType(mc.enums.potentialEnum.inversePower)
	packing.setInversePower()
	# packing.setNeighborType(mc.neighborEnum.nList)
	packing.calcNeighborsCut(neighborCut)
	packing.calcForceEnergy()

def readNinarelloData(filename):
	if os.path.isfile(filename):
		file = open(filename)
		nParticles = int(file.readline())
		Length = np.float64(file.readline())
		pos = np.zeros((nParticles,3),dtype=np.float64)
		radii = np.zeros((nParticles,1),dtype=np.float64)
		for i in range(nParticles):
			line = file.readline()
			data = line.split()
			radii[i] = np.float64(data[3])
			for j in range(3):
				pos[i,j] = np.float64(data[j])
		file.close()
	else:
		pos = 0
		radii = 0
		nParticles = 0
		Length = 0
	return pos, radii, nParticles, Length

dim = 3
minForce = 0.1
xi = np.float64(sys.argv[1])
dataPath = sys.argv[2]
basedir = sys.argv[3]

basedir = '/work/pkm8/Data/Ninarello/N12000/Tini0.062/'

pos, radii, nParticles, Length = readNinarelloData(dataPath)

nParticles = 12000
p = mc.MCPacking(dim,nParticles,1.0)

critForce = 1e-5

# Set the random seed to generate GCF. To make reproducable, set seed to a saved integer
seed = np.random.randint(2**16)
p.setRandomSeed(seed)

p.setPositions(pos/Length)
p.setRadii(radii/(2.0 * Length))
initializeInversePower(p, nonAdditivity = 0.4*Length)

p.calcNeighborsCut(2)
p.calcForceEnergy()

# Here the packing is loaded, and we can do AQS on it, but first!

# Here we should test that the maximum unbalanced force is small
p.getMaxUnbalancedForce()

# And that the energy is finite
p.getEnergy()

p.setMaxForceRescale(minForce)

p.minimizeFIRE(criticalForce=critForce)

thisNum = 1
numSteps = 5000
stepSize = 1e-4

initPos = p.getPositions()

xi = np.min(p.getRadii())*2.0

if not os.path.exists(basedir + '/' + str(configuration)):
	os.mkdir(basedir + '/' + str(configuration))
dirName = basedir + '/' + str(configuration) + '/RFStressStrain'
if not os.path.exists(basedir + '/' + str(configuration)):
	os.mkdir(basedir + '/' + str(configuration))
if not os.path.exists(dirName):
	os.mkdir(dirName)

seedDir = dirName + '/Seed' + str(seed)
if not os.path.exists(seedDir):
	os.mkdir(seedDir)
if not os.path.exists(seedDir + '/stressStrain.dat'):
	with open(seedDir + '/meanRadius.dat','w') as f:
		f.write(str(np.mean(p.getRadii())))
	np.random.seed(seed)
	constraint = p.generateGCF(xi=xi)
	stressStrain = np.zeros((numSteps,2))
	p.setPositions(initPos)
	initializeInversePower(p, nonAdditivity = 0.4*Length)
	p.minimizeFIRE(criticalForce=critForce)
	p.setConstraintType(mc.enums.constraintEnum.strainControl)
	p.setConstraint(constraint)
	meanRad = np.mean(p.getRadii())
	mc.save2DArray(seedDir + '/constraint.dat',constraint)
	for i in range(numSteps):
		posLast = p.getPositions()
		p.translate(stepSize,constraint)
		p.calcNeighborsCut(1)
		p.minimizeFIRE(criticalForce = critForce)
		stressStrain[i,0] = stepSize * i
		stressStrain[i,1] = p.getConstraintMagnitude()
		if np.mod(i,300) is 0:
			p.save(seedDir + '/' + str(stressStrain[i,0]), overwrite=True)
		print(i)
	mc.save2DArray(seedDir + '/stressStrain.dat',stressStrain)
