
# coding: utf-8

# ## Finding centroids
# 
# In this example, we're going to find a "centroid" (representitive structure) for a group of conformations. This group might potentially come from clustering, using method like Ward hierarchical clustering.
# 
# Note that there are many possible ways to define the centroids. This is just one.

# In[ ]:

from __future__ import print_function
get_ipython().magic(u'matplotlib inline')
import mdtraj as md
import numpy as np


# Load up a trajectory to use for the example.

# In[ ]:

traj = md.load('ala2.h5')
print(traj)


# Lets compute all pairwise rmsds between conformations.

# In[ ]:

atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != 'H']
distances = np.empty((traj.n_frames, traj.n_frames))
for i in range(traj.n_frames):
    distances[i] = md.rmsd(traj, traj, i, atom_indices=atom_indices)


# The algorithim we're going to use is relatively simple:
# - Compute all of the pairwise RMSDs between the conformations. This is O(N^2), so it's not going to
#   scale extremely well to large datasets.
# - Transform these distances into similarity scores. Our similarities will calculated as
#   $$ s_{ij} = e^{-\beta \cdot d_{ij} / d_\text{scale}} $$
#   where $s_{ij}$ is the pairwise similarity, $d_{ij}$ is the pairwise distance, and $d_\text{scale}$ is the standard deviation of
#   the values of $d$, to make the computation scale invariant.
# - Then, we define the centroid as
#   $$ \text{argmax}_i \sum_j s_{ij} $$
# 
# Using $\beta=1$, this is implemented with the following code:

# In[ ]:

beta = 1
index = np.exp(-beta*distances / distances.std()).sum(axis=1).argmax()
print(index)


# In[ ]:

centroid = traj[index]
print(centroid)


# In[ ]:



