From 4937e44e4f6039c0e51ee61e522810a2f1b03e89 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 5 Jun 2019 13:25:33 -0400
Subject: [PATCH 001/129] Initial upgrade to Python 3
---
.gitignore | 2 ++
setup.py | 2 +-
sitator/SiteNetwork.py | 4 ++--
sitator/SiteTrajectory.py | 2 +-
sitator/__init__.py | 4 ++--
sitator/descriptors/ConfigurationalEntropy.py | 10 +++++-----
sitator/descriptors/__init__.py | 2 +-
sitator/dynamics/DiffusionPathwayAnalysis.py | 6 +++---
sitator/dynamics/JumpAnalysis.py | 12 ++++++------
sitator/dynamics/MergeSitesByDynamics.py | 10 +++++-----
sitator/dynamics/__init__.py | 4 ++--
sitator/landmark/LandmarkAnalysis.py | 12 ++++++------
sitator/landmark/__init__.py | 4 ++--
sitator/landmark/cluster/dbscan.py | 4 ++--
sitator/landmark/pointmerge.py | 2 +-
sitator/misc/GenerateAroundSites.py | 4 ++--
sitator/misc/NAvgsPerSite.py | 4 ++--
sitator/misc/SiteVolumes.py | 6 +++---
sitator/misc/__init__.py | 6 +++---
sitator/site_descriptors/SOAP.py | 9 ++++-----
sitator/site_descriptors/SiteTypeAnalysis.py | 10 +++++-----
sitator/site_descriptors/__init__.py | 4 ++--
sitator/util/__init__.py | 8 ++++----
sitator/util/qvoronoi.py | 4 ++--
sitator/visualization/SiteNetworkPlotter.py | 6 +++---
sitator/visualization/__init__.py | 6 +++---
sitator/visualization/atoms.py | 2 +-
sitator/visualization/common.py | 2 +-
sitator/voronoi/__init__.py | 2 +-
29 files changed, 77 insertions(+), 76 deletions(-)
diff --git a/.gitignore b/.gitignore
index 4800c95..81051d0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,5 @@ build/
/dist/
/*.egg-info
+
+*.bak
diff --git a/setup.py b/setup.py
index f6067cd..93afd4b 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@
download_url = "https://github.com/Linux-cpp-lisp/sitator",
author = 'Alby Musaelian',
license = "MIT",
- python_requires = '>=2.7, <3',
+ python_requires = '>=3',
packages = find_packages(),
ext_modules = cythonize([
"sitator/landmark/helpers.pyx",
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 28909dc..3cc620a 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -258,11 +258,11 @@ def types(self):
@property
def site_attributes(self):
- return self._site_attrs.keys()
+ return list(self._site_attrs.keys())
@property
def edge_attributes(self):
- return self._edge_attrs.keys()
+ return list(self._edge_attrs.keys())
def has_attribute(self, attr):
return (attr in self._site_attrs) or (attr in self._edge_attrs)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index f934edd..3e845b2 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -156,7 +156,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
max_time_unknown = 0
total_reassigned = 0
- for i in xrange(self.n_frames):
+ for i in range(self.n_frames):
# All those unknown this frame
unknown = self._traj[i] == -1
# Update last_known for assigned sites
diff --git a/sitator/__init__.py b/sitator/__init__.py
index 0f393bc..1ca01b0 100644
--- a/sitator/__init__.py
+++ b/sitator/__init__.py
@@ -1,3 +1,3 @@
-from SiteNetwork import SiteNetwork
-from SiteTrajectory import SiteTrajectory
+from .SiteNetwork import SiteNetwork
+from .SiteTrajectory import SiteTrajectory
diff --git a/sitator/descriptors/ConfigurationalEntropy.py b/sitator/descriptors/ConfigurationalEntropy.py
index a1c31ac..92c0fa7 100644
--- a/sitator/descriptors/ConfigurationalEntropy.py
+++ b/sitator/descriptors/ConfigurationalEntropy.py
@@ -53,15 +53,15 @@ def compute(self, st):
forgivable = problems & (size_of_problems < self.acceptable_overshoot)
if self.verbose:
- print "n_i " + ("{:5.3} " * len(n_i)).format(*n_i)
- print "N_i " + ("{:>5} " * len(N_i)).format(*N_i)
- print " " + ("------" * len(n_i))
- print "P_2 " + ("{:5.3} " * len(p2)).format(*p2)
+ print("n_i " + ("{:5.3} " * len(n_i)).format(*n_i))
+ print("N_i " + ("{:>5} " * len(N_i)).format(*N_i))
+ print(" " + ("------" * len(n_i)))
+ print("P_2 " + ("{:5.3} " * len(p2)).format(*p2))
if not np.all(problems == forgivable):
raise ValueError("P_2 values for site types %s larger than 1.0 + acceptable_overshoot (%f)" % (np.where(problems)[0], self.acceptable_overshoot))
elif np.any(problems) and self.verbose:
- print ""
+ print("")
# Correct forgivable problems
p2[forgivable] = 1.0
diff --git a/sitator/descriptors/__init__.py b/sitator/descriptors/__init__.py
index 0f2d416..31d3db7 100644
--- a/sitator/descriptors/__init__.py
+++ b/sitator/descriptors/__init__.py
@@ -1 +1 @@
-from ConfigurationalEntropy import ConfigurationalEntropy
+from .ConfigurationalEntropy import ConfigurationalEntropy
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
index 245f278..5510d66 100644
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ b/sitator/dynamics/DiffusionPathwayAnalysis.py
@@ -57,8 +57,8 @@ def run(self, sn):
is_pathway = counts >= self.minimum_n_sites
if self.verbose:
- print "Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps)
- print "Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway))
+ print("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps))
+ print("Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway)))
translation = np.empty(n_ccs, dtype = np.int)
translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY
@@ -68,7 +68,7 @@ def run(self, sn):
outmat = np.empty(shape = (sn.n_sites, sn.n_sites), dtype = np.int)
- for i in xrange(sn.n_sites):
+ for i in range(sn.n_sites):
rowmask = node_pathways[i] == node_pathways
outmat[i, rowmask] = node_pathways[i]
outmat[i, ~rowmask] = DiffusionPathwayAnalysis.NO_PATHWAY
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index af5085c..fc594f9 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -29,7 +29,7 @@ def run(self, st):
assert isinstance(st, SiteTrajectory)
if self.verbose:
- print "Running JumpAnalysis..."
+ print("Running JumpAnalysis...")
n_mobile = st.site_network.n_mobile
n_frames = st.n_frames
@@ -63,7 +63,7 @@ def run(self, st):
fknown = frame >= 0
if np.any(~fknown) and self.verbose:
- print " at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown))
+ print(" at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
# -- Update stats
total_time_spent_at_site[frame[fknown]] += 1
@@ -95,7 +95,7 @@ def run(self, st):
assert not np.any(np.nonzero(avg_time_before_jump.diagonal()))
if self.verbose and n_problems != 0:
- print "Came across %i times where assignment and last known assignment were unassigned." % n_problems
+ print("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
msk = avg_time_before_jump_n > 0
# Zeros -- i.e. no jumps -- should actualy be infs
@@ -108,7 +108,7 @@ def run(self, st):
st.site_network.add_edge_attribute('p_ij', n_ij / total_time_spent_at_site)
res_times = np.empty(shape = n_sites, dtype = np.float)
- for site in xrange(n_sites):
+ for site in range(n_sites):
times = avg_time_before_jump[site]
noninf = times < np.inf
if np.any(noninf):
@@ -148,7 +148,7 @@ def jump_lag_by_type(self,
if return_counts:
countmat = np.empty(shape = outmat.shape, dtype = np.int)
- for stype_from, stype_to in itertools.product(xrange(len(all_types)), repeat = 2):
+ for stype_from, stype_to in itertools.product(range(len(all_types)), repeat = 2):
lags = sn.jump_lag[site_types == all_types[stype_from]][:, site_types == all_types[stype_to]]
# Only take things that aren't inf
lags = lags[lags < np.inf]
@@ -192,7 +192,7 @@ def plot_jump_lag(self, sn, mode = 'site', min_n_events = 1, ax = None, fig = No
mat[counts < min_n_events] = np.nan
# Show diagonal
- ax.plot(*zip([0.0, 0.0], mat.shape), color = 'k', alpha = 0.5, linewidth = 1, linestyle = '--')
+ ax.plot(*list(zip([0.0, 0.0], mat.shape)), color = 'k', alpha = 0.5, linewidth = 1, linestyle = '--')
ax.grid()
im = ax.matshow(mat, zorder = 10, cmap = 'plasma')
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index 5288cca..7abbe04 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -71,7 +71,7 @@ def run(self, st):
n_alarming_ignored_edges = 0
# Apply distance threshold
- for i in xrange(n_sites_before):
+ for i in range(n_sites_before):
dists = pbcc.distances(centers_before[i], centers_before[i + 1:])
js_too_far = np.where(dists > self.distance_threshold)[0]
js_too_far += i + 1
@@ -93,7 +93,7 @@ def run(self, st):
new_n_sites = len(clusters)
if self.verbose:
- print "After merge there will be %i sites" % new_n_sites
+ print("After merge there will be %i sites" % new_n_sites)
if self.check_types:
new_types = np.empty(shape = new_n_sites, dtype = np.int)
@@ -103,7 +103,7 @@ def run(self, st):
translation = np.empty(shape = st.site_network.n_sites, dtype = np.int)
translation.fill(-1)
- for newsite in xrange(new_n_sites):
+ for newsite in range(new_n_sites):
mask = list(clusters[newsite])
# Update translation table
if np.any(translation[mask] != -1):
@@ -167,7 +167,7 @@ def _markov_clustering(self,
allcols = np.arange(m1.shape[1])
converged = False
- for i in xrange(self.iterlimit):
+ for i in range(self.iterlimit):
# -- Expansion
m2 = np.linalg.matrix_power(m1, expansion)
# -- Inflation
@@ -182,7 +182,7 @@ def _markov_clustering(self,
if np.allclose(m1, m2):
converged = True
if self.verbose:
- print "Markov Clustering converged in %i iterations" % i
+ print("Markov Clustering converged in %i iterations" % i)
break
m1[:] = m2
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index ce898fc..a78471a 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,3 +1,3 @@
-from JumpAnalysis import JumpAnalysis
+from .JumpAnalysis import JumpAnalysis
-from MergeSitesByDynamics import MergeSitesByDynamics
+from .MergeSitesByDynamics import MergeSitesByDynamics
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index a972ca3..2bde9c1 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -158,7 +158,7 @@ def run(self, sn, frames):
n_frames = len(frames)
if self.verbose:
- print "--- Running Landmark Analysis ---"
+ print("--- Running Landmark Analysis ---")
# Create PBCCalculator
self._pbcc = PBCCalculator(sn.structure.cell)
@@ -177,7 +177,7 @@ def run(self, sn, frames):
site_vert_dists[i, :len(polyhedron)] = dists
# -- Step 2: Compute landmark vectors
- if self.verbose: print " - computing landmark vectors -"
+ if self.verbose: print(" - computing landmark vectors -")
# Compute landmark vectors
# The dimension of one landmark vector is the number of Voronoi regions
@@ -197,7 +197,7 @@ def run(self, sn, frames):
tqdm = tqdm)
# -- Step 3: Cluster landmark vectors
- if self.verbose: print " - clustering landmark vectors -"
+ if self.verbose: print(" - clustering landmark vectors -")
# - Preprocess -
self._do_peak_evening()
@@ -211,7 +211,7 @@ def run(self, sn, frames):
verbose = self.verbose)
if self.verbose:
- print " Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls)))
+ print(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))))
# reshape lables and confidences
lmk_lbls.shape = (n_frames, sn.n_mobile)
@@ -223,7 +223,7 @@ def run(self, sn, frames):
raise ValueError("There are %i mobile particles, but only identified %i sites. Check clustering_params." % (sn.n_mobile, n_sites))
if self.verbose:
- print " Identified %i sites with assignment counts %s" % (n_sites, cluster_counts)
+ print(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
# Check that multiple particles are never assigned to one site at the
# same time, cause that would be wrong.
@@ -246,7 +246,7 @@ def run(self, sn, frames):
# - Compute site centers
site_centers = np.empty(shape = (n_sites, 3), dtype = frames.dtype)
- for site in xrange(n_sites):
+ for site in range(n_sites):
mask = lmk_lbls == site
pts = frames[:, sn.mobile_mask][mask]
if self.weighted_site_positions:
diff --git a/sitator/landmark/__init__.py b/sitator/landmark/__init__.py
index e100d24..a084133 100644
--- a/sitator/landmark/__init__.py
+++ b/sitator/landmark/__init__.py
@@ -1,4 +1,4 @@
-from errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError
+from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError
-from LandmarkAnalysis import LandmarkAnalysis
+from .LandmarkAnalysis import LandmarkAnalysis
diff --git a/sitator/landmark/cluster/dbscan.py b/sitator/landmark/cluster/dbscan.py
index 0ff4856..381fa6d 100644
--- a/sitator/landmark/cluster/dbscan.py
+++ b/sitator/landmark/cluster/dbscan.py
@@ -54,8 +54,8 @@ def do_landmark_clustering(landmark_vectors,
lmk_lbls = trans_table[lmk_lbls]
if verbose:
- print "DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \
- (len(to_remove), len(cluster_counts), min_samples, min_n_samples_cluster, len(cluster_counts) - len(to_remove))
+ print("DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \
+ (len(to_remove), len(cluster_counts), min_samples, min_n_samples_cluster, len(cluster_counts) - len(to_remove)))
# Remove counts
cluster_counts = cluster_counts[~to_remove_mask]
diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py
index ece3107..e70bae7 100644
--- a/sitator/landmark/pointmerge.py
+++ b/sitator/landmark/pointmerge.py
@@ -59,7 +59,7 @@ def merge_points_soap_paths(tsoap,
assert edge_length <= sanity_check_cutoff, "edge_length %s" % edge_length
# Points along the line
- for i in xrange(n_steps):
+ for i in range(n_steps):
points_along[i] = step_vec
points_along *= step_vec_mult
points_along += edge_from_pt
diff --git a/sitator/misc/GenerateAroundSites.py b/sitator/misc/GenerateAroundSites.py
index dea0306..84f391f 100644
--- a/sitator/misc/GenerateAroundSites.py
+++ b/sitator/misc/GenerateAroundSites.py
@@ -14,9 +14,9 @@ def run(self, sn):
out = sn.copy()
pbcc = PBCCalculator(sn.structure.cell)
- print out.centers.shape
+ print(out.centers.shape)
newcenters = out.centers.repeat(self.n, axis = 0)
- print newcenters.shape
+ print(newcenters.shape)
newcenters += self.sigma * np.random.standard_normal(size = newcenters.shape)
pbcc.wrap_points(newcenters)
diff --git a/sitator/misc/NAvgsPerSite.py b/sitator/misc/NAvgsPerSite.py
index 41d74bf..bd9552d 100644
--- a/sitator/misc/NAvgsPerSite.py
+++ b/sitator/misc/NAvgsPerSite.py
@@ -35,7 +35,7 @@ def run(self, st):
types = np.empty(shape = centers.shape[0], dtype = np.int)
current_idex = 0
- for site in xrange(st.site_network.n_sites):
+ for site in range(st.site_network.n_sites):
if self.weighted:
pts, confs = st.real_positions_for_site(site, return_confidences = True)
else:
@@ -46,7 +46,7 @@ def run(self, st):
if len(pts) > self.n:
sanity = 0
- for i in xrange(self.n):
+ for i in range(self.n):
ps = pts[i::self.n]
sanity += len(ps)
c = confs[i::self.n]
diff --git a/sitator/misc/SiteVolumes.py b/sitator/misc/SiteVolumes.py
index 7a8f48c..095bc24 100644
--- a/sitator/misc/SiteVolumes.py
+++ b/sitator/misc/SiteVolumes.py
@@ -20,14 +20,14 @@ def run(self, st):
pbcc = PBCCalculator(st.site_network.structure.cell)
- for site in xrange(st.site_network.n_sites):
+ for site in range(st.site_network.n_sites):
pos = st.real_positions_for_site(site)
assert pos.flags['OWNDATA']
vol = np.inf
area = None
- for i in xrange(self.n_recenterings):
+ for i in range(self.n_recenterings):
# Recenter
offset = pbcc.cell_centroid - pos[int(i * (len(pos)/self.n_recenterings))]
pos += offset
@@ -36,7 +36,7 @@ def run(self, st):
try:
hull = ConvexHull(pos)
except QhullError as qhe:
- print "For site %i, iter %i: %s" % (site, i, qhe)
+ print("For site %i, iter %i: %s" % (site, i, qhe))
vols[site] = np.nan
areas[site] = np.nan
continue
diff --git a/sitator/misc/__init__.py b/sitator/misc/__init__.py
index 66901f4..5a7095e 100644
--- a/sitator/misc/__init__.py
+++ b/sitator/misc/__init__.py
@@ -1,4 +1,4 @@
-from NAvgsPerSite import NAvgsPerSite
-from GenerateAroundSites import GenerateAroundSites
-from SiteVolumes import SiteVolumes
+from .NAvgsPerSite import NAvgsPerSite
+from .GenerateAroundSites import GenerateAroundSites
+from .SiteVolumes import SiteVolumes
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index f2575ae..7d8cd97 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -35,7 +35,7 @@
def tqdm(iterable, **kwargs):
return iterable
-class SOAP(object):
+class SOAP(object, metaclass=ABCMeta):
"""Abstract base class for computing SOAP vectors in a SiteNetwork.
SOAP computations are *not* thread-safe; use one SOAP object per thread.
@@ -56,7 +56,6 @@ class SOAP(object):
For ideal performance: Specify environment and soap_mask correctly!
:param dict soap_params = {}: Any custom SOAP params.
"""
- __metaclass__ = ABCMeta
def __init__(self, tracer_atomic_number, environment = None,
soap_mask=None, soap_params={}, verbose =True):
from ase.data import atomic_numbers
@@ -237,7 +236,7 @@ class SOAPDescriptorAverages(SOAP):
then averaged in SOAP space to give the final SOAP vectors for each site.
This method often performs better than SOAPSampledCenters on more dynamic
- systems, but requires significantly more computation.
+ systems, but requires significantly more computation.
:param int stepsize: Stride (in frames) when computing SOAPs. Default 1.
:param int averaging: Number of SOAP vectors to average for each output vector.
@@ -297,7 +296,7 @@ def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask):
# Now, I need to allocate the output
# so for each site, I count how much data there is!
- counts = np.array([np.count_nonzero(site_traj==site_idx) for site_idx in xrange(nsit)], dtype=int)
+ counts = np.array([np.count_nonzero(site_traj==site_idx) for site_idx in range(nsit)], dtype=int)
if self._averaging is not None:
averaging = self._averaging
@@ -349,5 +348,5 @@ def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask):
if max_index[site_idx] == desc_index[site_idx]:
blocked[site_idx] = True
- desc_to_site = np.repeat(range(nsit), nr_of_descs)
+ desc_to_site = np.repeat(list(range(nsit)), nr_of_descs)
return descs, desc_to_site
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index 30a7143..7167a93 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -79,7 +79,7 @@ def run(self, descriptor_input, **kwargs):
self.dvecs = pca_dvecs
if self.verbose:
- print(" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1]))
+ print((" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1])))
# -- Do clustering
# pydpc requires a C-contiguous array
@@ -129,7 +129,7 @@ def run(self, descriptor_input, **kwargs):
# -- Voting
types = np.empty(shape = sn.n_sites, dtype = np.int)
self.winning_vote_percentages = np.empty(shape = sn.n_sites, dtype = np.float)
- for site in xrange(sn.n_sites):
+ for site in range(sn.n_sites):
corresponding_samples = dvecs_to_site == site
votes = assignments[corresponding_samples]
n_votes = len(votes)
@@ -146,7 +146,7 @@ def run(self, descriptor_input, **kwargs):
sn.site_types = types
if self.verbose:
- print((" " + "Type {:<2} " * self.n_types).format(*xrange(self.n_types)))
+ print((" " + "Type {:<2} " * self.n_types).format(*range(self.n_types)))
print(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type))
if np.any(n_sites_of_each_type == 0):
@@ -172,11 +172,11 @@ def plot_dvecs(self, fig = None, ax = None, **kwargs):
@plotter(is3D = False)
def plot_clustering(self, fig = None, ax = None, **kwargs):
ccycle = itertools.cycle(DEFAULT_COLORS)
- for cluster in xrange(self.n_types):
+ for cluster in range(self.n_types):
mask = self.dpc.membership == cluster
dvecs_core = self.dvecs[mask & ~self.dpc.border_member]
dvecs_border = self.dvecs[mask & self.dpc.border_member]
- color = ccycle.next()
+ color = next(ccycle)
ax.scatter(dvecs_core[:,0], dvecs_core[:,1], s = 3, color = color, label = "Type %i" % cluster)
ax.scatter(dvecs_border[:,0], dvecs_border[:,1], s = 3, color = color, alpha = 0.3)
diff --git a/sitator/site_descriptors/__init__.py b/sitator/site_descriptors/__init__.py
index ca132b2..cf5cdaf 100644
--- a/sitator/site_descriptors/__init__.py
+++ b/sitator/site_descriptors/__init__.py
@@ -1,3 +1,3 @@
-from SiteTypeAnalysis import SiteTypeAnalysis
+from .SiteTypeAnalysis import SiteTypeAnalysis
-from SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages
+from .SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages
diff --git a/sitator/util/__init__.py b/sitator/util/__init__.py
index 76ae207..b59424f 100644
--- a/sitator/util/__init__.py
+++ b/sitator/util/__init__.py
@@ -1,8 +1,8 @@
-from PBCCalculator import PBCCalculator
+from .PBCCalculator import PBCCalculator
-from DotProdClassifier import DotProdClassifier
+from .DotProdClassifier import DotProdClassifier
-from zeo import Zeopy
+from .zeo import Zeopy
-from RecenterTrajectory import RecenterTrajectory
+from .RecenterTrajectory import RecenterTrajectory
diff --git a/sitator/util/qvoronoi.py b/sitator/util/qvoronoi.py
index 4156a06..aaa878e 100644
--- a/sitator/util/qvoronoi.py
+++ b/sitator/util/qvoronoi.py
@@ -74,7 +74,7 @@ def periodic_voronoi(structure, logfile = sys.stdout):
facet_index = -1
next_our_index = 0
for facet_match in facets_regex.finditer(qhull_output):
- center = np.asarray(map(float, facet_match.group('center').split()))
+ center = np.asarray(list(map(float, facet_match.group('center').split())))
facet_index += 1
all_facets_centers.append(center)
@@ -82,7 +82,7 @@ def periodic_voronoi(structure, logfile = sys.stdout):
if not pbcc.is_in_image_of_cell(center, (1, 1, 1)):
continue
- verts = map(int, vertices_re.findall(facet_match.group('vertices')))
+ verts = list(map(int, vertices_re.findall(facet_match.group('vertices'))))
verts_in_main_cell = tuple(v % len(structure) for v in verts)
facet_indexes_taken.add(facet_index)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 157da1d..8804b8a 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -201,8 +201,8 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
sites_to_plot = []
sites_to_plot_positions = []
- for i in xrange(n_sites):
- for j in xrange(n_sites):
+ for i in range(n_sites):
+ for j in range(n_sites):
# No self edges
if i == j:
continue
@@ -254,7 +254,7 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
lccolors = np.empty(shape = (len(cs), 4), dtype = np.float)
# Group colors
if do_groups:
- for i in xrange(len(cs)):
+ for i in range(len(cs)):
lccolors[i] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[groups[i]])
else:
lccolors[:] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[0])
diff --git a/sitator/visualization/__init__.py b/sitator/visualization/__init__.py
index 4624148..68cc72a 100644
--- a/sitator/visualization/__init__.py
+++ b/sitator/visualization/__init__.py
@@ -1,5 +1,5 @@
-from common import layers, grid, plotter, DEFAULT_COLORS
+from .common import layers, grid, plotter, DEFAULT_COLORS
-from atoms import plot_atoms, plot_points
+from .atoms import plot_atoms, plot_points
-from SiteNetworkPlotter import SiteNetworkPlotter
+from .SiteNetworkPlotter import SiteNetworkPlotter
diff --git a/sitator/visualization/atoms.py b/sitator/visualization/atoms.py
index 1d0cc9c..153efa0 100644
--- a/sitator/visualization/atoms.py
+++ b/sitator/visualization/atoms.py
@@ -36,7 +36,7 @@ def plot_atoms(atoms, positions = None, hide_species = (), wrap = False, fig = N
all_cvecs = []
- whos_left = set(xrange(len(atoms.cell)))
+ whos_left = set(range(len(atoms.cell)))
for i, cvec1 in enumerate(atoms.cell):
all_cvecs.append(np.array([[0.0, 0.0, 0.0], cvec1]))
for j, cvec2 in enumerate(atoms.cell[list(whos_left - {i})]):
diff --git a/sitator/visualization/common.py b/sitator/visualization/common.py
index 9c4e702..b79a84f 100644
--- a/sitator/visualization/common.py
+++ b/sitator/visualization/common.py
@@ -61,7 +61,7 @@ def plotter_wraped(*args, **kwargs):
@plotter(is3D = True)
def layers(*args, **fax):
i = fax['i']
- print i
+ print(i)
for p, kwargs in args:
p(fig = fax['fig'], ax = fax['ax'], i = i, **kwargs)
i += 1
diff --git a/sitator/voronoi/__init__.py b/sitator/voronoi/__init__.py
index 17c3b95..6a5d16a 100644
--- a/sitator/voronoi/__init__.py
+++ b/sitator/voronoi/__init__.py
@@ -1,2 +1,2 @@
-from VoronoiSiteGenerator import VoronoiSiteGenerator
+from .VoronoiSiteGenerator import VoronoiSiteGenerator
From 8d0dbb1ca223e678a8fcdcd8d4717d27976a7643 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 6 Jun 2019 10:30:27 -0400
Subject: [PATCH 002/129] Minor Py3 Fixes
---
.gitignore | 3 +++
sitator/landmark/LandmarkAnalysis.py | 4 ++--
sitator/util/zeo.py | 9 +++------
sitator/visualization/SiteNetworkPlotter.py | 4 ++--
sitator/visualization/__init__.py | 2 +-
5 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/.gitignore b/.gitignore
index 81051d0..f71864f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,6 @@ build/
/*.egg-info
*.bak
+
+.tags
+.tags_swap
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 2bde9c1..1a33dcf 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -20,7 +20,7 @@ def tqdm(iterable, **kwargs):
import importlib
import tempfile
-import helpers
+from . import helpers
from sitator import SiteNetwork, SiteTrajectory
@@ -150,7 +150,7 @@ def run(self, sn, frames):
raise ValueError("Cannot rerun LandmarkAnalysis!")
if frames.shape[1:] != (sn.n_total, 3):
- raise ValueError("Wrong shape %s for frames." % frames.shape)
+ raise ValueError("Wrong shape %s for frames." % (frames.shape,))
if sn.vertices is None:
raise ValueError("Input SiteNetwork must have vertices")
diff --git a/sitator/util/zeo.py b/sitator/util/zeo.py
index 430e4bd..2c2931b 100644
--- a/sitator/util/zeo.py
+++ b/sitator/util/zeo.py
@@ -1,9 +1,6 @@
#zeopy: simple Python interface to the Zeo++ `network` tool.
# Alby Musaelian 2018
-from __future__ import (absolute_import, division,
- print_function)
-
import os
import sys
import tempfile
@@ -130,12 +127,12 @@ def parse_v1_cell(v1lines):
# remove blank lines:
v1lines = iter(filter(None, v1lines))
# First line is just "Unit cell vectors:"
- assert v1lines.next().strip() == "Unit cell vectors:"
+ assert next(v1lines).strip() == "Unit cell vectors:"
# Unit cell:
cell = np.empty(shape = (3, 3), dtype = np.float)
cellvec_re = re.compile('v[abc]=')
- for i in xrange(3):
- cellvec = v1lines.next().strip().split()
+ for i in range(3):
+ cellvec = next(v1lines).strip().split()
assert cellvec_re.match(cellvec[0])
cell[i] = [float(e) for e in cellvec[1:]]
# number of atoms, etc.
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 8804b8a..3711272 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -10,7 +10,7 @@
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from sitator.util import PBCCalculator
-from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS
+from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS, set_axes_equal
class SiteNetworkPlotter(object):
"""Plot a SiteNetwork.
@@ -74,7 +74,7 @@ def __call__(self, sn, *args, **kwargs):
ax.set_title(self.title)
- ax.set_aspect('equal')
+ set_axes_equal(ax)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
diff --git a/sitator/visualization/__init__.py b/sitator/visualization/__init__.py
index 68cc72a..025308b 100644
--- a/sitator/visualization/__init__.py
+++ b/sitator/visualization/__init__.py
@@ -1,4 +1,4 @@
-from .common import layers, grid, plotter, DEFAULT_COLORS
+from .common import layers, grid, plotter, DEFAULT_COLORS, set_axes_equal
from .atoms import plot_atoms, plot_points
From 41d2f2ccc7c98defbdb61ef95b2b75f62b9d338b Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 6 Jun 2019 10:43:31 -0400
Subject: [PATCH 003/129] Fixed edge case where vertices are numpy arrays
---
.gitignore | 5 +++++
sitator/landmark/LandmarkAnalysis.py | 2 +-
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/.gitignore b/.gitignore
index 4800c95..f71864f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,8 @@ build/
/dist/
/*.egg-info
+
+*.bak
+
+.tags
+.tags_swap
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index a972ca3..8f74eb3 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -167,7 +167,7 @@ def run(self, sn, frames):
self._landmark_dimension = sn.n_sites
longest_vert_set = np.max([len(v) for v in sn.vertices])
- verts_np = np.array([v + [-1] * (longest_vert_set - len(v)) for v in sn.vertices])
+ verts_np = np.array([np.concatenate((v, [-1] * (longest_vert_set - len(v)))) for v in sn.vertices])
site_vert_dists = np.empty(shape = verts_np.shape, dtype = np.float)
site_vert_dists.fill(np.nan)
From 12ea6039e0c0354ab9894c6588eb7a94aeebbdf7 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 6 Jun 2019 11:01:39 -0400
Subject: [PATCH 004/129] Py3 datatype fix
---
sitator/landmark/LandmarkAnalysis.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 8a5a0c5..c46fa12 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -167,7 +167,7 @@ def run(self, sn, frames):
self._landmark_dimension = sn.n_sites
longest_vert_set = np.max([len(v) for v in sn.vertices])
- verts_np = np.array([np.concatenate((v, [-1] * (longest_vert_set - len(v)))) for v in sn.vertices])
+ verts_np = np.array([np.concatenate((v, [-1] * (longest_vert_set - len(v)))) for v in sn.vertices], dtype = np.int)
site_vert_dists = np.empty(shape = verts_np.shape, dtype = np.float)
site_vert_dists.fill(np.nan)
From b6e7501c34f030d68c9dbd8387e1de5623425a41 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 6 Jun 2019 12:01:30 -0400
Subject: [PATCH 005/129] Fixed axes aspect ratios for Py3 matplotlib
---
sitator/visualization/SiteNetworkPlotter.py | 1 -
sitator/visualization/common.py | 10 +++++++++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 3711272..630a866 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -74,7 +74,6 @@ def __call__(self, sn, *args, **kwargs):
ax.set_title(self.title)
- set_axes_equal(ax)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
diff --git a/sitator/visualization/common.py b/sitator/visualization/common.py
index b79a84f..3886dbc 100644
--- a/sitator/visualization/common.py
+++ b/sitator/visualization/common.py
@@ -33,8 +33,10 @@ def plotter_wrapper(func):
def plotter_wraped(*args, **kwargs):
fig = None
ax = None
+ toplevel = False
if not ('ax' in kwargs and 'fig' in kwargs):
- # No existing axis/figure
+ # No existing axis/figure - toplevel
+ toplevel = True
fig = plt.figure(**outer)
if is3D:
ax = fig.add_subplot(111, projection = '3d')
@@ -54,8 +56,14 @@ def plotter_wraped(*args, **kwargs):
kwargs['i'] = 0
func(*args, fig = fig, ax = ax, **kwargs)
+
+ if is3D and toplevel:
+ set_axes_equal(ax)
+
return fig, ax
+
return plotter_wraped
+
return plotter_wrapper
@plotter(is3D = True)
From 286b04b46a8a197eae81cb14b90011beed2207b9 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 10 Jun 2019 11:18:19 -0400
Subject: [PATCH 006/129] Catch error in recentering with better message
---
sitator/util/RecenterTrajectory.pyx | 2 ++
1 file changed, 2 insertions(+)
diff --git a/sitator/util/RecenterTrajectory.pyx b/sitator/util/RecenterTrajectory.pyx
index 39e5d7c..c70430f 100644
--- a/sitator/util/RecenterTrajectory.pyx
+++ b/sitator/util/RecenterTrajectory.pyx
@@ -32,6 +32,8 @@ class RecenterTrajectory(object):
masses of all atoms in the system.
"""
+ assert np.any(static_mask), "Static mask all false; there must be static atoms to recenter on."
+
factors = static_mask.astype(np.float)
n_static = np.sum(static_mask)
From 194b5198545bd00e1c62509fcecd56716d1882ff Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 10 Jun 2019 13:22:25 -0400
Subject: [PATCH 007/129] Updated SOAP support for Python 3
---
README.md | 8 +-
setup.py | 5 +-
sitator/site_descriptors/SOAP.py | 121 +++++++------------
sitator/site_descriptors/SiteTypeAnalysis.py | 1 -
sitator/site_descriptors/__init__.py | 2 +-
sitator/site_descriptors/backend/__init__.py | 0
sitator/site_descriptors/backend/dscribe.py | 35 ++++++
sitator/site_descriptors/backend/quip.py | 84 +++++++++++++
8 files changed, 175 insertions(+), 81 deletions(-)
create mode 100644 sitator/site_descriptors/backend/__init__.py
create mode 100644 sitator/site_descriptors/backend/dscribe.py
create mode 100644 sitator/site_descriptors/backend/quip.py
diff --git a/README.md b/README.md
index 607bcd1..4a7cf1d 100644
--- a/README.md
+++ b/README.md
@@ -19,10 +19,12 @@ If you use `sitator` in your research, please consider citing this paper. The Bi
## Installation
-`sitator` is currently built for Python 2.7. We recommend the use of a Python 2.7 virtual environment (`virtualenv`, `conda`, etc.). `sitator` has two external dependencies:
+`sitator` is built for Python >=3.2 (the older version supports Python 2.7). We recommend the use of a virtual environment (`virtualenv`, `conda`, etc.). `sitator` has one mandatory external dependency:
- The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does *not* have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
- - If you want to use the site type analysis features, a working installation of [Quippy](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) and Python bindings is required for computing SOAP descriptors.
+
+
+If you want to use the site type analysis features, the `quip` binary from an installation of [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) can be used to compute the SOAP vectors. The Python 2.7 bindings (`quippy`) are **not** required. SOAP vectors can **also** be computed with [`DScribe`](https://singroup.github.io/dscribe/index.html) and the installation of QUIP avoided; note, however, that the descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on your system.
After downloading, the package is installed with `pip`:
@@ -32,7 +34,7 @@ cd sitator
pip install .
```
-To enable site type analysis, add the `[SiteTypeAnalysis]` option:
+To enable site type analysis, add the `[SiteTypeAnalysis]` option (this adds two dependencies -- Python packages `pydpc` and `dscribe`):
```
pip install ".[SiteTypeAnalysis]"
diff --git a/setup.py b/setup.py
index 93afd4b..54bbfa8 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@
download_url = "https://github.com/Linux-cpp-lisp/sitator",
author = 'Alby Musaelian',
license = "MIT",
- python_requires = '>=3',
+ python_requires = '>=3.2',
packages = find_packages(),
ext_modules = cythonize([
"sitator/landmark/helpers.pyx",
@@ -27,7 +27,8 @@
],
extras_require = {
"SiteTypeAnalysis" : [
- "pydpc"
+ "pydpc",
+ "dscribe"
]
},
zip_safe = True)
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index 7d8cd97..f98c750 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -5,20 +5,8 @@
from sitator.SiteNetwork import SiteNetwork
from sitator.SiteTrajectory import SiteTrajectory
-try:
- import quippy as qp
- from quippy import descriptors
-except ImportError:
- raise ImportError("Quippy with GAP is required for using SOAP descriptors.")
-
from ase.data import atomic_numbers
-DEFAULT_SOAP_PARAMS = {
- 'cutoff' : 3.0,
- 'cutoff_transition_width' : 1.0,
- 'l_max' : 6, 'n_max' : 6,
- 'atom_sigma' : 0.4
-}
# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
import sys
@@ -55,31 +43,31 @@ class SOAP(object, metaclass=ABCMeta):
Even not masked, species not considered in environment will be not accounted for.
For ideal performance: Specify environment and soap_mask correctly!
:param dict soap_params = {}: Any custom SOAP params.
+ :param func backend: A function that can be called with `sn, soap_mask, tracer_atomic_number, environment_list` as
+ parameters, returning a function that, given the current soap structure
+ along with tracer atoms, returns SOAP vectors in a numpy array. (i.e.
+ its signature is `soap(structure, positions)`)
"""
+
+ from .backend.quip import quip_soap_backend as backend_quip
+ from .backend.dscribe import dscribe_soap_backend as backend_dscribe
+
def __init__(self, tracer_atomic_number, environment = None,
- soap_mask=None, soap_params={}, verbose =True):
+ soap_mask = None, verbose =True,
+ backend = None):
from ase.data import atomic_numbers
- # Creating a dictionary for convenience, to check the types and values:
- self.tracer_atomic_number = 3
- centers_list = [self.tracer_atomic_number]
+ self.tracer_atomic_number = tracer_atomic_number
self._soap_mask = soap_mask
- # -- Create the descriptor object
- soap_opts = dict(DEFAULT_SOAP_PARAMS)
- soap_opts.update(soap_params)
- soap_cmd_line = ["soap"]
-
- # User options
- for opt in soap_opts:
- soap_cmd_line.append("{}={}".format(opt, soap_opts[opt]))
+ self._verbose = verbose
- #
- soap_cmd_line.append('n_Z={} Z={{{}}}'.format(len(centers_list), ' '.join(map(str, centers_list))))
+ if backend is None:
+ backend = SOAP.dscribe_soap_backend
+ self._backend = backend
- # - Add environment species controls if given
- self._environment = None
- if not environment is None:
+ # - Standardize environment species controls if given
+ if not environment is None: # User given environment
if not isinstance(environment, (list, tuple)):
raise TypeError('environment has to be a list or tuple of species (atomic number'
' or symbol of the environment to consider')
@@ -98,22 +86,9 @@ def __init__(self, tracer_atomic_number, environment = None,
raise TypeError("Environment has to be a list of atomic numbers or atomic symbols")
self._environment = environment_list
- soap_cmd_line.append('n_species={} species_Z={{{}}}'.format(len(environment_list), ' '.join(map(str, environment_list))))
-
- soap_cmd_line = " ".join(soap_cmd_line)
-
- if verbose:
- print("SOAP command line: %s" % soap_cmd_line)
-
- self._soaper = descriptors.Descriptor(soap_cmd_line)
- self._verbose = verbose
- self._cutoff = soap_opts['cutoff']
-
-
+ else:
+ self._environment = None
- @property
- def n_dim(self):
- return self._soaper.n_dim
def get_descriptors(self, stn):
"""
@@ -124,14 +99,24 @@ def get_descriptors(self, stn):
"""
# Build SOAP host structure
if isinstance(stn, SiteTrajectory):
- structure, tracer_index, soap_mask = self._make_structure(stn.site_network)
+ sn = stn.site_network
elif isinstance(stn, SiteNetwork):
- structure, tracer_index, soap_mask = self._make_structure(stn)
+ sn = stn
else:
raise TypeError("`stn` must be SiteNetwork or SiteTrajectory")
+ structure, tracer_atomic_number, soap_mask = self._make_structure(sn)
+
+ if self._environment is not None:
+ environment_list = self._environment
+ else:
+ # Set it to all species represented by the soap_mask
+ environment_list = np.unique(sn.structure.get_atomic_numbers()[soap_mask])
+
+ soaper = self._backend(sn, soap_mask, tracer_atomic_number, environment_list)
+
# Compute descriptors
- return self._get_descriptors(stn, structure, tracer_index, soap_mask)
+ return self._get_descriptors(stn, structure, tracer_atomic_number, soap_mask, soaper)
# ----
@@ -139,7 +124,7 @@ def _make_structure(self, sn):
if self._soap_mask is None:
# Make a copy of the static structure
- structure = qp.Atoms(sn.static_structure)
+ structure = sn.static_structure.copy()
soap_mask = sn.static_mask # soap mask is the
else:
if isinstance(self._soap_mask, tuple):
@@ -148,7 +133,7 @@ def _make_structure(self, sn):
soap_mask = self._soap_mask
assert not np.any(soap_mask & sn.mobile_mask), "Error for atoms %s; No atom can be both static and mobile" % np.where(soap_mask & sn.mobile_mask)[0]
- structure = qp.Atoms(sn.structure[soap_mask])
+ structure = sn.structure[soap_mask]
assert np.any(soap_mask), "Given `soap_mask` excluded all host atoms."
if not self._environment is None:
@@ -160,14 +145,16 @@ def _make_structure(self, sn):
else:
tracer_atomic_number = self.tracer_atomic_number
- structure.add_atoms((0.0, 0.0, 0.0), tracer_atomic_number)
+ if np.any(structure.get_atomic_numbers() == tracer_atomic_number):
+ raise ValueError("Structure cannot have static atoms (that are enabled in the SOAP mask) of the same species as `tracer_atomic_number`.")
+
structure.set_pbc([True, True, True])
- tracer_index = len(structure) - 1
- return structure, tracer_index, soap_mask
+ return structure, tracer_atomic_number, soap_mask
+
@abstractmethod
- def _get_descriptors(self, stn, structure, tracer_index):
+ def _get_descriptors(self, stn, structure, tracer_atomic_number, soaper):
pass
@@ -178,24 +165,12 @@ class SOAPCenters(SOAP):
Requires a SiteNetwork as input.
"""
- def _get_descriptors(self, sn, structure, tracer_index, soap_mask):
+ def _get_descriptors(self, sn, structure, tracer_atomic_number, soap_mask, soaper):
assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork, not `%s`" % sn
pts = sn.centers
- out = np.empty(shape = (len(pts), self.n_dim), dtype = np.float)
-
- structure.set_cutoff(self._soaper.cutoff())
-
- for i, pt in enumerate(tqdm(pts, desc="SOAP") if self._verbose else pts):
- # Move tracer
- structure.positions[tracer_index] = pt
-
- # SOAP requires connectivity data to be computed first
- structure.calc_connect()
-
- #There should only be one descriptor, since there should only be one Li
- out[i] = self._soaper.calc(structure)['descriptor'][0]
+ out = soaper(structure, pts)
return out, np.arange(sn.n_sites)
@@ -280,7 +255,7 @@ def __init__(self, *args, **kwargs):
super(SOAPDescriptorAverages, self).__init__(*args, **kwargs)
- def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask):
+ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soap_mask, soaper):
"""
calculate descriptors
"""
@@ -317,23 +292,21 @@ def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask):
count_of_site = np.zeros(len(nr_of_descs), dtype=int)
blocked = np.empty(nsit, dtype=bool)
blocked[:] = False
- structure.set_cutoff(self._soaper.cutoff())
+
for site_traj_t, pos in tqdm(zip(site_traj, real_traj), desc="SOAP"):
# I update the host lattice positions here, once for every timestep
- structure.positions[:tracer_index] = pos[soap_mask]
+ structure.positions[:] = pos[soap_mask]
+
for mob_idx, site_idx in enumerate(site_traj_t):
if site_idx >= 0 and not blocked[site_idx]:
# Now, for every lithium that has been associated to a site of index site_idx,
# I take my structure and load the position of this mobile atom:
- structure.positions[tracer_index] = pos[mob_indices[mob_idx]]
# calc_connect to calculated distance
# structure.calc_connect()
#There should only be one descriptor, since there should only be one mobile
# I also divide by averaging, to avoid getting into large numbers.
# soapv = self._soaper.calc(structure)['descriptor'][0] / self._averaging
- structure.set_cutoff(self._cutoff)
- structure.calc_connect()
- soapv = self._soaper.calc(structure, grad=False)["descriptor"]
+ soapv = soaper(structure, [pos[mob_indices[mob_idx]]])
#~ soapv ,_,_ = get_fingerprints([structure], d)
# So, now I need to figure out where to load the soapv into desc
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index 7167a93..465ff53 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -24,7 +24,6 @@ class SiteTypeAnalysis(object):
-- descriptor --
Some kind of object implementing:
- - n_dim: the number of components in a descriptor vector
- get_descriptors(site_traj or site_network): returns an array of descriptor vectors
of dimension (M, n_dim) and an array of length M indicating which
descriptor vectors correspond to which sites in (site_traj.)site_network.
diff --git a/sitator/site_descriptors/__init__.py b/sitator/site_descriptors/__init__.py
index cf5cdaf..050a51e 100644
--- a/sitator/site_descriptors/__init__.py
+++ b/sitator/site_descriptors/__init__.py
@@ -1,3 +1,3 @@
from .SiteTypeAnalysis import SiteTypeAnalysis
-from .SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages
+from .SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages, SOAP
diff --git a/sitator/site_descriptors/backend/__init__.py b/sitator/site_descriptors/backend/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/sitator/site_descriptors/backend/dscribe.py b/sitator/site_descriptors/backend/dscribe.py
new file mode 100644
index 0000000..62d3eef
--- /dev/null
+++ b/sitator/site_descriptors/backend/dscribe.py
@@ -0,0 +1,35 @@
+
+import numpy as np
+
+DEFAULT_SOAP_PARAMS = {
+ 'cutoff' : 3.0,
+ 'l_max' : 6, 'n_max' : 6,
+ 'atom_sigma' : 0.4,
+ 'rbf' : 'gto'
+}
+
+def dscribe_soap_backend(soap_params = {}):
+ from dscribe.descriptors import SOAP
+
+ soap_opts = dict(DEFAULT_SOAP_PARAMS)
+ soap_opts.update(soap_params)
+
+ def backend(sn, soap_mask, tracer_atomic_number, environment_list):
+
+ def dscribe_soap(structure, positions):
+ soap = SOAP(
+ species = environment_list,
+ crossover = False,
+ rcut = soap_opts['cutoff'],
+ nmax = soap_opts['n_max'],
+ lmax = soap_opts['l_max'],
+ rbf = soap_opts['rbf']
+ periodic = np.all(structure.pbc),
+ sparse = False
+ )
+
+ return soap.create(structure, positions = positions).astype(np.float)
+
+ return dscribe_soap
+
+ return backend
diff --git a/sitator/site_descriptors/backend/quip.py b/sitator/site_descriptors/backend/quip.py
new file mode 100644
index 0000000..576675a
--- /dev/null
+++ b/sitator/site_descriptors/backend/quip.py
@@ -0,0 +1,84 @@
+"""
+quip.py: Compute SOAP vectors for given positions in a structure using the command line QUIP tool
+"""
+
+import numpy as np
+
+import ase
+
+from tempfile import NamedTemporaryFile
+import subprocess
+
+DEFAULT_SOAP_PARAMS = {
+ 'cutoff' : 3.0,
+ 'cutoff_transition_width' : 1.0,
+ 'l_max' : 6, 'n_max' : 6,
+ 'atom_sigma' : 0.4
+}
+
+def quip_soap_backend(soap_params = {}, quip_path = 'quip'):
+ def backend(sn, soap_mask, tracer_atomic_number, environment_list):
+
+ soap_opts = dict(DEFAULT_SOAP_PARAMS)
+ soap_opts.update(soap_params)
+ soap_cmd_line = ["soap"]
+
+ # User options
+ for opt in soap_opts:
+ soap_cmd_line.append("{}={}".format(opt, soap_opts[opt]))
+
+ #
+ soap_cmd_line.append('n_Z=1 Z={{{}}}'.format(tracer_atomic_number))
+
+ soap_cmd_line.append('n_species={} species_Z={{{}}}'.format(len(environment_list), ' '.join(map(str, environment_list))))
+
+ soap_cmd_line = " ".join(soap_cmd_line)
+
+ def soaper(structure, positions):
+ structure = structure.copy()
+ for i in range(len(positions)):
+ structure.append(ase.Atom(position = tuple(positions[i]), symbol = tracer_atomic_number))
+ return _soap(soap_cmd_line, structure, quip_path = quip_path)
+
+ return soaper
+ return backend
+
+
+def _soap(descriptor_str,
+ structure,
+ quip_path = 'quip'):
+ """Calculate SOAP vectors by calling `quip` as a subprocess.
+
+ Args:
+ - descriptor_str (str): The QUIP descriptor str, i.e. `soap cutoff=3 ...`
+ - structure (ase.Atoms)
+ - quip_path (str): Path to `quip` executable
+ """
+
+ with NamedTemporaryFile() as xyz:
+ structure.write(xyz.name, format = 'extxyz')
+
+ quip_cmd = [
+ quip_path,
+ "atoms_filename=" + xyz.name,
+ "descriptor_str=\"" + descriptor_str + "\""
+ ]
+
+ result = subprocess.run(quip_cmd, stdout = subprocess.PIPE, check = True, text = True).stdout
+
+ lines = result.splitlines()
+
+ soaps = []
+ for line in lines:
+ if line.startswith("DESC"):
+ soaps.append(np.fromstring(line.lstrip("DESC"), dtype = np.float, sep = ' '))
+ elif line.startswith("Error"):
+ e = subprocess.CalledProcessError(returncode = 0, cmd = quip_cmd)
+ e.stdout = result
+ raise e
+ else:
+ continue
+
+ soaps = np.asarray(soaps)
+
+ return soaps
From b561c4f0c35200dd560959b0b514963b70dcd5b7 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 10 Jun 2019 16:40:17 -0400
Subject: [PATCH 008/129] DScribe improvements
---
README.md | 2 +-
sitator/site_descriptors/backend/dscribe.py | 10 ++++++----
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index 4a7cf1d..9aa7a09 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@ If you use `sitator` in your research, please consider citing this paper. The Bi
- The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does *not* have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
-If you want to use the site type analysis features, the `quip` binary from an installation of [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) can be used to compute the SOAP vectors. The Python 2.7 bindings (`quippy`) are **not** required. SOAP vectors can **also** be computed with [`DScribe`](https://singroup.github.io/dscribe/index.html) and the installation of QUIP avoided; note, however, that the descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on your system.
+If you want to use the site type analysis features, the `quip` binary from an installation of [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) can be used to compute the SOAP vectors. The Python 2.7 bindings (`quippy`) are **not** required. SOAP vectors can **also** be computed with [`DScribe`](https://singroup.github.io/dscribe/index.html) and the installation of QUIP avoided; note, however, that the descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
After downloading, the package is installed with `pip`:
diff --git a/sitator/site_descriptors/backend/dscribe.py b/sitator/site_descriptors/backend/dscribe.py
index 62d3eef..d9f9de3 100644
--- a/sitator/site_descriptors/backend/dscribe.py
+++ b/sitator/site_descriptors/backend/dscribe.py
@@ -5,7 +5,8 @@
'cutoff' : 3.0,
'l_max' : 6, 'n_max' : 6,
'atom_sigma' : 0.4,
- 'rbf' : 'gto'
+ 'rbf' : 'gto',
+ 'crossover' : False
}
def dscribe_soap_backend(soap_params = {}):
@@ -19,16 +20,17 @@ def backend(sn, soap_mask, tracer_atomic_number, environment_list):
def dscribe_soap(structure, positions):
soap = SOAP(
species = environment_list,
- crossover = False,
+ crossover = soap_opts['crossover'],
rcut = soap_opts['cutoff'],
nmax = soap_opts['n_max'],
lmax = soap_opts['l_max'],
- rbf = soap_opts['rbf']
+ rbf = soap_opts['rbf'],
periodic = np.all(structure.pbc),
sparse = False
)
- return soap.create(structure, positions = positions).astype(np.float)
+ out = soap.create(structure, positions = positions).astype(np.float)
+ return out
return dscribe_soap
From 32b216e56c3db875efc7fd441647c0dacf4063bc Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 10 Jun 2019 17:57:18 -0400
Subject: [PATCH 009/129] Improved exception system
---
sitator/landmark/LandmarkAnalysis.py | 7 ++++---
sitator/landmark/__init__.py | 2 +-
sitator/landmark/errors.py | 7 +++++--
3 files changed, 10 insertions(+), 6 deletions(-)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index c46fa12..402fa8f 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -1,5 +1,6 @@
import numpy as np
+from sitator.landmark import MultipleOccupancyError
from sitator.util import PBCCalculator
# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
@@ -219,8 +220,8 @@ def run(self, sn, frames):
n_sites = len(cluster_counts)
- if n_sites < sn.n_mobile:
- raise ValueError("There are %i mobile particles, but only identified %i sites. Check clustering_params." % (sn.n_mobile, n_sites))
+ if n_sites < (sn.n_mobile / self.max_mobile_per_site):
+ raise MultipleOccupancyError("There are %i mobile particles, but only identified %i sites. With %i max_mobile_per_site, this is an error. Check clustering_params." % (sn.n_mobile, n_sites, self.max_mobile_per_site))
if self.verbose:
print(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
@@ -234,7 +235,7 @@ def run(self, sn, frames):
_, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
count_msk = counts > self.max_mobile_per_site
if np.any(count_msk):
- raise ValueError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i))
+ raise MultipleOccupancyError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i))
n_more_than_ones += np.sum(counts > 1)
avg_mobile_per_site += np.sum(counts)
divisor += len(counts)
diff --git a/sitator/landmark/__init__.py b/sitator/landmark/__init__.py
index a084133..2af5cf2 100644
--- a/sitator/landmark/__init__.py
+++ b/sitator/landmark/__init__.py
@@ -1,4 +1,4 @@
-from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError
+from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError, MultipleOccupancyError
from .LandmarkAnalysis import LandmarkAnalysis
diff --git a/sitator/landmark/errors.py b/sitator/landmark/errors.py
index 55dcfb2..a170cd3 100644
--- a/sitator/landmark/errors.py
+++ b/sitator/landmark/errors.py
@@ -2,7 +2,7 @@
class LandmarkAnalysisError(Exception):
pass
-class StaticLatticeError(Exception):
+class StaticLatticeError(LandmarkAnalysisError):
"""Error raised when static lattice atoms break any limits on their movement/position.
Attributes:
@@ -25,7 +25,7 @@ def __init__(self, message, lattice_atoms = None, frame = None, try_recentering
self.lattice_atoms = lattice_atoms
self.frame = frame
-class ZeroLandmarkError(Exception):
+class ZeroLandmarkError(LandmarkAnalysisError):
def __init__(self, mobile_index, frame):
message = "Encountered a zero landmark vector for mobile ion %i at frame %i. Try increasing `cutoff_midpoint` and/or decreasing `cutoff_steepness`." % (mobile_index, frame)
@@ -34,3 +34,6 @@ def __init__(self, mobile_index, frame):
self.mobile_index = mobile_index
self.frame = frame
+
+class MultipleOccupancyError(LandmarkAnalysisError):
+ pass
From cfe50c076db27730c8d5ebf3bab6ea23fedde6be Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 11 Jun 2019 17:32:20 -0400
Subject: [PATCH 010/129] Fixed bug in NAvgsPerSite with insufficient data
---
sitator/misc/NAvgsPerSite.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/sitator/misc/NAvgsPerSite.py b/sitator/misc/NAvgsPerSite.py
index bd9552d..afcd2a5 100644
--- a/sitator/misc/NAvgsPerSite.py
+++ b/sitator/misc/NAvgsPerSite.py
@@ -32,7 +32,9 @@ def run(self, st):
pbcc = PBCCalculator(st.site_network.structure.cell)
# Maximum length
centers = np.empty(shape = (self.n * st.site_network.n_sites, 3), dtype = st.real_trajectory.dtype)
+ centers.fill(np.nan)
types = np.empty(shape = centers.shape[0], dtype = np.int)
+ types.fill(np.nan)
current_idex = 0
for site in range(st.site_network.n_sites):
@@ -64,7 +66,9 @@ def run(self, st):
types[old_idex:current_idex] = site
sn = st.site_network.copy()
- sn.centers = centers
- sn.site_types = types
+ sn.centers = centers[:current_idex]
+ sn.site_types = types[:current_idex]
+
+ assert not (np.any(np.isnan(sn.centers)) or np.any(np.isnan(sn.site_types)))
return sn
From 3a6a37606c6422c26279a9bad7fe47a2d3e69d94 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 11 Jun 2019 17:33:30 -0400
Subject: [PATCH 011/129] Cleaned up tqdm imports
---
sitator/landmark/LandmarkAnalysis.py | 13 +++----------
sitator/landmark/pointmerge.py | 13 +++----------
sitator/site_descriptors/SOAP.py | 13 +++----------
sitator/util/DotProdClassifier.pyx | 15 ++++-----------
4 files changed, 13 insertions(+), 41 deletions(-)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 402fa8f..0401c70 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -6,17 +6,10 @@
# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
import sys
try:
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
+ from tqdm.autonotebook import tqdm
except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
+ def tqdm(iterable, **kwargs):
+ return iterable
import importlib
import tempfile
diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py
index e70bae7..40cdd7a 100644
--- a/sitator/landmark/pointmerge.py
+++ b/sitator/landmark/pointmerge.py
@@ -4,17 +4,10 @@
# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
import sys
try:
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
+ from tqdm.autonotebook import tqdm
except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
+ def tqdm(iterable, **kwargs):
+ return iterable
def merge_points_soap_paths(tsoap,
pbcc,
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index f98c750..9ba7a9b 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -11,17 +11,10 @@
# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
import sys
try:
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
+ from tqdm.autonotebook import tqdm
except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
+ def tqdm(iterable, **kwargs):
+ return iterable
class SOAP(object, metaclass=ABCMeta):
"""Abstract base class for computing SOAP vectors in a SiteNetwork.
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index d933a5c..213e957 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -5,19 +5,12 @@ import numpy as np
import numbers
import sys
+
try:
- from IPython import get_ipython
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
+ from tqdm.autonotebook import tqdm
except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
+ def tqdm(iterable, **kwargs):
+ return iterable
N_SITES_ALLOC_INCREMENT = 100
From 34431d8d5b3977514bc8aab3c4a10d51ce0ec0f5 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 12 Jun 2019 10:31:20 -0400
Subject: [PATCH 012/129] Converted all `verbose`/`print` to `logging`
---
sitator/SiteNetwork.py | 4 --
sitator/SiteTrajectory.py | 22 ++++-----
sitator/descriptors/ConfigurationalEntropy.py | 17 ++++---
sitator/dynamics/DiffusionPathwayAnalysis.py | 13 +++---
sitator/dynamics/JumpAnalysis.py | 18 ++++----
sitator/dynamics/MergeSitesByDynamics.py | 20 ++++-----
sitator/landmark/LandmarkAnalysis.py | 25 ++++++-----
sitator/landmark/cluster/dbscan.py | 9 ++--
sitator/landmark/helpers.pyx | 17 +++++--
sitator/landmark/pointmerge.py | 3 +-
sitator/misc/GenerateAroundSites.py | 4 +-
sitator/misc/SiteVolumes.py | 5 ++-
sitator/site_descriptors/SOAP.py | 4 +-
sitator/site_descriptors/SiteTypeAnalysis.py | 45 ++++++++-----------
sitator/util/DotProdClassifier.pyx | 12 ++---
sitator/util/zeo.py | 12 ++---
sitator/visualization/SiteNetworkPlotter.py | 4 --
sitator/visualization/common.py | 1 -
sitator/voronoi/VoronoiSiteGenerator.py | 7 +--
19 files changed, 118 insertions(+), 124 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 3cc620a..7e40cff 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -1,7 +1,3 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
import re
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 3e845b2..5c5f2ff 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -1,7 +1,3 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
from sitator.util import PBCCalculator
@@ -10,6 +6,9 @@
import matplotlib
from matplotlib.collections import LineCollection
+import logging
+logger = logging.getLogger(__name__)
+
class SiteTrajectory(object):
"""A trajectory capturing the dynamics of particles through a SiteNetwork."""
@@ -137,7 +136,7 @@ def compute_site_occupancies(self):
self.site_network.add_site_attribute('occupancies', occ)
return occ
- def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
+ def assign_to_last_known_site(self, frame_threshold = 1):
"""Assign unassigned mobile particles to their last known site within
`frame_threshold` frames.
@@ -145,8 +144,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
"""
total_unknown = self.n_unassigned
- if verbose:
- print("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %i frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold))
+ logger.info("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %i frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold))
last_known = np.empty(shape = self._sn.n_mobile, dtype = np.int)
last_known.fill(-1)
@@ -184,10 +182,9 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
if avg_time_unknown_div > 0: # We corrected some unknowns
avg_time_unknown = float(avg_time_unknown) / avg_time_unknown_div
- if verbose:
- print(" Maximum # of frames any mobile particle spent unassigned: %i" % max_time_unknown)
- print(" Avg. # of frames spent unassigned: %f" % avg_time_unknown)
- print(" Assigned %i/%i unassigned positions, leaving %i (%i%%) unknown" % (total_reassigned, total_unknown, self.n_unassigned, self.percent_unassigned))
+ logger.info(" Maximum # of frames any mobile particle spent unassigned: %i" % max_time_unknown)
+ logger.info(" Avg. # of frames spent unassigned: %f" % avg_time_unknown)
+ logger.info(" Assigned %i/%i unassigned positions, leaving %i (%i%%) unknown" % (total_reassigned, total_unknown, self.n_unassigned, self.percent_unassigned))
res = {
'max_time_unknown' : max_time_unknown,
@@ -195,8 +192,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
'total_reassigned' : total_reassigned
}
else:
- if self.verbose:
- print(" None to correct.")
+ logger.info(" None to correct.")
res = {
'max_time_unknown' : 0,
diff --git a/sitator/descriptors/ConfigurationalEntropy.py b/sitator/descriptors/ConfigurationalEntropy.py
index 92c0fa7..de61258 100644
--- a/sitator/descriptors/ConfigurationalEntropy.py
+++ b/sitator/descriptors/ConfigurationalEntropy.py
@@ -3,6 +3,9 @@
from sitator import SiteTrajectory
from sitator.dynamics import JumpAnalysis
+import logging
+logger = logging.getLogger(__name__)
+
class ConfigurationalEntropy(object):
"""Compute the S~ configurational entropy.
@@ -15,9 +18,8 @@ class ConfigurationalEntropy(object):
Chemistry of Materials 2017 29 (21), 9142-9153
DOI: 10.1021/acs.chemmater.7b02902
"""
- def __init__(self, acceptable_overshoot = 0.0001, verbose = True):
+ def __init__(self, acceptable_overshoot = 0.0001):
self.acceptable_overshoot = acceptable_overshoot
- self.verbose = verbose
def compute(self, st):
assert isinstance(st, SiteTrajectory)
@@ -52,16 +54,13 @@ def compute(self, st):
size_of_problems = p2 - 1.0
forgivable = problems & (size_of_problems < self.acceptable_overshoot)
- if self.verbose:
- print("n_i " + ("{:5.3} " * len(n_i)).format(*n_i))
- print("N_i " + ("{:>5} " * len(N_i)).format(*N_i))
- print(" " + ("------" * len(n_i)))
- print("P_2 " + ("{:5.3} " * len(p2)).format(*p2))
+ logger.info("n_i " + ("{:5.3} " * len(n_i)).format(*n_i))
+ logger.info("N_i " + ("{:>5} " * len(N_i)).format(*N_i))
+ logger.info(" " + ("------" * len(n_i)))
+ logger.info("P_2 " + ("{:5.3} " * len(p2)).format(*p2))
if not np.all(problems == forgivable):
raise ValueError("P_2 values for site types %s larger than 1.0 + acceptable_overshoot (%f)" % (np.where(problems)[0], self.acceptable_overshoot))
- elif np.any(problems) and self.verbose:
- print("")
# Correct forgivable problems
p2[forgivable] = 1.0
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
index 5510d66..13da941 100644
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ b/sitator/dynamics/DiffusionPathwayAnalysis.py
@@ -5,6 +5,9 @@
from scipy.sparse.csgraph import connected_components
+import logging
+logger = logging.getLogger(__name__)
+
class DiffusionPathwayAnalysis(object):
"""Find connected diffusion pathways in a SiteNetwork.
@@ -19,15 +22,12 @@ class DiffusionPathwayAnalysis(object):
def __init__(self,
connectivity_threshold = 0.001,
- minimum_n_sites = 4,
- verbose = True):
+ minimum_n_sites = 4):
assert minimum_n_sites >= 0
self.connectivity_threshold = connectivity_threshold
self.minimum_n_sites = minimum_n_sites
- self.verbose = verbose
-
def run(self, sn):
"""
Expects a SiteNetwork that has had a JumpAnalysis run on it.
@@ -56,9 +56,8 @@ def run(self, sn):
is_pathway = counts >= self.minimum_n_sites
- if self.verbose:
- print("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps))
- print("Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway)))
+ logging.info("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps))
+ logging.info("Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway)))
translation = np.empty(n_ccs, dtype = np.int)
translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index fc594f9..24b304b 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -5,6 +5,9 @@
from sitator import SiteNetwork, SiteTrajectory
from sitator.visualization import plotter, plot_atoms, layers
+import logging
+logger = logging.getLogger(__name__)
+
class JumpAnalysis(object):
"""Given a SiteTrajectory, compute various statistics about the jumps it contains.
@@ -18,8 +21,8 @@ class JumpAnalysis(object):
- `total_corrected_residences`: Total number of frames when a particle was at the site,
*including* frames when an unassigned particle's last known site was this site.
"""
- def __init__(self, verbose = True):
- self.verbose = verbose
+ def __init__(self):
+ pass
def run(self, st):
"""Run the analysis.
@@ -28,8 +31,7 @@ def run(self, st):
"""
assert isinstance(st, SiteTrajectory)
- if self.verbose:
- print("Running JumpAnalysis...")
+ logger.info("Running JumpAnalysis...")
n_mobile = st.site_network.n_mobile
n_frames = st.n_frames
@@ -62,8 +64,8 @@ def run(self, st):
frame[unassigned] = last_known[unassigned]
fknown = frame >= 0
- if np.any(~fknown) and self.verbose:
- print(" at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
+ if np.any(~fknown):
+ logger.warning(" at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
# -- Update stats
total_time_spent_at_site[frame[fknown]] += 1
@@ -94,8 +96,8 @@ def run(self, st):
# The time before jumping to self should always be inf
assert not np.any(np.nonzero(avg_time_before_jump.diagonal()))
- if self.verbose and n_problems != 0:
- print("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
+ if n_problems != 0:
+ logger.warning("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
msk = avg_time_before_jump_n > 0
# Zeros -- i.e. no jumps -- should actualy be infs
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index 7abbe04..ba19cc3 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -4,6 +4,9 @@
from sitator.dynamics import JumpAnalysis
from sitator.util import PBCCalculator
+import logging
+logger = logging.getLogger(__name__)
+
class MergeSitesByDynamics(object):
"""Merges sites using dynamical data.
@@ -28,11 +31,8 @@ def __init__(self,
distance_threshold = 1.0,
post_check_thresh_factor = 1.5,
check_types = True,
- verbose = True,
iterlimit = 100,
markov_parameters = {}):
-
- self.verbose = verbose
self.distance_threshold = distance_threshold
self.post_check_thresh_factor = post_check_thresh_factor
self.check_types = check_types
@@ -47,7 +47,7 @@ def run(self, st):
# -- Compute jump statistics
if not st.site_network.has_attribute('p_ij'):
- ja = JumpAnalysis(verbose = self.verbose)
+ ja = JumpAnalysis()
ja.run(st)
pbcc = PBCCalculator(st.site_network.structure.cell)
@@ -83,17 +83,16 @@ def run(self, st):
connectivity_matrix[i, js_too_far] = 0
connectivity_matrix[js_too_far, i] = 0 # Symmetry
- if self.verbose and n_alarming_ignored_edges > 0:
- print(" At least %i site pairs with high (z-score > 3) fluxes were over the given distance cutoff.\n"
- " This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges)
+ if n_alarming_ignored_edges > 0:
+ logger.warning(" At least %i site pairs with high (z-score > 3) fluxes were over the given distance cutoff.\n"
+ " This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges)
# -- Do Markov Clustering
clusters = self._markov_clustering(connectivity_matrix, **self.markov_parameters)
new_n_sites = len(clusters)
- if self.verbose:
- print("After merge there will be %i sites" % new_n_sites)
+ logger.info("After merge there will be %i sites" % new_n_sites)
if self.check_types:
new_types = np.empty(shape = new_n_sites, dtype = np.int)
@@ -181,8 +180,7 @@ def _markov_clustering(self,
# -- Check converged
if np.allclose(m1, m2):
converged = True
- if self.verbose:
- print("Markov Clustering converged in %i iterations" % i)
+ logger.info("Markov Clustering converged in %i iterations" % i)
break
m1[:] = m2
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 0401c70..198e7ba 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -17,6 +17,8 @@ def tqdm(iterable, **kwargs):
from . import helpers
from sitator import SiteNetwork, SiteTrajectory
+import logging
+logger = logging.getLogger(__name__)
from functools import wraps
def analysis_result(func):
@@ -85,7 +87,8 @@ def __init__(self,
and self.n_multiple_assignments.
:param bool force_no_memmap: if True, landmark vectors will be stored only in memory.
Only useful if access to landmark vectors after the analysis has run is desired.
- :param bool verbose: If `True`, progress bars and messages will be printed to stdout.
+ :param bool verbose: If `True`, progress bars will be printed to stdout.
+ Other output is handled seperately through the `logging` module.
"""
self._cutoff_midpoint = cutoff_midpoint
@@ -151,8 +154,7 @@ def run(self, sn, frames):
n_frames = len(frames)
- if self.verbose:
- print("--- Running Landmark Analysis ---")
+ logger.info("--- Running Landmark Analysis ---")
# Create PBCCalculator
self._pbcc = PBCCalculator(sn.structure.cell)
@@ -171,7 +173,7 @@ def run(self, sn, frames):
site_vert_dists[i, :len(polyhedron)] = dists
# -- Step 2: Compute landmark vectors
- if self.verbose: print(" - computing landmark vectors -")
+ logger.info(" - computing landmark vectors -")
# Compute landmark vectors
# The dimension of one landmark vector is the number of Voronoi regions
@@ -188,10 +190,15 @@ def run(self, sn, frames):
helpers._fill_landmark_vectors(self, sn, verts_np, site_vert_dists,
frames, check_for_zeros = self.check_for_zero_landmarks,
- tqdm = tqdm)
+ tqdm = tqdm, logger = logger)
+
+ if not self.check_for_zero_landmarks and self.n_all_zero_lvecs > 0:
+ logger.warning(" Had %i all-zero landmark vectors; no error because `check_for_zero_landmarks = False`." % self.n_all_zero_lvecs)
+ elif self.check_for_zero_landmarks:
+ assert self.n_all_zero_lvecs == 0
# -- Step 3: Cluster landmark vectors
- if self.verbose: print(" - clustering landmark vectors -")
+ logger.info(" - clustering landmark vectors -")
# - Preprocess -
self._do_peak_evening()
@@ -204,8 +211,7 @@ def run(self, sn, frames):
min_samples = self._minimum_site_occupancy / float(sn.n_mobile),
verbose = self.verbose)
- if self.verbose:
- print(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))))
+ logging.info(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))))
# reshape lables and confidences
lmk_lbls.shape = (n_frames, sn.n_mobile)
@@ -216,8 +222,7 @@ def run(self, sn, frames):
if n_sites < (sn.n_mobile / self.max_mobile_per_site):
raise MultipleOccupancyError("There are %i mobile particles, but only identified %i sites. With %i max_mobile_per_site, this is an error. Check clustering_params." % (sn.n_mobile, n_sites, self.max_mobile_per_site))
- if self.verbose:
- print(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
+ logging.info(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
# Check that multiple particles are never assigned to one site at the
# same time, cause that would be wrong.
diff --git a/sitator/landmark/cluster/dbscan.py b/sitator/landmark/cluster/dbscan.py
index 381fa6d..769152b 100644
--- a/sitator/landmark/cluster/dbscan.py
+++ b/sitator/landmark/cluster/dbscan.py
@@ -4,6 +4,9 @@
import numbers
from sklearn.cluster import DBSCAN
+import logging
+logger = logging.getLogger(__name__)
+
DEFAULT_PARAMS = {
'eps' : 0.05,
'min_samples' : 5,
@@ -13,7 +16,8 @@
def do_landmark_clustering(landmark_vectors,
clustering_params,
min_samples,
- verbose):
+ verbose = False):
+ # `verbose` ignored.
tmp = DEFAULT_PARAMS.copy()
tmp.update(clustering_params)
@@ -53,8 +57,7 @@ def do_landmark_clustering(landmark_vectors,
# Do the remapping
lmk_lbls = trans_table[lmk_lbls]
- if verbose:
- print("DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \
+ logging.info("DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \
(len(to_remove), len(cluster_counts), min_samples, min_n_samples_cluster, len(cluster_counts) - len(to_remove)))
# Remove counts
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index d54bbfd..de508b9 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -9,7 +9,7 @@ from sitator.landmark import StaticLatticeError, ZeroLandmarkError
ctypedef double precision
-def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i):
+def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i, logger = None):
if self._landmark_dimension is None:
raise ValueError("_fill_landmark_vectors called before Voronoi!")
@@ -42,6 +42,8 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
self._cutoff_steepness,
0.0001)
+ cdef Py_ssize_t n_all_zero_lvecs = 0
+
cdef Py_ssize_t landmark_dim = self._landmark_dimension
cdef Py_ssize_t current_landmark_i = 0
# Iterate through time
@@ -66,7 +68,7 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
if static_positions_seen[nearest_static_position]:
# We've already seen this one... error
- print "Static atom %i is the closest to more than one static lattice position" % nearest_static_position
+ logger.warning("Static atom %i is the closest to more than one static lattice position" % nearest_static_position)
#raise ValueError("Static atom %i is the closest to more than one static lattice position" % nearest_static_position)
static_positions_seen[nearest_static_position] = True
@@ -111,17 +113,24 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
cutoff_round_to_zero,
temp_distbuff)
- if check_for_zeros and (np.count_nonzero(self._landmark_vectors[current_landmark_i]) == 0):
- raise ZeroLandmarkError(mobile_index = j, frame = i)
+ if np.count_nonzero(self._landmark_vectors[current_landmark_i]) == 0:
+ if check_for_zeros:
+ raise ZeroLandmarkError(mobile_index = j, frame = i)
+ else:
+ n_all_zero_lvecs += 1
current_landmark_i += 1
+ self.n_all_zero_lvecs = n_all_zero_lvecs
+
+
cdef precision cutoff_round_to_zero_point(precision cutoff_midpoint,
precision cutoff_steepness,
precision threshold):
# Computed by solving for x:
return cutoff_midpoint + log((1/threshold) - 1.) / cutoff_steepness
+
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py
index 40cdd7a..d94682d 100644
--- a/sitator/landmark/pointmerge.py
+++ b/sitator/landmark/pointmerge.py
@@ -15,8 +15,7 @@ def merge_points_soap_paths(tsoap,
connectivity_dict,
threshold,
n_steps = 5,
- sanity_check_cutoff = np.inf,
- verbose = True):
+ sanity_check_cutoff = np.inf):
"""Merge points using SOAP paths method.
:param SOAP tsoap: to compute SOAPs with.
diff --git a/sitator/misc/GenerateAroundSites.py b/sitator/misc/GenerateAroundSites.py
index 84f391f..12a6138 100644
--- a/sitator/misc/GenerateAroundSites.py
+++ b/sitator/misc/GenerateAroundSites.py
@@ -14,9 +14,9 @@ def run(self, sn):
out = sn.copy()
pbcc = PBCCalculator(sn.structure.cell)
- print(out.centers.shape)
+
newcenters = out.centers.repeat(self.n, axis = 0)
- print(newcenters.shape)
+ assert len(newcenters) == self.n * len(out.centers)
newcenters += self.sigma * np.random.standard_normal(size = newcenters.shape)
pbcc.wrap_points(newcenters)
diff --git a/sitator/misc/SiteVolumes.py b/sitator/misc/SiteVolumes.py
index 095bc24..5095dae 100644
--- a/sitator/misc/SiteVolumes.py
+++ b/sitator/misc/SiteVolumes.py
@@ -6,6 +6,9 @@
from sitator import SiteTrajectory
from sitator.util import PBCCalculator
+import logging
+logger = logging.getLogger(__name__)
+
class SiteVolumes(object):
"""Computes the volumes of convex hulls around all positions associated with a site.
@@ -36,7 +39,7 @@ def run(self, st):
try:
hull = ConvexHull(pos)
except QhullError as qhe:
- print("For site %i, iter %i: %s" % (site, i, qhe))
+ logging.warning("For site %i, iter %i: %s" % (site, i, qhe))
vols[site] = np.nan
areas[site] = np.nan
continue
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index 9ba7a9b..0eb22f8 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -46,15 +46,13 @@ class SOAP(object, metaclass=ABCMeta):
from .backend.dscribe import dscribe_soap_backend as backend_dscribe
def __init__(self, tracer_atomic_number, environment = None,
- soap_mask = None, verbose =True,
+ soap_mask = None,
backend = None):
from ase.data import atomic_numbers
self.tracer_atomic_number = tracer_atomic_number
self._soap_mask = soap_mask
- self._verbose = verbose
-
if backend is None:
backend = SOAP.dscribe_soap_backend
self._backend = backend
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index 465ff53..afeccce 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -1,7 +1,3 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
from sitator.misc import GenerateAroundSites
@@ -14,6 +10,9 @@
import itertools
+import logging
+logger = logging.getLogger(__name__)
+
try:
import pydpc
except ImportError:
@@ -30,11 +29,10 @@ class SiteTypeAnalysis(object):
"""
def __init__(self, descriptor,
min_pca_variance = 0.9, min_pca_dimensions = 2,
- verbose = True, n_site_types_max = 20):
+ n_site_types_max = 20):
self.descriptor = descriptor
self.min_pca_variance = min_pca_variance
self.min_pca_dimensions = min_pca_dimensions
- self.verbose = verbose
self.n_site_types_max = n_site_types_max
self._n_dvecs = None
@@ -44,8 +42,7 @@ def run(self, descriptor_input, **kwargs):
raise ValueError("Can't run SiteTypeAnalysis more than once!")
# -- Sample enough points
- if self.verbose:
- print(" -- Running SiteTypeAnalysis --")
+ logger.info(" -- Running SiteTypeAnalysis --")
if isinstance(descriptor_input, SiteNetwork):
sn = descriptor_input.copy()
@@ -55,30 +52,26 @@ def run(self, descriptor_input, **kwargs):
raise RuntimeError("Input {}".format(type(descriptor_input)))
# -- Compute descriptor vectors
- if self.verbose:
- print(" - Computing Descriptor Vectors")
+ logger.info(" - Computing Descriptor Vectors")
self.dvecs, dvecs_to_site = self.descriptor.get_descriptors(descriptor_input, **kwargs)
assert len(self.dvecs) == len(dvecs_to_site), "Length mismatch in descriptor return values"
assert np.min(dvecs_to_site) == 0 and np.max(dvecs_to_site) < sn.n_sites
# -- Dimensionality Reduction
- if self.verbose:
- print(" - Clustering Descriptor Vectors")
+ logger.info(" - Clustering Descriptor Vectors")
self.pca = PCA(self.min_pca_variance)
pca_dvecs = self.pca.fit_transform(self.dvecs)
if pca_dvecs.shape[1] < self.min_pca_dimensions:
- if self.verbose:
- print(" PCA accounted for %.0f%% variance in only %i dimensions; less than minimum of %.0f." % (100.0 * np.sum(self.pca.explained_variance_ratio_), pca_dvecs.shape[1], self.min_pca_dimensions))
- print(" Forcing PCA to use %i dimensions." % self.min_pca_dimensions)
- self.pca = PCA(n_components = self.min_pca_dimensions)
- pca_dvecs = self.pca.fit_transform(self.dvecs)
+ logger.info(" PCA accounted for %.0f%% variance in only %i dimensions; less than minimum of %.0f." % (100.0 * np.sum(self.pca.explained_variance_ratio_), pca_dvecs.shape[1], self.min_pca_dimensions))
+ logger.info(" Forcing PCA to use %i dimensions." % self.min_pca_dimensions)
+ self.pca = PCA(n_components = self.min_pca_dimensions)
+ pca_dvecs = self.pca.fit_transform(self.dvecs)
self.dvecs = pca_dvecs
- if self.verbose:
- print((" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1])))
+ logger.info((" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1])))
# -- Do clustering
# pydpc requires a C-contiguous array
@@ -121,9 +114,8 @@ def run(self, descriptor_input, **kwargs):
assert self.n_types == len(site_type_counts), "Got %i types from pydpc, but counted %i" % (self.n_types, len(site_type_counts))
- if self.verbose:
- print(" Found %i site type clusters" % self.n_types )
- print(" Failed to assign %i/%i descriptor vectors to clusters." % (self._n_unassigned, self._n_dvecs))
+ logger.info(" Found %i site type clusters" % self.n_types )
+ logger.info(" Failed to assign %i/%i descriptor vectors to clusters." % (self._n_unassigned, self._n_dvecs))
# -- Voting
types = np.empty(shape = sn.n_sites, dtype = np.int)
@@ -144,12 +136,11 @@ def run(self, descriptor_input, **kwargs):
n_sites_of_each_type = np.bincount(types, minlength = self.n_types)
sn.site_types = types
- if self.verbose:
- print((" " + "Type {:<2} " * self.n_types).format(*range(self.n_types)))
- print(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type))
+ logger.info((" " + "Type {:<2} " * self.n_types).format(*range(self.n_types)))
+ logger.info(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type))
- if np.any(n_sites_of_each_type == 0):
- print("WARNING: Had site types with no sites; check clustering settings/voting!")
+ if np.any(n_sites_of_each_type == 0):
+ logger.warning("WARNING: Had site types with no sites; check clustering settings/voting!")
return sn
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 213e957..52cf883 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -14,6 +14,9 @@ except:
N_SITES_ALLOC_INCREMENT = 100
+import logging
+logger = logging.getLogger(__name__)
+
class OneValueListlike(object):
def __init__(self, value, length = np.inf):
self.length = length
@@ -222,9 +225,8 @@ class DotProdClassifier(object):
# Then we removed everything...
raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask))
- if verbose:
- print "DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \
- (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts))
+ logger.info("DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \
+ (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts)))
# Do another predict -- this could be more efficient, but who cares?
labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
@@ -297,8 +299,8 @@ class DotProdClassifier(object):
labels[i] = assigned_to
confidences[i] = assignment_confidence
- if verbose and zeros_count > 0:
- print "Encountered %i zero vectors during prediction" % zeros_count
+ if zeros_count > 0:
+ logger.warning("Encountered %i zero vectors during prediction" % zeros_count)
if return_confidences:
return labels, confidences
diff --git a/sitator/util/zeo.py b/sitator/util/zeo.py
index 2c2931b..290903a 100644
--- a/sitator/util/zeo.py
+++ b/sitator/util/zeo.py
@@ -16,6 +16,9 @@
from sitator.util import PBCCalculator
+import logging
+logger = logging.getLogger(__name__)
+
# TODO: benchmark CUC vs CIF
class Zeopy(object):
@@ -40,7 +43,7 @@ def __enter__(self):
def __exit__(self, *args):
shutil.rmtree(self._tmpdir)
- def voronoi(self, structure, radial = False, verbose=True):
+ def voronoi(self, structure, radial = False):
"""
:param Atoms structure: The ASE Atoms to compute the Voronoi decomposition of.
"""
@@ -66,12 +69,11 @@ def voronoi(self, structure, radial = False, verbose=True):
output = subprocess.check_output([self._exe] + args + ["-v1", v1out, "-nt2", outp, inp],
stderr = subprocess.STDOUT)
except subprocess.CalledProcessError as e:
- print("Zeo++ returned an error:", file = sys.stderr)
- print(e.output, file = sys.stderr)
+ logger.error("Zeo++ returned an error:")
+ logger.error(e.output)
raise
- if verbose:
- print(output)
+ logger.debug(output)
with open(outp, "r") as outf:
verts, edges = self.parse_nt2(outf.readlines())
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 630a866..73ededb 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -1,7 +1,3 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
import itertools
diff --git a/sitator/visualization/common.py b/sitator/visualization/common.py
index 3886dbc..74db185 100644
--- a/sitator/visualization/common.py
+++ b/sitator/visualization/common.py
@@ -69,7 +69,6 @@ def plotter_wraped(*args, **kwargs):
@plotter(is3D = True)
def layers(*args, **fax):
i = fax['i']
- print(i)
for p, kwargs in args:
p(fig = fax['fig'], ax = fax['ax'], i = i, **kwargs)
i += 1
diff --git a/sitator/voronoi/VoronoiSiteGenerator.py b/sitator/voronoi/VoronoiSiteGenerator.py
index 64c2058..544ec16 100644
--- a/sitator/voronoi/VoronoiSiteGenerator.py
+++ b/sitator/voronoi/VoronoiSiteGenerator.py
@@ -10,12 +10,10 @@ class VoronoiSiteGenerator(object):
:param str zeopp_path: Path to the Zeo++ `network` executable
:param bool radial: Whether to use the radial Voronoi transform. Defaults to,
and should typically be, False.
- :param bool verbose:
"""
- def __init__(self, zeopp_path = "network", radial = False, verbose = True):
+ def __init__(self, zeopp_path = "network", radial = False):
self._radial = radial
- self._verbose = verbose
self._zeopy = Zeopy(zeopp_path)
def run(self, sn):
@@ -24,8 +22,7 @@ def run(self, sn):
with self._zeopy:
nodes, verts, edges, _ = self._zeopy.voronoi(sn.static_structure,
- radial = self._radial,
- verbose = self._verbose)
+ radial = self._radial)
out = sn.copy()
out.centers = nodes
From 863232f812a1cc4d51948e52cdff2f53641629ee Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 12 Jun 2019 15:18:55 -0400
Subject: [PATCH 013/129] Pass through atomic sigma
---
sitator/site_descriptors/backend/dscribe.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/sitator/site_descriptors/backend/dscribe.py b/sitator/site_descriptors/backend/dscribe.py
index d9f9de3..ed7b140 100644
--- a/sitator/site_descriptors/backend/dscribe.py
+++ b/sitator/site_descriptors/backend/dscribe.py
@@ -25,6 +25,7 @@ def dscribe_soap(structure, positions):
nmax = soap_opts['n_max'],
lmax = soap_opts['l_max'],
rbf = soap_opts['rbf'],
+ sigma = soap_opts['atom_sigma'],
periodic = np.all(structure.pbc),
sparse = False
)
From aa8fa998a061109e0ba604816da285ba4561fcaa Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 12 Jun 2019 15:19:16 -0400
Subject: [PATCH 014/129] Correct case where no edges cross cell boundaries
---
sitator/visualization/SiteNetworkPlotter.py | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 73ededb..8e8ebfa 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -268,12 +268,14 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
ax.add_collection(lc)
# -- Plot new sites
- sn2 = sn[sites_to_plot]
- sn2.update_centers(np.asarray(sites_to_plot_positions))
-
- pts_params = dict(self.plot_points_params)
- pts_params['alpha'] = 0.2
- return self._site_layers(sn2, pts_params)
+ if len(sites_to_plot) > 0:
+ sn2 = sn[sites_to_plot]
+ sn2.update_centers(np.asarray(sites_to_plot_positions))
+ pts_params = dict(self.plot_points_params)
+ pts_params['alpha'] = 0.2
+ return self._site_layers(sn2, pts_params)
+ else:
+ return []
else:
return []
From 1d68b653f3881d83c938bbc900df08160a37511e Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 14 Jun 2019 11:58:49 -0400
Subject: [PATCH 015/129] Correct IO behaviour under MPI
---
sitator/SiteNetwork.py | 2 +-
sitator/util/zeo.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 7e40cff..d4d9891 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -113,7 +113,7 @@ def save(self, file):
"""Save this SiteNetwork to a tar archive."""
with tempfile.TemporaryDirectory() as tmpdir:
# -- Write the structure
- ase.io.write(os.path.join(tmpdir, self._STRUCT_FNAME), self.structure)
+ ase.io.write(os.path.join(tmpdir, self._STRUCT_FNAME), self.structure, parallel = False)
# -- Write masks
np.save(os.path.join(tmpdir, self._SMASK_FNAME), self.static_mask)
np.save(os.path.join(tmpdir, self._MMASK_FNAME), self.mobile_mask)
diff --git a/sitator/util/zeo.py b/sitator/util/zeo.py
index 290903a..338e6e0 100644
--- a/sitator/util/zeo.py
+++ b/sitator/util/zeo.py
@@ -55,7 +55,7 @@ def voronoi(self, structure, radial = False):
outp = os.path.join(self._tmpdir, "out.nt2")
v1out = os.path.join(self._tmpdir, "out.v1")
- ase.io.write(inp, structure)
+ ase.io.write(inp, structure, parallel = False)
# with open(inp, "w") as inf:
# inf.write(self.ase2cuc(structure))
@@ -70,7 +70,7 @@ def voronoi(self, structure, radial = False):
stderr = subprocess.STDOUT)
except subprocess.CalledProcessError as e:
logger.error("Zeo++ returned an error:")
- logger.error(e.output)
+ logger.error(str(e.output))
raise
logger.debug(output)
From ffebc1a6eee41d6c49b0c4c3289e064365d75d34 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 14 Jun 2019 11:59:09 -0400
Subject: [PATCH 016/129] Removed unneeded dependencies
---
setup.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/setup.py b/setup.py
index 54bbfa8..11024c9 100644
--- a/setup.py
+++ b/setup.py
@@ -21,8 +21,6 @@
"matplotlib",
"ase",
"tqdm",
- "backports.tempfile",
- "future",
"sklearn"
],
extras_require = {
From 8cb05cf00e9ddf3f9cbe19bd05f931595f97ec40 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 14 Jun 2019 13:18:34 -0400
Subject: [PATCH 017/129] Refactored out TQDM imports
---
README.md | 4 ++++
sitator/landmark/LandmarkAnalysis.py | 8 ++------
sitator/landmark/pointmerge.py | 8 +-------
sitator/site_descriptors/SOAP.py | 10 +---------
sitator/util/DotProdClassifier.pyx | 8 +-------
sitator/util/progress.py | 14 ++++++++++++++
6 files changed, 23 insertions(+), 29 deletions(-)
create mode 100644 sitator/util/progress.py
diff --git a/README.md b/README.md
index 9aa7a09..6dd63a2 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,10 @@ Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4
All individual classes and parameters are documented with docstrings in the source code.
+## Global Options
+
+`sitator` uses the `tqdm.autonotebook` tool to automatically produce the correct fancy progress bars for terminals and iPython notebooks. To disable all progress bars, run with the environment variable `SITATOR_PROGRESSBAR` set to `false`.
+
## License
This software is made available under the MIT License. See `LICENSE` for more details.
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 198e7ba..c9ffa2c 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -2,14 +2,10 @@
from sitator.landmark import MultipleOccupancyError
from sitator.util import PBCCalculator
+from sitator.util.progress import tqdm
-# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
import sys
-try:
- from tqdm.autonotebook import tqdm
-except:
- def tqdm(iterable, **kwargs):
- return iterable
+
import importlib
import tempfile
diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py
index d94682d..c8dcc8c 100644
--- a/sitator/landmark/pointmerge.py
+++ b/sitator/landmark/pointmerge.py
@@ -1,13 +1,7 @@
import numpy as np
-# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
-import sys
-try:
- from tqdm.autonotebook import tqdm
-except:
- def tqdm(iterable, **kwargs):
- return iterable
+from sitator.util.progress import tqdm
def merge_points_soap_paths(tsoap,
pbcc,
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index 0eb22f8..3a05fe9 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -4,18 +4,10 @@
from sitator.SiteNetwork import SiteNetwork
from sitator.SiteTrajectory import SiteTrajectory
+from sitator.util.progress import tqdm
from ase.data import atomic_numbers
-
-# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
-import sys
-try:
- from tqdm.autonotebook import tqdm
-except:
- def tqdm(iterable, **kwargs):
- return iterable
-
class SOAP(object, metaclass=ABCMeta):
"""Abstract base class for computing SOAP vectors in a SiteNetwork.
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 52cf883..6e03f17 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -4,13 +4,7 @@ import numpy as np
import numbers
-import sys
-
-try:
- from tqdm.autonotebook import tqdm
-except:
- def tqdm(iterable, **kwargs):
- return iterable
+from sitator.util.progress import tqdm
N_SITES_ALLOC_INCREMENT = 100
diff --git a/sitator/util/progress.py b/sitator/util/progress.py
new file mode 100644
index 0000000..4ffe766
--- /dev/null
+++ b/sitator/util/progress.py
@@ -0,0 +1,14 @@
+import os
+
+progress = os.getenv('SITATOR_PROGRESSBAR', 'true').lower()
+progress = (progress == 'true') or (progress == 'yes') or (progress == 'on')
+
+if progress:
+ try:
+ from tqdm.autonotebook import tqdm
+ except:
+ def tqdm(iterable, **kwargs):
+ return iterable
+else:
+ def tqdm(iterable, **kwargs):
+ return iterable
From e74649f421f614644d36cff61d4d74ce05547af6 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 14 Jun 2019 15:40:32 -0400
Subject: [PATCH 018/129] More flexible site merging behaviour
---
sitator/dynamics/JumpAnalysis.py | 8 +++
sitator/dynamics/MergeSitesByDynamics.py | 87 ++++++++++++++++++++++--
sitator/misc/NAvgsPerSite.py | 2 +-
sitator/util/PBCCalculator.pyx | 15 ++++
4 files changed, 106 insertions(+), 6 deletions(-)
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index 24b304b..64725d6 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -105,6 +105,14 @@ def run(self, st):
# Do mean
avg_time_before_jump[msk] /= avg_time_before_jump_n[msk]
+ if st.site_network.has_attribute('n_ij'):
+ st.site_network.remove_attribute('n_ij')
+ st.site_network.remove_attribute('p_ij')
+ st.site_network.remove_attribute('jump_lag')
+ st.site_network.remove_attribute('residence_times')
+ st.site_network.remove_attribute('occupancy_freqs')
+ st.site_network.remove_attribute('total_corrected_residences')
+
st.site_network.add_edge_attribute('jump_lag', avg_time_before_jump)
st.site_network.add_edge_attribute('n_ij', n_ij)
st.site_network.add_edge_attribute('p_ij', n_ij / total_time_spent_at_site)
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index ba19cc3..732622c 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -7,13 +7,26 @@
import logging
logger = logging.getLogger(__name__)
+class MergeSitesError(Exception):
+ pass
+
+class MergedSitesTooDistantError(MergeSitesError):
+ pass
+
+class TooFewMergedSitesError(MergeSitesError):
+ pass
+
+
+
class MergeSitesByDynamics(object):
"""Merges sites using dynamical data.
Given a SiteTrajectory, merges sites using Markov Clustering.
:param float distance_threshold: Don't merge sites further than this
- in real space.
+ in real space. Zeros out the connectivity_matrix at distances greater than
+ this; a hard, step function style cutoff. For a more gentle cutoff, try
+ changing `connectivity_matrix_generator` to incorporate distance.
:param float post_check_thresh_factor: Throw an error if proposed merge sites
are further than this * distance_threshold away. Only a sanity check; not
a hard guerantee. Can be `None`; defaults to `1.5`. Can be loosely
@@ -28,17 +41,77 @@ class MergeSitesByDynamics(object):
Valid keys are ``'inflation'``, ``'expansion'``, and ``'pruning_threshold'``.
"""
def __init__(self,
+ connectivity_matrix_generator = None,
distance_threshold = 1.0,
post_check_thresh_factor = 1.5,
check_types = True,
iterlimit = 100,
markov_parameters = {}):
+
+ if connectivity_matrix_generator is None:
+ connectivity_matrix_generator = MergeSitesByDynamics.connectivity_n_ij
+ assert callable(connectivity_matrix_generator)
+
+ self.connectivity_matrix_generator = connectivity_matrix_generator
self.distance_threshold = distance_threshold
self.post_check_thresh_factor = post_check_thresh_factor
self.check_types = check_types
self.iterlimit = iterlimit
self.markov_parameters = markov_parameters
+ # Connectivity Matrix Generation Schemes:
+
+ @staticmethod
+ def connectivity_n_ij(sn):
+ """Basic default connectivity scheme: uses n_ij directly as connectivity matrix.
+
+ Works well for systems with sufficient statistics.
+ """
+ return sn.n_ij
+
+ @staticmethod
+ def connectivity_jump_lag_biased(jump_lag_coeff = 1.0,
+ jump_lag_sigma = 20.0,
+ jump_lag_cutoff = np.inf,
+ distance_coeff = 0.5,
+ distance_sigma = 1.0):
+ """Bias the typical connectivity matrix p_ij with jump lag and distance contributions.
+
+ The jump lag and distance are processed through Gaussian functions with
+ the given sigmas (i.e. higher jump lag/larger distance => lower
+ connectivity value). These matrixes are then added to p_ij, with a prefactor
+ of `jump_lag_coeff` and `distance_coeff`.
+
+ Site pairs with jump lags greater than `jump_lag_cutoff` have their bias
+ set to zero regardless of `jump_lag_sigma`. Defaults to `inf`.
+ """
+ def cfunc(sn):
+ jl = sn.jump_lag.copy()
+ jl -= 1.0 # Center it around 1 since that's the minimum lag, 1 frame
+ jl /= jump_lag_sigma
+ np.square(jl, out = jl)
+ jl *= -0.5
+ np.exp(jl, out = jl) # exp correctly takes the -infs to 0
+
+ jl[sn.jump_lag > jump_lag_cutoff] = 0.
+
+ # Distance term
+ pbccalc = PBCCalculator(sn.structure.cell)
+ dists = pbccalc.pairwise_distances(sn.centers)
+ dmat = dists.copy()
+
+ # We want to strongly boost the similarity of *very* close sites
+ dmat /= distance_sigma
+ np.square(dmat, out = dmat)
+ dmat *= -0.5
+ np.exp(dmat, out = dmat)
+
+ return sn.p_ij + jump_lag_coeff * jl + distance_coeff * dmat
+
+ return cfunc
+
+ # Real methods
+
def run(self, st):
"""Takes a SiteTrajectory and returns a SiteTrajectory, including a new SiteNetwork."""
@@ -56,7 +129,7 @@ def run(self, st):
site_types = st.site_network.site_types
# -- Build connectivity_matrix
- connectivity_matrix = st.site_network.n_ij.copy()
+ connectivity_matrix = self.connectivity_matrix_generator(st.site_network).copy()
n_sites_before = st.site_network.n_sites
assert n_sites_before == connectivity_matrix.shape[0]
@@ -94,6 +167,9 @@ def run(self, st):
logger.info("After merge there will be %i sites" % new_n_sites)
+ if new_n_sites < np.sum(st.site_network.mobile_mask):
+ raise TooFewMergedSitesError("There are %i mobile atoms in this system, but only %i sites after merge" % (np.sum(st.site_network.mobile_mask), new_n_sites))
+
if self.check_types:
new_types = np.empty(shape = new_n_sites, dtype = np.int)
@@ -108,7 +184,7 @@ def run(self, st):
if np.any(translation[mask] != -1):
# We've assigned a different cluster for this before... weird
# degeneracy
- raise ValueError("Markov clustering tried to merge site(s) into more than one new site")
+ raise ValueError("Markov clustering tried to merge site(s) into more than one new site. This shouldn't happen.")
translation[mask] = newsite
to_merge = site_centers[mask]
@@ -116,8 +192,8 @@ def run(self, st):
# Check distances
if not self.post_check_thresh_factor is None:
dists = pbcc.distances(to_merge[0], to_merge[1:])
- assert np.all(dists < self.post_check_thresh_factor * self.distance_threshold), \
- "Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold)
+ if not np.all(dists < self.post_check_thresh_factor * self.distance_threshold):
+ raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold))
# New site center
new_centers[newsite] = pbcc.average(to_merge)
@@ -142,6 +218,7 @@ def run(self, st):
return newst
+
def _markov_clustering(self,
transition_matrix,
expansion = 2,
diff --git a/sitator/misc/NAvgsPerSite.py b/sitator/misc/NAvgsPerSite.py
index afcd2a5..013c4e2 100644
--- a/sitator/misc/NAvgsPerSite.py
+++ b/sitator/misc/NAvgsPerSite.py
@@ -69,6 +69,6 @@ def run(self, st):
sn.centers = centers[:current_idex]
sn.site_types = types[:current_idex]
- assert not (np.any(np.isnan(sn.centers)) or np.any(np.isnan(sn.site_types)))
+ assert not (np.isnan(np.sum(sn.centers)) or np.isnan(np.sum(sn.site_types)))
return sn
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index f34830d..99ef20d 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -36,6 +36,21 @@ cdef class PBCCalculator(object):
def cell_centroid(self):
return self._cell_centroid
+ cpdef pairwise_distances(self, pts):
+ """Compute the pairwise distance matrix of `pts` with itself.
+
+ :returns ndarray (len(pts), len(pts)): distances
+ """
+ out = np.empty(shape = (len(pts), len(pts)), dtype = pts.dtype)
+
+ buf = pts.copy()
+
+ for i in xrange(len(pts)):
+ self.distances(pts[i], buf, in_place = True, out = out[i])
+ buf[:] = pts
+
+ return out
+
cpdef distances(self, pt1, pts2, in_place = False, out = None):
"""Compute the Euclidean distances from pt1 to all points in pts2, using
shift-and-wrap.
From d97cb6b4233695fa2ec184b6a7f99d15b34d62e3 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 19 Jun 2019 13:25:50 -0400
Subject: [PATCH 019/129] Added `const` for unmodified buffers
---
sitator/util/PBCCalculator.pyx | 22 ++++++++++++++++++----
1 file changed, 18 insertions(+), 4 deletions(-)
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index 99ef20d..fe97200 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -18,6 +18,7 @@ cdef class PBCCalculator(object):
cdef cell_precision [:] _cell_centroid
cdef cell_precision [:, :] _cell
+
def __init__(self, cell):
"""
:param DxD ndarray: the unit cell -- an array of cell vectors, like the
@@ -32,10 +33,12 @@ cdef class PBCCalculator(object):
self._cell_mat_inverse_array = np.asarray(cellmat.I)
self._cell_centroid = np.sum(0.5 * cell, axis = 0)
+
@property
def cell_centroid(self):
return self._cell_centroid
+
cpdef pairwise_distances(self, pts):
"""Compute the pairwise distance matrix of `pts` with itself.
@@ -51,6 +54,7 @@ cdef class PBCCalculator(object):
return out
+
cpdef distances(self, pt1, pts2, in_place = False, out = None):
"""Compute the Euclidean distances from pt1 to all points in pts2, using
shift-and-wrap.
@@ -91,6 +95,7 @@ cdef class PBCCalculator(object):
#return np.linalg.norm(self._cell_centroid - pts2, axis = 1)
return np.sqrt(out, out = out)
+
cpdef average(self, points, weights = None):
"""Average position of a "cloud" of points using the shift-and-wrap hack.
@@ -121,6 +126,7 @@ cdef class PBCCalculator(object):
return out
+
cpdef time_average(self, frames):
"""Do multiple PBC correct means. Frames is n_frames x n_pts x 3.
@@ -152,6 +158,7 @@ cdef class PBCCalculator(object):
del posbuf
return out
+
cpdef void wrap_point(self, precision [:] pt):
"""Wrap a single point into the unit cell, IN PLACE. 3D only."""
cdef cell_precision [:, :] cell = self._cell_mat_array
@@ -173,7 +180,8 @@ cdef class PBCCalculator(object):
pt[0] = buf[0]; pt[1] = buf[1]; pt[2] = buf[2];
- cpdef bint is_in_unit_cell(self, precision [:] pt):
+
+ cpdef bint is_in_unit_cell(self, const precision [:] pt):
cdef cell_precision [:, :] cell = self._cell_mat_array
cdef cell_precision [:, :] cell_I = self._cell_mat_inverse_array
@@ -189,13 +197,15 @@ cdef class PBCCalculator(object):
return (buf[0] < 1.0) and (buf[1] < 1.0) and (buf[2] < 1.0) and \
(buf[0] >= 0.0) and (buf[1] >= 0.0) and (buf[2] >= 0.0)
- cpdef bint all_in_unit_cell(self, precision [:, :] pts):
+
+ cpdef bint all_in_unit_cell(self, const precision [:, :] pts):
for pt in pts:
if not self.is_in_unit_cell(pt):
return False
return True
- cpdef bint is_in_image_of_cell(self, precision [:] pt, image):
+
+ cpdef bint is_in_image_of_cell(self, const precision [:] pt, image):
cdef cell_precision [:, :] cell = self._cell_mat_array
cdef cell_precision [:, :] cell_I = self._cell_mat_inverse_array
@@ -214,6 +224,7 @@ cdef class PBCCalculator(object):
return out
+
cpdef void to_cell_coords(self, precision [:, :] points):
"""Convert to cell coordinates in place."""
assert points.shape[1] == 3, "Points must be 3D"
@@ -235,7 +246,8 @@ cdef class PBCCalculator(object):
# Store into points
points[i, 0] = buf[0]; points[i, 1] = buf[1]; points[i, 2] = buf[2];
- cpdef int min_image(self, precision [:] ref, precision [:] pt):
+
+ cpdef int min_image(self, const precision [:] ref, precision [:] pt):
"""Find the minimum image of `pt` relative to `ref`. In place in pt.
Uses the brute force algorithm for correctness; returns the minimum image.
@@ -288,6 +300,7 @@ cdef class PBCCalculator(object):
return 100 * minimg[0] + 10 * minimg[1] + 1 * minimg[2]
+
cpdef void to_real_coords(self, precision [:, :] points):
"""Convert to real coords from crystal coords in place."""
assert points.shape[1] == 3, "Points must be 3D"
@@ -309,6 +322,7 @@ cdef class PBCCalculator(object):
# Store into points
points[i, 0] = buf[0]; points[i, 1] = buf[1]; points[i, 2] = buf[2];
+
cpdef void wrap_points(self, precision [:, :] points):
"""Wrap `points` into a unit cell, IN PLACE. 3D only.
"""
From 90086e4e45be7264504ca9402d47f9a6b4a96081 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 19 Jun 2019 15:26:23 -0400
Subject: [PATCH 020/129] Added true periodic pathway analysis
---
sitator/dynamics/DiffusionPathwayAnalysis.py | 142 +++++++++++++++++--
1 file changed, 134 insertions(+), 8 deletions(-)
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
index 13da941..51cee5d 100644
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ b/sitator/dynamics/DiffusionPathwayAnalysis.py
@@ -2,9 +2,13 @@
import numpy as np
import numbers
+import itertools
+from scipy.sparse import lil_matrix
from scipy.sparse.csgraph import connected_components
+from sitator.util import PBCCalculator
+
import logging
logger = logging.getLogger(__name__)
@@ -21,14 +25,16 @@ class DiffusionPathwayAnalysis(object):
NO_PATHWAY = -1
def __init__(self,
- connectivity_threshold = 0.001,
- minimum_n_sites = 4):
+ connectivity_threshold = 1,
+ true_periodic_pathways = True,
+ minimum_n_sites = 0):
assert minimum_n_sites >= 0
+ self.true_periodic_pathways = true_periodic_pathways
self.connectivity_threshold = connectivity_threshold
self.minimum_n_sites = minimum_n_sites
- def run(self, sn):
+ def run(self, sn, return_count = False):
"""
Expects a SiteNetwork that has had a JumpAnalysis run on it.
"""
@@ -48,20 +54,55 @@ def run(self, sn):
connectivity_matrix = sn.n_ij >= threshold
+ if self.true_periodic_pathways:
+ connectivity_matrix, mask_000 = self._build_mic_connmat(sn, connectivity_matrix)
+
n_ccs, ccs = connected_components(connectivity_matrix,
directed = False, # even though the matrix is symmetric
connection = 'weak') # diffusion could be unidirectional
_, counts = np.unique(ccs, return_counts = True)
- is_pathway = counts >= self.minimum_n_sites
+ if self.true_periodic_pathways:
+ # is_pathway = np.ones(shape = n_ccs, dtype = np.bool)
+ # We have to check that the pathways include a site and its periodic
+ # image, and throw out those that don't
+
+ new_n_ccs = 1
+ new_ccs = np.zeros(shape = len(sn), dtype = np.int)
+
+ for pathway_i in np.arange(n_ccs):
+ path_mask = ccs == pathway_i
+
+ if not np.any(path_mask & mask_000):
+ continue
+
+ sitenums = np.where(path_mask)[0] % len(sn) # Get unit cell site numbers of all sites in pathway
+ _, site_counts = np.unique(sitenums, return_counts = True)
+ if np.sum(site_counts) <= len(site_counts):
+ # Not a percolating path
+ continue
- logging.info("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps))
- logging.info("Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway)))
+ if len(site_counts) < self.minimum_n_sites:
+ continue
+ new_ccs[sitenums] = new_n_ccs
+ new_n_ccs += 1
+
+ n_ccs = new_n_ccs
+ ccs = new_ccs
+ is_pathway = np.in1d(np.arange(n_ccs), ccs)
+ is_pathway[0] = False # Cause this was the "unassigned" value
+ else:
+ is_pathway = counts >= self.minimum_n_sites
+
+ logging.info("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps))
+ logging.info("Found %i connected components, of which %i are large enough to qualify as pathways (%i sites)." % (n_ccs, np.sum(is_pathway), self.minimum_n_sites))
+
+ n_pathway = np.sum(is_pathway)
translation = np.empty(n_ccs, dtype = np.int)
translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY
- translation[is_pathway] = np.arange(np.sum(is_pathway))
+ translation[is_pathway] = np.arange(n_pathway)
node_pathways = translation[ccs]
@@ -74,4 +115,89 @@ def run(self, sn):
sn.add_site_attribute('site_diffusion_pathway', node_pathways)
sn.add_edge_attribute('edge_diffusion_pathway', outmat)
- return sn
+
+ if return_count:
+ return sn, n_pathway
+ else:
+ return sn
+
+
+ def _build_mic_connmat(self, sn, connectivity_matrix):
+ # We use a 3x3x3 = 27 supercell, so there are 27x as many sites
+ assert len(sn) == connectivity_matrix.shape[0]
+
+ images = np.asarray(list(itertools.product(range(-1, 2), repeat = 3)))
+ image_to_idex = dict((100 * (image[0] + 1) + 10 * (image[1] + 1) + (image[2] + 1), i) for i, image in enumerate(images))
+ n_images = len(images)
+ assert n_images == 27
+
+ n_sites = len(sn)
+ pos = sn.centers.copy() # TODO: copy not needed after reinstall of sitator!
+ n_total_sites = len(images) * n_sites
+ newmat = lil_matrix((n_total_sites, n_total_sites), dtype = np.bool)
+
+ mask_000 = np.zeros(shape = n_total_sites, dtype = np.bool)
+ index_000 = image_to_idex[111]
+ mask_000[index_000:index_000 + n_sites] = True
+
+ pbcc = PBCCalculator(sn.structure.cell)
+ buf = np.empty(shape = 3)
+
+ internal_mat = np.zeros_like(connectivity_matrix)
+ external_connections = []
+ for from_site, to_site in zip(*np.where(connectivity_matrix)):
+ buf[:] = pos[to_site]
+ if pbcc.min_image(pos[from_site], buf) == 111:
+ # If we're in the main image, keep the connection: it's internal
+ internal_mat[from_site, to_site] = True
+ else:
+ external_connections.append((from_site, to_site))
+
+ for image_idex, image in enumerate(images):
+ # Make the block diagonal
+ newmat[image_idex * n_sites:(image_idex + 1) * n_sites,
+ image_idex * n_sites:(image_idex + 1) * n_sites] = internal_mat
+
+ # Check all external connections from this image; add other sparse entries
+ for from_site, to_site in external_connections:
+ buf[:] = pos[to_site]
+ to_mic = pbcc.min_image(pos[from_site], buf)
+ to_in_image = image + [to_mic // 10**(2 - i) % 10 for i in range(3)]
+ if not np.any(np.abs(to_in_image) > 1):
+ to_in_image = 100 * (to_in_image[0] + 1) + 10 * (to_in_image[1] + 1) + (to_in_image[2] + 1)
+ newmat[image_idex * n_sites + from_site,
+ image_to_idex[to_in_image] * n_sites + to_site] = True
+
+ assert np.sum(newmat) >= n_images * np.sum(internal_mat) # Lowest it can be is if every one is internal
+
+ return newmat, mask_000
+
+ # def _is_pbc_connected(self, sn, connectivity_matrix, component_mask):
+ # n_sites = np.sum(component_mask)
+ # mat = connectivity_matrix[component_mask, component_mask]
+ # pos = sn.centers[component_mask]
+ #
+ # #mic_offsets = coo_matrix(shape = (n_sites, n_sites, 3), dtype = np.int)
+ # mic_offsets = np.full(shape = (n_sites, n_sites, 3), fill_value = -20, dtype = np.int)
+ # for from_site, to_site in zip(*np.where(mat)):
+ # to_mic = pbcc.min_image(pos[from_site], pos[to_site])
+ # mic_offsets[from_site, to_site] = [int(d) - 1 for d in str(to_mic)] # Get the individual digits
+ #
+ # # We always have seen every site in the 000 image, so it's implicit
+ # have_seen = {}
+ #
+ # for from_site in range(n_sites):
+ # # mark all images reachable from from_site as seen
+ # can_reach = np.where(mat[from_site])[0]
+ # reach_the_image = mic_offsets[from_site, can_reach]
+ # have_seen.update(set(zip(can_reach, reach_the_image)))
+ #
+ # for saw_site, as_image in have_seen:
+ # # Can we reach anything new from it?
+ # can_reach = np.where(mat[saw_site])[0]
+ # reach_the_image = as_image + mic_offsets[saw_site, can_reach]
+ #
+ # if any(p not in have_seen for p in zip(can_reach, reach_the_image)):
+ # return True
+ #
+ # return False
From df810c2d2bf60ca1fd64326875be3464156bdcca Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 19 Jun 2019 15:43:39 -0400
Subject: [PATCH 021/129] Refactoring visualization code
---
sitator/SiteTrajectory.py | 133 +++---------------
sitator/visualization/SiteNetworkPlotter.py | 2 +-
.../visualization/SiteTrajectoryPlotter.py | 127 +++++++++++++++++
sitator/visualization/__init__.py | 2 +
4 files changed, 146 insertions(+), 118 deletions(-)
create mode 100644 sitator/visualization/SiteTrajectoryPlotter.py
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 5c5f2ff..f7c2e93 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -1,10 +1,7 @@
import numpy as np
from sitator.util import PBCCalculator
-from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS
-
-import matplotlib
-from matplotlib.collections import LineCollection
+from sitator.visualization import SiteTrajectoryPlotter
import logging
logger = logging.getLogger(__name__)
@@ -41,6 +38,8 @@ def __init__(self,
self._real_traj = None
+ self._default_plotter = None
+
def __len__(self):
return self.n_frames
@@ -202,119 +201,19 @@ def assign_to_last_known_site(self, frame_threshold = 1):
return res
- @plotter(is3D = True)
- def plot_frame(self, frame, **kwargs):
- sites_of_frame = np.unique(self._traj[frame])
- frame_sn = self._sn[sites_of_frame]
-
- frame_sn.plot(**kwargs)
-
- if not self._real_traj is None:
- mobile_atoms = self._sn.structure.copy()
- del mobile_atoms[~self._sn.mobile_mask]
-
- mobile_atoms.positions[:] = self._real_traj[frame, self._sn.mobile_mask]
- plot_atoms(atoms = mobile_atoms, **kwargs)
-
- kwargs['ax'].set_title("Frame %i/%i" % (frame, self.n_frames))
-
- @plotter(is3D = True)
- def plot_site(self, site, **kwargs):
- pbcc = PBCCalculator(self._sn.structure.cell)
- pts = self.real_positions_for_site(site).copy()
- offset = pbcc.cell_centroid - pts[3]
- pts += offset
- pbcc.wrap_points(pts)
- lattice_pos = self._sn.static_structure.positions.copy()
- lattice_pos += offset
- pbcc.wrap_points(lattice_pos)
- site_pos = self._sn.centers[site:site+1].copy()
- site_pos += offset
- pbcc.wrap_points(site_pos)
- # Plot point cloud
- plot_points(points = pts, alpha = 0.3, marker = '.', color = 'k', **kwargs)
- # Plot site
- plot_points(points = site_pos, color = 'cyan', **kwargs)
- # Plot everything else
- plot_atoms(self._sn.static_structure, positions = lattice_pos, **kwargs)
-
- title = "Site %i/%i" % (site, len(self._sn))
-
- if not self._sn.site_types is None:
- title += " (type %i)" % self._sn.site_types[site]
-
- kwargs['ax'].set_title(title)
-
- @plotter(is3D = False)
- def plot_particle_trajectory(self, particle, ax = None, fig = None, **kwargs):
- types = not self._sn.site_types is None
- if types:
- type_height_percent = 0.1
- axpos = ax.get_position()
- typeax_height = type_height_percent * axpos.height
- typeax = fig.add_axes([axpos.x0, axpos.y0, axpos.width, typeax_height], sharex = ax)
- ax.set_position([axpos.x0, axpos.y0 + typeax_height, axpos.width, axpos.height - typeax_height])
- type_height = 1
- # Draw trajectory
- segments = []
- linestyles = []
- colors = []
-
- traj = self._traj[:, particle]
- current_value = traj[0]
- last_value = traj[0]
- if types:
- last_type = None
- current_segment_start = 0
- puttext = False
-
- for i, f in enumerate(traj):
- if f != current_value or i == len(traj) - 1:
- val = last_value if current_value == -1 else current_value
- segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
- linestyles.append(':' if current_value == -1 else '-')
- colors.append('lightgray' if current_value == -1 else 'k')
-
- if types:
- rxy = (current_segment_start, 0)
- this_type = self._sn.site_types[val]
- typerect = matplotlib.patches.Rectangle(rxy, i - current_segment_start, type_height,
- color = DEFAULT_COLORS[this_type], linewidth = 0)
- typeax.add_patch(typerect)
- if this_type != last_type:
- typeax.annotate("T%i" % this_type,
- xy = (rxy[0], rxy[1] + 0.5 * type_height),
- xytext = (3, -1),
- textcoords = 'offset points',
- fontsize = 'xx-small',
- va = 'center',
- fontweight = 'bold')
- last_type = this_type
-
- last_value = val
- current_segment_start = i
- current_value = f
-
- lc = LineCollection(segments, linestyles = linestyles, colors = colors, linewidth=1.5)
- ax.add_collection(lc)
-
- if types:
- typeax.set_xlabel("Frame")
- ax.tick_params(axis = 'x', which = 'both', bottom = False, top = False, labelbottom = False)
- typeax.tick_params(axis = 'y', which = 'both', left = False, right = False, labelleft = False)
- typeax.annotate("Type", xy = (0, 0.5), xytext = (-25, 0), xycoords = 'axes fraction', textcoords = 'offset points', va = 'center', fontsize = 'x-small')
- else:
- ax.set_xlabel("Frame")
- ax.set_ylabel("Atom %i's site" % particle)
- ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
- ax.grid()
+ # ---- Plotting code
+ def plot_frame(self, *args, **kwargs):
+ if self._default_plotter is None:
+ self._default_plotter = SiteTrajectoryPlotter()
+ self._default_plotter.plot_frame(self, *args, **kwargs)
- ax.set_xlim((0, self.n_frames - 1))
- margin_percent = 0.04
- ymargin = (margin_percent * self._sn.n_sites)
- ax.set_ylim((-ymargin, self._sn.n_sites - 1.0 + ymargin))
+ def plot_site(self, *args, **kwargs):
+ if self._default_plotter is None:
+ self._default_plotter = SiteTrajectoryPlotter()
+ self._default_plotter.plot_site(self, *args, **kwargs)
- if types:
- typeax.set_xlim((0, self.n_frames - 1))
- typeax.set_ylim((0, type_height))
+ def plot_particle_trajectory(self, *args, **kwargs):
+ if self._default_plotter is None:
+ self._default_plotter = SiteTrajectoryPlotter()
+ self._default_plotter.plot_particle_trajectory(self, *args, **kwargs)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 8e8ebfa..775f8b8 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -30,7 +30,7 @@ class SiteNetworkPlotter(object):
DEFAULT_MARKERS = ['x', '+', 'v', '<', '^', '>', '*', 'd', 'h', 'p']
DEFAULT_LINESTYLES = ['--', ':', '-.', '-']
- EDGE_GROUP_COLORS = ['b', 'g', 'm', 'lightseagreen', 'crimson'] + ['gray'] # gray last for -1's
+ EDGE_GROUP_COLORS = ['b', 'g', 'm', 'crimson', 'lightseagreen', 'darkorange', 'sandybrown', 'gold', 'hotpink'] + ['gray'] # gray last for -1's
def __init__(self,
site_mappings = DEFAULT_SITE_MAPPINGS,
diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py
new file mode 100644
index 0000000..7649c8b
--- /dev/null
+++ b/sitator/visualization/SiteTrajectoryPlotter.py
@@ -0,0 +1,127 @@
+
+import matplotlib
+from matplotlib.collections import LineCollection
+
+from sitator.util import PBCCalculator
+from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS
+
+
+class SiteTrajectoryPlotter(object):
+ @plotter(is3D = True)
+ def plot_frame(self, st, frame, **kwargs):
+ sites_of_frame = np.unique(st._traj[frame])
+ frame_sn = st._sn[sites_of_frame]
+
+ frame_sn.plot(**kwargs)
+
+ if not st._real_traj is None:
+ mobile_atoms = st._sn.structure.copy()
+ del mobile_atoms[~st._sn.mobile_mask]
+
+ mobile_atoms.positions[:] = st._real_traj[frame, st._sn.mobile_mask]
+ plot_atoms(atoms = mobile_atoms, **kwargs)
+
+ kwargs['ax'].set_title("Frame %i/%i" % (frame, st.n_frames))
+
+
+ @plotter(is3D = True)
+ def plot_site(self, st, site, **kwargs):
+ pbcc = PBCCalculator(st._sn.structure.cell)
+ pts = st.real_positions_for_site(site).copy()
+ offset = pbcc.cell_centroid - pts[3]
+ pts += offset
+ pbcc.wrap_points(pts)
+ lattice_pos = st._sn.static_structure.positions.copy()
+ lattice_pos += offset
+ pbcc.wrap_points(lattice_pos)
+ site_pos = st._sn.centers[site:site+1].copy()
+ site_pos += offset
+ pbcc.wrap_points(site_pos)
+ # Plot point cloud
+ plot_points(points = pts, alpha = 0.3, marker = '.', color = 'k', **kwargs)
+ # Plot site
+ plot_points(points = site_pos, color = 'cyan', **kwargs)
+ # Plot everything else
+ plot_atoms(st._sn.static_structure, positions = lattice_pos, **kwargs)
+
+ title = "Site %i/%i" % (site, len(st._sn))
+
+ if not st._sn.site_types is None:
+ title += " (type %i)" % st._sn.site_types[site]
+
+ kwargs['ax'].set_title(title)
+
+
+ @plotter(is3D = False)
+ def plot_particle_trajectory(self, st, particle, ax = None, fig = None, **kwargs):
+ types = not st._sn.site_types is None
+ if types:
+ type_height_percent = 0.1
+ axpos = ax.get_position()
+ typeax_height = type_height_percent * axpos.height
+ typeax = fig.add_axes([axpos.x0, axpos.y0, axpos.width, typeax_height], sharex = ax)
+ ax.set_position([axpos.x0, axpos.y0 + typeax_height, axpos.width, axpos.height - typeax_height])
+ type_height = 1
+ # Draw trajectory
+ segments = []
+ linestyles = []
+ colors = []
+
+ traj = st._traj[:, particle]
+ current_value = traj[0]
+ last_value = traj[0]
+ if types:
+ last_type = None
+ current_segment_start = 0
+ puttext = False
+
+ for i, f in enumerate(traj):
+ if f != current_value or i == len(traj) - 1:
+ val = last_value if current_value == -1 else current_value
+ segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
+ linestyles.append(':' if current_value == -1 else '-')
+ colors.append('lightgray' if current_value == -1 else 'k')
+
+ if types:
+ rxy = (current_segment_start, 0)
+ this_type = st._sn.site_types[val]
+ typerect = matplotlib.patches.Rectangle(rxy, i - current_segment_start, type_height,
+ color = DEFAULT_COLORS[this_type], linewidth = 0)
+ typeax.add_patch(typerect)
+ if this_type != last_type:
+ typeax.annotate("T%i" % this_type,
+ xy = (rxy[0], rxy[1] + 0.5 * type_height),
+ xytext = (3, -1),
+ textcoords = 'offset points',
+ fontsize = 'xx-small',
+ va = 'center',
+ fontweight = 'bold')
+ last_type = this_type
+
+ last_value = val
+ current_segment_start = i
+ current_value = f
+
+ lc = LineCollection(segments, linestyles = linestyles, colors = colors, linewidth=1.5)
+ ax.add_collection(lc)
+
+ if types:
+ typeax.set_xlabel("Frame")
+ ax.tick_params(axis = 'x', which = 'both', bottom = False, top = False, labelbottom = False)
+ typeax.tick_params(axis = 'y', which = 'both', left = False, right = False, labelleft = False)
+ typeax.annotate("Type", xy = (0, 0.5), xytext = (-25, 0), xycoords = 'axes fraction', textcoords = 'offset points', va = 'center', fontsize = 'x-small')
+ else:
+ ax.set_xlabel("Frame")
+ ax.set_ylabel("Atom %i's site" % particle)
+
+ ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
+ ax.grid()
+
+ ax.set_xlim((0, st.n_frames - 1))
+ margin_percent = 0.04
+ ymargin = (margin_percent * st._sn.n_sites)
+ ax.set_ylim((-ymargin, st._sn.n_sites - 1.0 + ymargin))
+
+ if types:
+ typeax.set_xlim((0, st.n_frames - 1))
+ typeax.set_ylim((0, type_height))
diff --git a/sitator/visualization/__init__.py b/sitator/visualization/__init__.py
index 025308b..e96d843 100644
--- a/sitator/visualization/__init__.py
+++ b/sitator/visualization/__init__.py
@@ -3,3 +3,5 @@
from .atoms import plot_atoms, plot_points
from .SiteNetworkPlotter import SiteNetworkPlotter
+
+from .SiteTrajectoryPlotter import SiteTrajectoryPlotter
From 622d6b74be27751d8bf8e95386dc8293e4722e69 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 21 Jun 2019 16:09:04 -0400
Subject: [PATCH 022/129] Cleanup
---
README.md | 16 +++++-----
sitator/dynamics/DiffusionPathwayAnalysis.py | 32 +-------------------
sitator/util/PBCCalculator.pyx | 3 ++
3 files changed, 13 insertions(+), 38 deletions(-)
diff --git a/README.md b/README.md
index 6dd63a2..11909e4 100644
--- a/README.md
+++ b/README.md
@@ -9,22 +9,24 @@ A modular framework for conducting and visualizing site analysis of molecular dy
`sitator` contains an efficient implementation of our method, landmark analysis, as well as visualization tools, generic data structures for site analysis, pre- and post-processing tools, and more.
-For details on the method and its application, please see our paper:
+For details on landmark analysis and its application, please see our paper:
> L. Kahle, A. Musaelian, N. Marzari, and B. Kozinsky
> [Unsupervised landmark analysis for jump detection in molecular dynamics simulations](https://doi.org/10.1103/PhysRevMaterials.3.055404)
> Phys. Rev. Materials 3, 055404 – 21 May 2019
-If you use `sitator` in your research, please consider citing this paper. The BibTex citation can be found in [`CITATION.bib`](CITATION.bib).
+If you use `sitator` in your research, please consider citing this paper. The BibTeX citation can be found in [`CITATION.bib`](CITATION.bib).
## Installation
-`sitator` is built for Python >=3.2 (the older version supports Python 2.7). We recommend the use of a virtual environment (`virtualenv`, `conda`, etc.). `sitator` has one mandatory external dependency:
+`sitator` is built for Python >=3.2 (the older version supports Python 2.7). We recommend the use of a virtual environment (`virtualenv`, `conda`, etc.). `sitator` has a number of optional dependencies that enable various features:
- - The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does *not* have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
-
-
-If you want to use the site type analysis features, the `quip` binary from an installation of [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) can be used to compute the SOAP vectors. The Python 2.7 bindings (`quippy`) are **not** required. SOAP vectors can **also** be computed with [`DScribe`](https://singroup.github.io/dscribe/index.html) and the installation of QUIP avoided; note, however, that the descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
+ * **Landmark Analysis**
+ * The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does not have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
+ * **Site Type Analysis**
+ * For computing SOAP vectors: the `quip` binary from [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) **or** the [`DScribe`](https://singroup.github.io/dscribe/index.html) Python library.
+
+ The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
After downloading, the package is installed with `pip`:
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
index 51cee5d..de2801a 100644
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ b/sitator/dynamics/DiffusionPathwayAnalysis.py
@@ -162,7 +162,7 @@ def _build_mic_connmat(self, sn, connectivity_matrix):
for from_site, to_site in external_connections:
buf[:] = pos[to_site]
to_mic = pbcc.min_image(pos[from_site], buf)
- to_in_image = image + [to_mic // 10**(2 - i) % 10 for i in range(3)]
+ to_in_image = image + [(to_mic // 10**(2 - i) % 10) - 1 for i in range(3)] # FIXME: is the -1 right
if not np.any(np.abs(to_in_image) > 1):
to_in_image = 100 * (to_in_image[0] + 1) + 10 * (to_in_image[1] + 1) + (to_in_image[2] + 1)
newmat[image_idex * n_sites + from_site,
@@ -171,33 +171,3 @@ def _build_mic_connmat(self, sn, connectivity_matrix):
assert np.sum(newmat) >= n_images * np.sum(internal_mat) # Lowest it can be is if every one is internal
return newmat, mask_000
-
- # def _is_pbc_connected(self, sn, connectivity_matrix, component_mask):
- # n_sites = np.sum(component_mask)
- # mat = connectivity_matrix[component_mask, component_mask]
- # pos = sn.centers[component_mask]
- #
- # #mic_offsets = coo_matrix(shape = (n_sites, n_sites, 3), dtype = np.int)
- # mic_offsets = np.full(shape = (n_sites, n_sites, 3), fill_value = -20, dtype = np.int)
- # for from_site, to_site in zip(*np.where(mat)):
- # to_mic = pbcc.min_image(pos[from_site], pos[to_site])
- # mic_offsets[from_site, to_site] = [int(d) - 1 for d in str(to_mic)] # Get the individual digits
- #
- # # We always have seen every site in the 000 image, so it's implicit
- # have_seen = {}
- #
- # for from_site in range(n_sites):
- # # mark all images reachable from from_site as seen
- # can_reach = np.where(mat[from_site])[0]
- # reach_the_image = mic_offsets[from_site, can_reach]
- # have_seen.update(set(zip(can_reach, reach_the_image)))
- #
- # for saw_site, as_image in have_seen:
- # # Can we reach anything new from it?
- # can_reach = np.where(mat[saw_site])[0]
- # reach_the_image = as_image + mic_offsets[saw_site, can_reach]
- #
- # if any(p not in have_seen for p in zip(can_reach, reach_the_image)):
- # return True
- #
- # return False
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index fe97200..a4022e6 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -252,6 +252,9 @@ cdef class PBCCalculator(object):
Uses the brute force algorithm for correctness; returns the minimum image.
+ Assumes that `ref` and `pt` are already in the *same* cell (though not
+ necessarily the <0,0,0> cell -- any periodic image will do).
+
:returns int[3] minimg: Which image was the minimum image.
"""
# # There are 27 possible minimum images
From 82f3c42358a5712989d074fc457d2f3e6a8f7191 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 21 Jun 2019 16:09:20 -0400
Subject: [PATCH 023/129] Trajectory clamping
---
sitator/SiteTrajectory.py | 84 ++++++++++++++++++++++++++++++++++++++-
1 file changed, 82 insertions(+), 2 deletions(-)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index f7c2e93..4cf5ed8 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -2,6 +2,7 @@
from sitator.util import PBCCalculator
from sitator.visualization import SiteTrajectoryPlotter
+from sitator.util.progress import tqdm
import logging
logger = logging.getLogger(__name__)
@@ -49,7 +50,6 @@ def __getitem__(self, key):
confidences = None if self._confs is None else self._confs[key])
if not self._real_traj is None:
st.set_real_traj(self._real_traj[key])
-
return st
@property
@@ -131,10 +131,90 @@ def real_positions_for_site(self, site, return_confidences = False):
def compute_site_occupancies(self):
"""Computes site occupancies and adds site attribute `occupancies` to site_network."""
- occ = np.true_divide(np.bincount(self._traj[self._traj >= 0]), self.n_frames)
+ occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames)
self.site_network.add_site_attribute('occupancies', occ)
return occ
+ def clamped_real_trajectory(self, clamp_mask = None, wrap = False, pass_through_unassigned = False):
+ """Create a real-space trajectory with the fixed site/static structure positions.
+
+ Generate a real-space trajectory where the atoms indicated in `clamp_mask` --
+ any mixture of static and mobile -- are clamped to: (1) the fixed position of
+ their current site, if mobile, or (2) the corresponding fixed position in
+ the `SiteNetwork`'s static structure, if static.
+
+ Atoms not indicated in `clamp_mask` will have their positions from `real_traj`
+ passed through.
+
+ Clamped positions will be in the unit cell; it is assumed that the real
+ trajectory is wrapped.
+
+ Args:
+ - clamp_mask (ndarray, len(sn.structure))
+ - wrap (bool, default: False)
+ - pass_through_unassigned (bool, default: False): If True, when a
+ mobile atom is supposed to be clamped but is unassigned at some
+ frame, its real-space position will be passed through from the
+ real trajectory. If False, an error will be raised.
+ Returns:
+ ndarray (n_frames x n_atoms x 3)
+ """
+ cell = self._sn.structure.cell
+ pbcc = PBCCalculator(cell)
+
+ n_atoms = len(self._sn.structure)
+ if clamp_mask is None:
+ clamp_mask = np.ones(shape = n_atoms, dtype = np.bool)
+ if self._real_traj is None and not np.all(clamp_mask):
+ raise RuntimeError("This `SiteTrajectory` has no real-space trajectory, but the given clamp mask leaves some atoms unclamped.")
+
+ clamptrj = np.empty(shape = (self.n_frames, n_atoms, 3))
+ # Pass through unclamped positions
+ if not np.all(clamp_mask):
+ clamptrj[:, ~clamp_mask, :] = self._real_traj[:, ~clamp_mask, :]
+ # Clamp static atoms
+ static_clamp = clamp_mask & self._sn.static_mask
+ clamptrj[:, static_clamp, :] = self._sn.structure.get_positions()[static_clamp]
+ # Clamp mobile atoms
+ mobile_clamp = clamp_mask & self._sn.mobile_mask
+ selected_sitetraj = self._traj[:, mobile_clamp]
+ mobile_clamp_indexes = np.where(mobile_clamp)[0]
+ if not pass_through_unassigned and np.min(selected_sitetraj) < 0:
+ raise RuntimeError("The mobile atoms indicated for clamping are unassigned at some point during the trajectory and `pass_through_unassigned` is set to False. Try `assign_to_last_known_site()`?")
+
+ if wrap:
+ for frame_i in tqdm(range(len(clamptrj))):
+ for mobile_i in mobile_clamp_indexes:
+ at_site = self._traj[frame_i, mobile_i]
+ if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
+ clamptrj[frame_i, mobile_i] = self._real_traj[frame_i, mobile_i]
+ continue
+ clamptrj[frame_i, mobile_i] = self._sn.centers[at_site]
+ else:
+ buf = np.empty(shape = (1, 3))
+ site_pt = np.empty(shape = 3)
+ for frame_i in tqdm(range(len(clamptrj))):
+ for mobile_i in mobile_clamp_indexes:
+ buf[:, :] = self._real_traj[frame_i, mobile_i]
+ at_site = self._traj[frame_i, mobile_i]
+ if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
+ clamptrj[frame_i, mobile_i] = self._real_traj[frame_i, mobile_i]
+ continue
+ site_pt[:] = self._sn.centers[at_site]
+ pbcc.wrap_point(site_pt)
+ pbcc.wrap_points(buf)
+ site_mic = pbcc.min_image(buf[0], site_pt)
+ site_mic = [(site_mic // 10**(2 - i) % 10) - 1 for i in range(3)]
+ buf[:, :] = self._real_traj[frame_i, mobile_i]
+ pbcc.to_cell_coords(buf)
+ pt_in_image = np.floor(buf[0])
+ pt_in_image += site_mic
+ # np.dot(np.matrix(cell.T), pt_in_image) +
+ clamptrj[frame_i, mobile_i] = np.dot(pt_in_image, cell) + self._sn.centers[at_site]
+
+ return clamptrj
+
+
def assign_to_last_known_site(self, frame_threshold = 1):
"""Assign unassigned mobile particles to their last known site within
`frame_threshold` frames.
From 56053d44e9df79ba4755aa888dce21018fe7fa04 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 21 Jun 2019 17:55:38 -0400
Subject: [PATCH 024/129] Removed backports reference
---
sitator/SiteNetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index d4d9891..526e197 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -3,7 +3,7 @@
import re
import os
import tarfile
-from backports import tempfile
+import tempfile
import ase.io
From da31818a77ac6d4cd0c8be69b8eeac5f3d956b16 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 24 Jun 2019 10:13:52 -0400
Subject: [PATCH 025/129] Refactoring
---
sitator/SiteTrajectory.py | 84 ++----------------
sitator/misc/GenerateClampedTrajectory.py | 100 ++++++++++++++++++++++
2 files changed, 105 insertions(+), 79 deletions(-)
create mode 100644 sitator/misc/GenerateClampedTrajectory.py
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 4cf5ed8..96b36cf 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -88,6 +88,7 @@ def site_network(self, value):
def real_trajectory(self):
return self._real_traj
+
def set_real_traj(self, real_traj):
"""Assocaite this SiteTrajectory with a trajectory of points in real space.
@@ -98,11 +99,13 @@ def set_real_traj(self, real_traj):
raise ValueError("real_traj of shape %s does not have expected shape %s" % (real_traj.shape, expected_shape))
self._real_traj = real_traj
+
def remove_real_traj(self):
"""Forget associated real trajectory."""
del self._real_traj
self._real_traj = None
+
def trajectory_for_particle(self, i, return_confidences = False):
"""Returns the array of sites particle i is assigned to over time."""
if return_confidences and self._confs is None:
@@ -112,6 +115,7 @@ def trajectory_for_particle(self, i, return_confidences = False):
else:
return self._traj[:, i]
+
def real_positions_for_site(self, site, return_confidences = False):
if self._real_traj is None:
raise ValueError("This SiteTrajectory has no real trajectory")
@@ -129,91 +133,13 @@ def real_positions_for_site(self, site, return_confidences = False):
else:
return pts
+
def compute_site_occupancies(self):
"""Computes site occupancies and adds site attribute `occupancies` to site_network."""
occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames)
self.site_network.add_site_attribute('occupancies', occ)
return occ
- def clamped_real_trajectory(self, clamp_mask = None, wrap = False, pass_through_unassigned = False):
- """Create a real-space trajectory with the fixed site/static structure positions.
-
- Generate a real-space trajectory where the atoms indicated in `clamp_mask` --
- any mixture of static and mobile -- are clamped to: (1) the fixed position of
- their current site, if mobile, or (2) the corresponding fixed position in
- the `SiteNetwork`'s static structure, if static.
-
- Atoms not indicated in `clamp_mask` will have their positions from `real_traj`
- passed through.
-
- Clamped positions will be in the unit cell; it is assumed that the real
- trajectory is wrapped.
-
- Args:
- - clamp_mask (ndarray, len(sn.structure))
- - wrap (bool, default: False)
- - pass_through_unassigned (bool, default: False): If True, when a
- mobile atom is supposed to be clamped but is unassigned at some
- frame, its real-space position will be passed through from the
- real trajectory. If False, an error will be raised.
- Returns:
- ndarray (n_frames x n_atoms x 3)
- """
- cell = self._sn.structure.cell
- pbcc = PBCCalculator(cell)
-
- n_atoms = len(self._sn.structure)
- if clamp_mask is None:
- clamp_mask = np.ones(shape = n_atoms, dtype = np.bool)
- if self._real_traj is None and not np.all(clamp_mask):
- raise RuntimeError("This `SiteTrajectory` has no real-space trajectory, but the given clamp mask leaves some atoms unclamped.")
-
- clamptrj = np.empty(shape = (self.n_frames, n_atoms, 3))
- # Pass through unclamped positions
- if not np.all(clamp_mask):
- clamptrj[:, ~clamp_mask, :] = self._real_traj[:, ~clamp_mask, :]
- # Clamp static atoms
- static_clamp = clamp_mask & self._sn.static_mask
- clamptrj[:, static_clamp, :] = self._sn.structure.get_positions()[static_clamp]
- # Clamp mobile atoms
- mobile_clamp = clamp_mask & self._sn.mobile_mask
- selected_sitetraj = self._traj[:, mobile_clamp]
- mobile_clamp_indexes = np.where(mobile_clamp)[0]
- if not pass_through_unassigned and np.min(selected_sitetraj) < 0:
- raise RuntimeError("The mobile atoms indicated for clamping are unassigned at some point during the trajectory and `pass_through_unassigned` is set to False. Try `assign_to_last_known_site()`?")
-
- if wrap:
- for frame_i in tqdm(range(len(clamptrj))):
- for mobile_i in mobile_clamp_indexes:
- at_site = self._traj[frame_i, mobile_i]
- if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
- clamptrj[frame_i, mobile_i] = self._real_traj[frame_i, mobile_i]
- continue
- clamptrj[frame_i, mobile_i] = self._sn.centers[at_site]
- else:
- buf = np.empty(shape = (1, 3))
- site_pt = np.empty(shape = 3)
- for frame_i in tqdm(range(len(clamptrj))):
- for mobile_i in mobile_clamp_indexes:
- buf[:, :] = self._real_traj[frame_i, mobile_i]
- at_site = self._traj[frame_i, mobile_i]
- if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
- clamptrj[frame_i, mobile_i] = self._real_traj[frame_i, mobile_i]
- continue
- site_pt[:] = self._sn.centers[at_site]
- pbcc.wrap_point(site_pt)
- pbcc.wrap_points(buf)
- site_mic = pbcc.min_image(buf[0], site_pt)
- site_mic = [(site_mic // 10**(2 - i) % 10) - 1 for i in range(3)]
- buf[:, :] = self._real_traj[frame_i, mobile_i]
- pbcc.to_cell_coords(buf)
- pt_in_image = np.floor(buf[0])
- pt_in_image += site_mic
- # np.dot(np.matrix(cell.T), pt_in_image) +
- clamptrj[frame_i, mobile_i] = np.dot(pt_in_image, cell) + self._sn.centers[at_site]
-
- return clamptrj
-
def assign_to_last_known_site(self, frame_threshold = 1):
"""Assign unassigned mobile particles to their last known site within
diff --git a/sitator/misc/GenerateClampedTrajectory.py b/sitator/misc/GenerateClampedTrajectory.py
new file mode 100644
index 0000000..87626d0
--- /dev/null
+++ b/sitator/misc/GenerateClampedTrajectory.py
@@ -0,0 +1,100 @@
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.util import PBCCalculator
+from sitator.util.progress import tqdm
+
+
+class GenerateClampedTrajectory(object):
+ """Create a real-space trajectory with the fixed site/static structure positions.
+
+ Generate a real-space trajectory where the atoms are clamped to the fixed
+ positions of the current site/their fixed static position.
+
+ Args:
+ - wrap (bool, default: False): If True, all clamped positions will be in
+ the unit cell; if False, the clamped position will be the minimum
+ image of the clamped position with respect to the real-space position.
+ (This can generate a clamped, unwrapped real-space trajectory
+ from an unwrapped real space trajectory.)
+ - pass_through_unassigned (bool, default: False): If True, when a
+ mobile atom is supposed to be clamped but is unassigned at some
+ frame, its real-space position will be passed through from the
+ real trajectory. If False, an error will be raised.
+ """
+ def __init__(self, wrap = False, pass_through_unassigned = False):
+ self.wrap = wrap
+ self.pass_through_unassigned = pass_through_unassigned
+
+
+ def run(self, st, clamp_mask = None):
+ """Create a real-space trajectory with the fixed site/static structure positions.
+
+ Generate a real-space trajectory where the atoms indicated in `clamp_mask` --
+ any mixture of static and mobile -- are clamped to: (1) the fixed position of
+ their current site, if mobile, or (2) the corresponding fixed position in
+ the `SiteNetwork`'s static structure, if static.
+
+ Atoms not indicated in `clamp_mask` will have their positions from `real_traj`
+ passed through.
+
+ Args:
+ - clamp_mask (ndarray, len(sn.structure))
+ Returns:
+ ndarray (n_frames x n_atoms x 3)
+ """
+ wrap = st.wrap
+ pass_through_unassigned = st.pass_through_unassigned
+ cell = st._sn.structure.cell
+ pbcc = PBCCalculator(cell)
+
+ n_atoms = len(st._sn.structure)
+ if clamp_mask is None:
+ clamp_mask = np.ones(shape = n_atoms, dtype = np.bool)
+ if st._real_traj is None and not np.all(clamp_mask):
+ raise RuntimeError("This `SiteTrajectory` has no real-space trajectory, but the given clamp mask leaves some atoms unclamped.")
+
+ clamptrj = np.empty(shape = (st.n_frames, n_atoms, 3))
+ # Pass through unclamped positions
+ if not np.all(clamp_mask):
+ clamptrj[:, ~clamp_mask, :] = st._real_traj[:, ~clamp_mask, :]
+ # Clamp static atoms
+ static_clamp = clamp_mask & st._sn.static_mask
+ clamptrj[:, static_clamp, :] = st._sn.structure.get_positions()[static_clamp]
+ # Clamp mobile atoms
+ mobile_clamp = clamp_mask & st._sn.mobile_mask
+ selected_sitetraj = st._traj[:, mobile_clamp]
+ mobile_clamp_indexes = np.where(mobile_clamp)[0]
+ if not pass_through_unassigned and np.min(selected_sitetraj) < 0:
+ raise RuntimeError("The mobile atoms indicated for clamping are unassigned at some point during the trajectory and `pass_through_unassigned` is set to False. Try `assign_to_last_known_site()`?")
+
+ if wrap:
+ for frame_i in tqdm(range(len(clamptrj))):
+ for mobile_i in mobile_clamp_indexes:
+ at_site = st._traj[frame_i, mobile_i]
+ if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
+ clamptrj[frame_i, mobile_i] = st._real_traj[frame_i, mobile_i]
+ continue
+ clamptrj[frame_i, mobile_i] = st._sn.centers[at_site]
+ else:
+ buf = np.empty(shape = (1, 3))
+ site_pt = np.empty(shape = 3)
+ for frame_i in tqdm(range(len(clamptrj))):
+ for mobile_i in mobile_clamp_indexes:
+ buf[:, :] = st._real_traj[frame_i, mobile_i]
+ at_site = st._traj[frame_i, mobile_i]
+ if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
+ clamptrj[frame_i, mobile_i] = st._real_traj[frame_i, mobile_i]
+ continue
+ site_pt[:] = st._sn.centers[at_site]
+ pbcc.wrap_point(site_pt)
+ pbcc.wrap_points(buf)
+ site_mic = pbcc.min_image(buf[0], site_pt)
+ site_mic = [(site_mic // 10**(2 - i) % 10) - 1 for i in range(3)]
+ buf[:, :] = st._real_traj[frame_i, mobile_i]
+ pbcc.to_cell_coords(buf)
+ pt_in_image = np.floor(buf[0])
+ pt_in_image += site_mic
+ clamptrj[frame_i, mobile_i] = np.dot(pt_in_image, cell) + st._sn.centers[at_site]
+
+ return clamptrj
From 26cc0d88103af9a6b2872ae0a9c26f027c1dfbb4 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 24 Jun 2019 12:25:44 -0400
Subject: [PATCH 026/129] Added `RemoveShortJumps`
---
sitator/SiteTrajectory.py | 2 +
sitator/dynamics/RemoveShortJumps.py | 101 ++++++++++++++++++++++
sitator/dynamics/__init__.py | 2 +-
sitator/misc/GenerateClampedTrajectory.py | 4 +-
sitator/misc/__init__.py | 1 +
5 files changed, 107 insertions(+), 3 deletions(-)
create mode 100644 sitator/dynamics/RemoveShortJumps.py
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 96b36cf..c2c6056 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -88,6 +88,8 @@ def site_network(self, value):
def real_trajectory(self):
return self._real_traj
+ def copy(self):
+ return self[:]
def set_real_traj(self, real_traj):
"""Assocaite this SiteTrajectory with a trajectory of points in real space.
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
new file mode 100644
index 0000000..99bf662
--- /dev/null
+++ b/sitator/dynamics/RemoveShortJumps.py
@@ -0,0 +1,101 @@
+import numpy as np
+
+from sitator import SiteTrajectory
+
+import logging
+logger = logging.getLogger(__name__)
+
+class RemoveShortJumps(object):
+ """Remove "short" jumps in a SiteTrajectory.
+
+ Remove jumps where the residence at the target is less than some threshold
+ and, optionally, only where the mobile atom returns to the site it originally
+ jumped from.
+
+ Args:
+ - only_returning_jumps (bool, default: True): If True, only short jumps
+ where the mobile atom returns to its initial site will be removed.
+ """
+ def __init__(self, only_returning_jumps = True):
+ self.only_returning_jumps = only_returning_jumps
+
+
+ def run(self,
+ st,
+ threshold):
+ """Returns a copy of `st` with short jumps removed.
+
+ Args:
+ - st (SiteTrajectory): Unassigned considered to be last known.
+ - threshold (int): The largest number of frames the mobile atom
+ can spend at a site while the jump is still considered short.
+ """
+ n_mobile = st.site_network.n_mobile
+ n_frames = st.n_frames
+ n_sites = st.site_network.n_sites
+
+ previous_site = np.full(shape = n_mobile, fill_value = -2, dtype = np.int)
+ last_known = np.empty(shape = n_mobile, dtype = np.int)
+ np.copyto(last_known, st.traj[0])
+ # Everything is at it's first position for at least one frame by definition
+ time_at_current = np.ones(shape = n_mobile, dtype = np.int)
+
+ framebuf = np.empty(shape = st.traj.shape[1:], dtype = st.traj.dtype)
+
+ out = st.traj.copy()
+
+ n_problems = 0
+ n_short_jumps = 0
+
+ for i, frame in enumerate(st.traj):
+ # -- Deal with unassigned
+ # Don't screw up the SiteTrajectory
+ np.copyto(framebuf, frame)
+ frame = framebuf
+
+ unassigned = frame == SiteTrajectory.SITE_UNKNOWN
+ # Reassign unassigned
+ frame[unassigned] = last_known[unassigned]
+ fknown = frame >= 0
+
+ if np.any(~fknown):
+ logger.warning("At frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
+ # -- Update stats
+
+ jumped = (frame != last_known) & fknown
+ problems = last_known[jumped] == -1
+ jumped[np.where(jumped)[0][problems]] = False
+ n_problems += np.sum(problems)
+
+ jump_froms = last_known[jumped]
+ jump_tos = frame[jumped]
+
+ # For all that didn't jump, increment time at current
+ time_at_current[~jumped] += 1
+ # For all that did, check if short
+ short_mask = time_at_current[jumped] <= threshold
+ if self.only_returning_jumps:
+ short_mask &= jump_tos == previous_site[jumped]
+ # Remove short jumps
+ for sj_atom in np.arange(n_mobile)[jumped][short_mask]:
+ #print("atom %s removing %i -> %i (%i) -> %i" % (sj_atom, previous_site[sj_atom], last_known[sj_atom], time_at_current[sj_atom], frame[sj_atom]))
+ n_short_jumps += 1
+ out[i - time_at_current[sj_atom]:i+1, sj_atom] = previous_site[sj_atom]
+
+ previous_site[jumped] = last_known[jumped]
+
+ # Reset for those that jumped
+ time_at_current[jumped] = 1
+
+ # Update last known assignment for anything that has one
+ last_known[~unassigned] = frame[~unassigned]
+
+ if n_problems != 0:
+ logger.warning("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
+ logger.info("Removed %i short jumps" % n_short_jumps)
+ self.n_short_jumps = n_short_jumps
+
+ st = st.copy()
+ st._traj = out
+
+ return st
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index a78471a..e3232aa 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,3 +1,3 @@
from .JumpAnalysis import JumpAnalysis
-
from .MergeSitesByDynamics import MergeSitesByDynamics
+from .RemoveShortJumps import RemoveShortJumps
diff --git a/sitator/misc/GenerateClampedTrajectory.py b/sitator/misc/GenerateClampedTrajectory.py
index 87626d0..64575ea 100644
--- a/sitator/misc/GenerateClampedTrajectory.py
+++ b/sitator/misc/GenerateClampedTrajectory.py
@@ -43,8 +43,8 @@ def run(self, st, clamp_mask = None):
Returns:
ndarray (n_frames x n_atoms x 3)
"""
- wrap = st.wrap
- pass_through_unassigned = st.pass_through_unassigned
+ wrap = self.wrap
+ pass_through_unassigned = self.pass_through_unassigned
cell = st._sn.structure.cell
pbcc = PBCCalculator(cell)
diff --git a/sitator/misc/__init__.py b/sitator/misc/__init__.py
index 5a7095e..647ed08 100644
--- a/sitator/misc/__init__.py
+++ b/sitator/misc/__init__.py
@@ -2,3 +2,4 @@
from .NAvgsPerSite import NAvgsPerSite
from .GenerateAroundSites import GenerateAroundSites
from .SiteVolumes import SiteVolumes
+from .GenerateClampedTrajectory import GenerateClampedTrajectory
From dd29745a0269ea52372a5981b27f276760079b38 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 25 Jun 2019 14:48:26 -0400
Subject: [PATCH 027/129] Fixed true periodic pathway analysis
---
sitator/dynamics/DiffusionPathwayAnalysis.py | 72 ++++++++++++++++----
1 file changed, 60 insertions(+), 12 deletions(-)
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
index de2801a..9d740b9 100644
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ b/sitator/dynamics/DiffusionPathwayAnalysis.py
@@ -7,6 +7,7 @@
from scipy.sparse import lil_matrix
from scipy.sparse.csgraph import connected_components
+from sitator import SiteNetwork
from sitator.util import PBCCalculator
import logging
@@ -20,6 +21,10 @@ class DiffusionPathwayAnalysis(object):
for it to be considered connected.
:param int minimum_n_sites: The minimum number of sites that must be part of
a pathway for it to be considered as such.
+ :param bool true_periodic_pathways: Whether only to return true periodic
+ pathways that include sites and their periodic images (i.e. conductive
+ in the bulk) rather than just connected components. If True, `minimum_n_sites`
+ is NOT respected.
"""
NO_PATHWAY = -1
@@ -37,6 +42,15 @@ def __init__(self,
def run(self, sn, return_count = False):
"""
Expects a SiteNetwork that has had a JumpAnalysis run on it.
+
+ Adds information to `sn` in place.
+
+ Args:
+ - sn (SiteNetwork): Must have jump statistics from a `JumpAnalysis()`.
+ - return_count (bool, default: False): Return the number of connected
+ pathways.
+ Returns:
+ sn, [n_pathways]
"""
if not sn.has_attribute('n_ij'):
raise ValueError("SiteNetwork has no `n_ij`; run a JumpAnalysis on it first.")
@@ -67,32 +81,61 @@ def run(self, sn, return_count = False):
# is_pathway = np.ones(shape = n_ccs, dtype = np.bool)
# We have to check that the pathways include a site and its periodic
# image, and throw out those that don't
-
new_n_ccs = 1
new_ccs = np.zeros(shape = len(sn), dtype = np.int)
+ # Add a non-path (contains no sites, all False) so the broadcasting works
+ site_masks = [np.zeros(shape = len(sn), dtype = np.bool)]
+ #seen_mask = np.zeros(shape = len(sn), dtype = np.bool)
+
for pathway_i in np.arange(n_ccs):
path_mask = ccs == pathway_i
if not np.any(path_mask & mask_000):
+ # If the pathway is entirely outside the unit cell, we don't care
continue
- sitenums = np.where(path_mask)[0] % len(sn) # Get unit cell site numbers of all sites in pathway
- _, site_counts = np.unique(sitenums, return_counts = True)
- if np.sum(site_counts) <= len(site_counts):
- # Not a percolating path
+ # Sum along each site's periodic images, giving a count site-by-site
+ site_counts = np.sum(path_mask.reshape((-1, sn.n_sites)).astype(np.int), axis = 0)
+ if not np.any(site_counts > 1):
+ # Not percolating; doesn't contain any site and its periodic image.
+ print("Not percolating")
continue
- if len(site_counts) < self.minimum_n_sites:
- continue
-
- new_ccs[sitenums] = new_n_ccs
+ cur_site_mask = site_counts > 0
+
+ intersects_with = np.where(np.any(np.logical_and(site_masks, cur_site_mask), axis = 1))[0]
+ # Merge them:
+ if len(intersects_with) > 0:
+ path_mask = cur_site_mask | np.logical_or.reduce([site_masks[i] for i in intersects_with], axis = 0)
+ else:
+ path_mask = cur_site_mask
+ # Remove individual merged paths
+ for i in intersects_with:
+ del site_masks[i]
+ # Add new (super)path
+ site_masks.append(path_mask)
+ # if np.any(cur_site_mask & seen_mask):
+ # # We've seen this one before
+ # # This is OK because either they are connected, in which
+ # # case they aren't seperate components, or they include the
+ # # same site but AREN'T connected, in which case they must be
+ # # periodic images since otherwise they'd be connected.
+ # print('seen it')
+ # continue
+ # seen_mask |= cur_site_mask
+
+ new_ccs[path_mask] = new_n_ccs
new_n_ccs += 1
+ print(new_n_ccs)
+
n_ccs = new_n_ccs
ccs = new_ccs
+ # Only actually take the ones that were assigned to in the end
+ # This will deal with the ones that were merged.
is_pathway = np.in1d(np.arange(n_ccs), ccs)
- is_pathway[0] = False # Cause this was the "unassigned" value
+ is_pathway[0] = False # Cause this was the "unassigned" value, we initialized with zeros up above
else:
is_pathway = counts >= self.minimum_n_sites
@@ -132,13 +175,14 @@ def _build_mic_connmat(self, sn, connectivity_matrix):
assert n_images == 27
n_sites = len(sn)
- pos = sn.centers.copy() # TODO: copy not needed after reinstall of sitator!
+ pos = sn.centers #.copy() # TODO: copy not needed after reinstall of sitator!
n_total_sites = len(images) * n_sites
newmat = lil_matrix((n_total_sites, n_total_sites), dtype = np.bool)
mask_000 = np.zeros(shape = n_total_sites, dtype = np.bool)
index_000 = image_to_idex[111]
mask_000[index_000:index_000 + n_sites] = True
+ assert np.sum(mask_000) == len(sn)
pbcc = PBCCalculator(sn.structure.cell)
buf = np.empty(shape = 3)
@@ -150,8 +194,10 @@ def _build_mic_connmat(self, sn, connectivity_matrix):
if pbcc.min_image(pos[from_site], buf) == 111:
# If we're in the main image, keep the connection: it's internal
internal_mat[from_site, to_site] = True
+ #internal_mat[to_site, from_site] = True # fake FIXME
else:
external_connections.append((from_site, to_site))
+ #external_connections.append((to_site, from_site)) # FAKE FIXME
for image_idex, image in enumerate(images):
# Make the block diagonal
@@ -163,8 +209,10 @@ def _build_mic_connmat(self, sn, connectivity_matrix):
buf[:] = pos[to_site]
to_mic = pbcc.min_image(pos[from_site], buf)
to_in_image = image + [(to_mic // 10**(2 - i) % 10) - 1 for i in range(3)] # FIXME: is the -1 right
+ assert to_in_image is not None, "%s" % to_in_image
+ assert np.max(np.abs(to_in_image)) <= 2
if not np.any(np.abs(to_in_image) > 1):
- to_in_image = 100 * (to_in_image[0] + 1) + 10 * (to_in_image[1] + 1) + (to_in_image[2] + 1)
+ to_in_image = 100 * (to_in_image[0] + 1) + 10 * (to_in_image[1] + 1) + 1 * (to_in_image[2] + 1)
newmat[image_idex * n_sites + from_site,
image_to_idex[to_in_image] * n_sites + to_site] = True
From 38e0ca7591b6562f217c9c3b3fa1cf3d098bfaf3 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 25 Jun 2019 14:49:21 -0400
Subject: [PATCH 028/129] Cleaned up pathway analysis
---
sitator/dynamics/DiffusionPathwayAnalysis.py | 12 ------------
sitator/dynamics/__init__.py | 1 +
2 files changed, 1 insertion(+), 12 deletions(-)
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
index 9d740b9..ca3687e 100644
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ b/sitator/dynamics/DiffusionPathwayAnalysis.py
@@ -99,7 +99,6 @@ def run(self, sn, return_count = False):
site_counts = np.sum(path_mask.reshape((-1, sn.n_sites)).astype(np.int), axis = 0)
if not np.any(site_counts > 1):
# Not percolating; doesn't contain any site and its periodic image.
- print("Not percolating")
continue
cur_site_mask = site_counts > 0
@@ -115,21 +114,10 @@ def run(self, sn, return_count = False):
del site_masks[i]
# Add new (super)path
site_masks.append(path_mask)
- # if np.any(cur_site_mask & seen_mask):
- # # We've seen this one before
- # # This is OK because either they are connected, in which
- # # case they aren't seperate components, or they include the
- # # same site but AREN'T connected, in which case they must be
- # # periodic images since otherwise they'd be connected.
- # print('seen it')
- # continue
- # seen_mask |= cur_site_mask
new_ccs[path_mask] = new_n_ccs
new_n_ccs += 1
- print(new_n_ccs)
-
n_ccs = new_n_ccs
ccs = new_ccs
# Only actually take the ones that were assigned to in the end
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index e3232aa..900f7ca 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,3 +1,4 @@
from .JumpAnalysis import JumpAnalysis
from .MergeSitesByDynamics import MergeSitesByDynamics
from .RemoveShortJumps import RemoveShortJumps
+from .DiffusionPathwayAnalysis import DiffusionPathwayAnalysis
From 0ce351f8d863ae682487054d2e4ab3ef95952a21 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 25 Jun 2019 14:54:14 -0400
Subject: [PATCH 029/129] Improved docs and default parameters for
SiteNetworkPlotter
---
sitator/visualization/SiteNetworkPlotter.py | 34 ++++++++++++++-------
1 file changed, 23 insertions(+), 11 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 775f8b8..8a8cbc3 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -11,16 +11,28 @@
class SiteNetworkPlotter(object):
"""Plot a SiteNetwork.
- site_mappings defines how to show different properties. Each entry maps a
- visual aspect ('marker', 'color', 'size') to the name of a site attribute
- including 'site_type'.
-
- Likewise for edge_mappings, each key maps a visual property ('intensity', 'color',
- 'width', 'linestyle') to an edge attribute in the SiteNetwork.
-
Note that for edges, the average of the edge property for i -> j and j -> i
is often used for visual clarity; if your edge properties are not almost symmetric,
the visualization might not be useful.
+
+ Params:
+ - site_mappings (dict): defines how to show different properties. Each
+ entry maps a visual aspect ('marker', 'color', 'size') to the name
+ of a site attribute including 'site_type'.
+ - edge_mappings (dict): each key maps a visual property ('intensity',
+ 'color', 'width', 'linestyle') to an edge attribute in the SiteNetwork.
+ - markers (list of str): What `matplotlib` markers to use for sites.
+ - plot_points_params (dict): User options for plotting site points.
+ - minmax_linewidth (2-tuple): Minimum and maximum linewidth to use.
+ - minmax_edge_alpha (2-tuple): Similar, for edge line alphas.
+ - minmax_markersize (2-tuple): Similar, for markersize.
+ - min_color_threshold (float): Minimum (normalized) color intensity for
+ the corresponding line to be shown. Defaults to zero, i.e., all
+ nonzero edges will be drawn.
+ - min_width_threshold (float): Minimum normalized edge width for the
+ corresponding edge to be shown. Defaults to zero, i.e., all
+ nonzero edges will be drawn.
+ - title (str)
"""
DEFAULT_SITE_MAPPINGS = {
@@ -40,8 +52,8 @@ def __init__(self,
minmax_linewidth = (1.5, 7),
minmax_edge_alpha = (0.15, 0.75),
minmax_markersize = (80.0, 180.0),
- min_color_threshold = 0.005,
- min_width_threshold = 0.005,
+ min_color_threshold = 0.0,
+ min_width_threshold = 0.0,
title = ""):
self.site_mappings = site_mappings
self.edge_mappings = edge_mappings
@@ -205,9 +217,9 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
if done_already[i, j]:
continue
# Ignore anything below the threshold
- if all_cs[i, j] < self.min_color_threshold:
+ if all_cs[i, j] <= self.min_color_threshold:
continue
- if do_widths and all_linewidths[i, j] < self.min_width_threshold:
+ if do_widths and all_linewidths[i, j] <= self.min_width_threshold:
continue
segment = np.empty(shape = (2, 3), dtype = centers.dtype)
From 48d75079807596f27ce68c65ca9f3e5187c2fe4e Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 11:04:22 -0400
Subject: [PATCH 030/129] Added environment variables for paths to external
tools
---
README.md | 2 ++
sitator/site_descriptors/backend/quip.py | 4 +++-
sitator/voronoi/VoronoiSiteGenerator.py | 6 +++++-
3 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 11909e4..35abef7 100644
--- a/README.md
+++ b/README.md
@@ -52,6 +52,8 @@ All individual classes and parameters are documented with docstrings in the sour
`sitator` uses the `tqdm.autonotebook` tool to automatically produce the correct fancy progress bars for terminals and iPython notebooks. To disable all progress bars, run with the environment variable `SITATOR_PROGRESSBAR` set to `false`.
+The `SITATOR_ZEO_PATH` and `SITATOR_QUIP_PATH` environment variables can set the default paths to the Zeo++ `network` and QUIP `quip` executables, respectively.
+
## License
This software is made available under the MIT License. See `LICENSE` for more details.
diff --git a/sitator/site_descriptors/backend/quip.py b/sitator/site_descriptors/backend/quip.py
index 576675a..d97459c 100644
--- a/sitator/site_descriptors/backend/quip.py
+++ b/sitator/site_descriptors/backend/quip.py
@@ -4,6 +4,8 @@
import numpy as np
+import os
+
import ase
from tempfile import NamedTemporaryFile
@@ -16,7 +18,7 @@
'atom_sigma' : 0.4
}
-def quip_soap_backend(soap_params = {}, quip_path = 'quip'):
+def quip_soap_backend(soap_params = {}, quip_path = os.getenv("SITATOR_QUIP_PATH", default = 'quip')):
def backend(sn, soap_mask, tracer_atomic_number, environment_list):
soap_opts = dict(DEFAULT_SOAP_PARAMS)
diff --git a/sitator/voronoi/VoronoiSiteGenerator.py b/sitator/voronoi/VoronoiSiteGenerator.py
index 544ec16..1f5f3c3 100644
--- a/sitator/voronoi/VoronoiSiteGenerator.py
+++ b/sitator/voronoi/VoronoiSiteGenerator.py
@@ -1,6 +1,8 @@
import numpy as np
+import os
+
from sitator import SiteNetwork
from sitator.util import Zeopy
@@ -12,7 +14,9 @@ class VoronoiSiteGenerator(object):
and should typically be, False.
"""
- def __init__(self, zeopp_path = "network", radial = False):
+ def __init__(self,
+ zeopp_path = os.getenv("SITATOR_ZEO_PATH", default = "network"),
+ radial = False):
self._radial = radial
self._zeopy = Zeopy(zeopp_path)
From d2efc540f56ea458920022dc69e8ed9d2b7f7fe4 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 15:12:30 -0400
Subject: [PATCH 031/129] Bugfixes
---
sitator/visualization/SiteNetworkPlotter.py | 23 +++++++++++++--------
1 file changed, 14 insertions(+), 9 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 8a8cbc3..ce40be2 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -73,9 +73,10 @@ def __init__(self,
def __call__(self, sn, *args, **kwargs):
# -- Plot actual SiteNetwork --
l = [(plot_atoms, {'atoms' : sn.static_structure})]
- l += self._site_layers(sn, self.plot_points_params)
+ site_layer, normalization_params = self._site_layers(sn, self.plot_points_params)
+ l += site_layer
- l += self._plot_edges(sn, *args, **kwargs)
+ l += self._plot_edges(sn, site_params = normalization_params, *args, **kwargs)
# -- Some visual clean up --
ax = kwargs['ax']
@@ -97,7 +98,7 @@ def __call__(self, sn, *args, **kwargs):
# -- Put it all together --
layers(*l, **kwargs)
- def _site_layers(self, sn, plot_points_params):
+ def _site_layers(self, sn, plot_points_params, size_minmax = None, color_minmax = None):
pts_arrays = {'points' : sn.centers}
pts_params = {'cmap' : 'rainbow'}
@@ -112,11 +113,15 @@ def _site_layers(self, sn, plot_points_params):
markers = val.copy()
elif key == 'color':
pts_arrays['c'] = val.copy()
- pts_params['norm'] = matplotlib.colors.Normalize(vmin = np.min(val), vmax = np.max(val))
+ if color_minmax is None:
+ color_minmax = (np.min(val), np.max(val))
+ pts_params['norm'] = matplotlib.colors.Normalize(vmin = color_minmax[0], vmax = color_minmax[1])
elif key == 'size':
+ if size_minmax is None:
+ size_minmax = (np.min(val), np.max(val))
s = val.copy()
- s += np.min(s)
- s /= np.max(s)
+ s -= size_minmax[0]
+ s /= size_minmax[1] - size_minmax[0]
s *= self.minmax_markersize[1]
s += self.minmax_markersize[0]
pts_arrays['s'] = s
@@ -153,9 +158,9 @@ def _site_layers(self, sn, plot_points_params):
d.update(pts_params)
pts_layers.append((plot_points, d))
- return pts_layers
+ return pts_layers, {'size_minmax' : size_minmax, 'color_minmax' : color_minmax}
- def _plot_edges(self, sn, ax = None, *args, **kwargs):
+ def _plot_edges(self, sn, site_params = {}, ax = None, *args, **kwargs):
if not 'intensity' in self.edge_mappings:
return []
@@ -285,7 +290,7 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
sn2.update_centers(np.asarray(sites_to_plot_positions))
pts_params = dict(self.plot_points_params)
pts_params['alpha'] = 0.2
- return self._site_layers(sn2, pts_params)
+ return self._site_layers(sn2, pts_params, **site_params)
else:
return []
else:
From 77f991e7297ba9870a6b2ba4f956c67f0e5f93f2 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 15:40:46 -0400
Subject: [PATCH 032/129] Fix multiple visualization bugs
---
sitator/visualization/SiteNetworkPlotter.py | 34 ++++++++++++---------
1 file changed, 19 insertions(+), 15 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index ce40be2..0650396 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -73,10 +73,8 @@ def __init__(self,
def __call__(self, sn, *args, **kwargs):
# -- Plot actual SiteNetwork --
l = [(plot_atoms, {'atoms' : sn.static_structure})]
- site_layer, normalization_params = self._site_layers(sn, self.plot_points_params)
- l += site_layer
-
- l += self._plot_edges(sn, site_params = normalization_params, *args, **kwargs)
+ l += self._site_layers(sn, self.plot_points_params)
+ l += self._plot_edges(sn, *args, **kwargs)
# -- Some visual clean up --
ax = kwargs['ax']
@@ -98,7 +96,7 @@ def __call__(self, sn, *args, **kwargs):
# -- Put it all together --
layers(*l, **kwargs)
- def _site_layers(self, sn, plot_points_params, size_minmax = None, color_minmax = None):
+ def _site_layers(self, sn, plot_points_params, same_normalization = False):
pts_arrays = {'points' : sn.centers}
pts_params = {'cmap' : 'rainbow'}
@@ -113,12 +111,14 @@ def _site_layers(self, sn, plot_points_params, size_minmax = None, color_minmax
markers = val.copy()
elif key == 'color':
pts_arrays['c'] = val.copy()
- if color_minmax is None:
- color_minmax = (np.min(val), np.max(val))
+ if not same_normalization:
+ self._color_minmax = (np.min(val), np.max(val))
+ color_minmax = self._color_minmax
pts_params['norm'] = matplotlib.colors.Normalize(vmin = color_minmax[0], vmax = color_minmax[1])
elif key == 'size':
- if size_minmax is None:
- size_minmax = (np.min(val), np.max(val))
+ if not same_normalization:
+ self._size_minmax = (np.min(val), np.max(val))
+ size_minmax = self._size_minmax
s = val.copy()
s -= size_minmax[0]
s /= size_minmax[1] - size_minmax[0]
@@ -136,10 +136,12 @@ def _site_layers(self, sn, plot_points_params, size_minmax = None, color_minmax
else:
markers = self._make_discrete(markers)
unique_markers = np.unique(markers)
- marker_i = 0
+ if len(unique_markers) > len(self.markers):
+ raise ValueError("Too many distinct values of the site property mapped to markers (there are %i) for the %i markers in `self.markers`" % (len(unique_markers), len(self.markers)))
+ if not same_normalization:
+ self._marker_table = dict(zip(unique_markers, self.markers[:len(unique_markers)]))
for um in unique_markers:
- marker_layers[SiteNetworkPlotter.DEFAULT_MARKERS[marker_i]] = (markers == um)
- marker_i += 1
+ marker_layers[self._marker_table[um]] = (markers == um)
# -- Do plot
# If no color info provided, a fallback
@@ -158,9 +160,9 @@ def _site_layers(self, sn, plot_points_params, size_minmax = None, color_minmax
d.update(pts_params)
pts_layers.append((plot_points, d))
- return pts_layers, {'size_minmax' : size_minmax, 'color_minmax' : color_minmax}
+ return pts_layers
- def _plot_edges(self, sn, site_params = {}, ax = None, *args, **kwargs):
+ def _plot_edges(self, sn, ax = None, *args, **kwargs):
if not 'intensity' in self.edge_mappings:
return []
@@ -267,6 +269,8 @@ def _plot_edges(self, sn, site_params = {}, ax = None, *args, **kwargs):
# Group colors
if do_groups:
for i in range(len(cs)):
+ if groups[i] >= len(SiteNetworkPlotter.EDGE_GROUP_COLORS) - 1:
+ raise ValueError("Too many groups, not enough group colors")
lccolors[i] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[groups[i]])
else:
lccolors[:] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[0])
@@ -290,7 +294,7 @@ def _plot_edges(self, sn, site_params = {}, ax = None, *args, **kwargs):
sn2.update_centers(np.asarray(sites_to_plot_positions))
pts_params = dict(self.plot_points_params)
pts_params['alpha'] = 0.2
- return self._site_layers(sn2, pts_params, **site_params)
+ return self._site_layers(sn2, pts_params, same_normalization = True)
else:
return []
else:
From e839587e61d423f54fe142e7981bf2c9a4cd7682 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 17:13:45 -0400
Subject: [PATCH 033/129] Allow recomputing occupancies
---
sitator/SiteTrajectory.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index c2c6056..fbaa7cf 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -139,6 +139,8 @@ def real_positions_for_site(self, site, return_confidences = False):
def compute_site_occupancies(self):
"""Computes site occupancies and adds site attribute `occupancies` to site_network."""
occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames)
+ if self.site_network.has_attribute('occupancies'):
+ self.site_network.remove_attribute('occupancies')
self.site_network.add_site_attribute('occupancies', occ)
return occ
From 69779a88a750797ca4a002f29a4f2ee8270ffb90 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 17:14:52 -0400
Subject: [PATCH 034/129] SOAP improvements
---
sitator/site_descriptors/SOAP.py | 43 ++++++++++++---------
sitator/site_descriptors/backend/dscribe.py | 28 +++++++-------
2 files changed, 39 insertions(+), 32 deletions(-)
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index 3a05fe9..e9c07df 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -2,12 +2,14 @@
import numpy as np
from abc import ABCMeta, abstractmethod
-from sitator.SiteNetwork import SiteNetwork
-from sitator.SiteTrajectory import SiteTrajectory
+from sitator import SiteNetwork, SiteTrajectory
from sitator.util.progress import tqdm
from ase.data import atomic_numbers
+import logging
+logger = logging.getLogger(__name__)
+
class SOAP(object, metaclass=ABCMeta):
"""Abstract base class for computing SOAP vectors in a SiteNetwork.
@@ -31,7 +33,9 @@ class SOAP(object, metaclass=ABCMeta):
:param func backend: A function that can be called with `sn, soap_mask, tracer_atomic_number, environment_list` as
parameters, returning a function that, given the current soap structure
along with tracer atoms, returns SOAP vectors in a numpy array. (i.e.
- its signature is `soap(structure, positions)`)
+ its signature is `soap(structure, positions)`). The returned function
+ can also have a property, `n_dim`, giving the length of a single SOAP
+ vector.
"""
from .backend.quip import quip_soap_backend as backend_quip
@@ -72,7 +76,6 @@ def __init__(self, tracer_atomic_number, environment = None,
else:
self._environment = None
-
def get_descriptors(self, stn):
"""
Get the descriptors.
@@ -149,7 +152,9 @@ class SOAPCenters(SOAP):
Requires a SiteNetwork as input.
"""
def _get_descriptors(self, sn, structure, tracer_atomic_number, soap_mask, soaper):
- assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork, not `%s`" % sn
+ if isinstance(sn, SiteTrajectory):
+ sn = sn.site_network
+ assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork or SiteTrajectory, not `%s`" % sn
pts = sn.centers
@@ -248,25 +253,30 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
mob_indices = np.where(site_trajectory.site_network.mobile_mask)[0]
# real_traj is the real space positions, site_traj the site trajectory
# (i.e. for every mobile species the site index)
- # I load into new variable, only the steps I need (memory???)
real_traj = site_trajectory._real_traj[::self._stepsize]
site_traj = site_trajectory.traj[::self._stepsize]
# Now, I need to allocate the output
# so for each site, I count how much data there is!
- counts = np.array([np.count_nonzero(site_traj==site_idx) for site_idx in range(nsit)], dtype=int)
+ # Add one to deal with -1 -> 0, then just ignore the count for 0
+ counts = np.bincount(site_traj.reshape(-1) + 1, minlength = nsit)[1:]
if self._averaging is not None:
averaging = self._averaging
else:
averaging = int(np.floor(np.mean(counts) / self._avg_desc_per_site))
- nr_of_descs = counts // averaging
+ if averaging == 0:
+ logger.warning("Asking for too many average descriptors per site; got averaging = 0; setting averaging = 1")
+ averaging = 1
- if np.any(nr_of_descs == 0):
- raise ValueError("You are asking too much, averaging with {} gives a problem".format(averaging))
+ nr_of_descs = counts // averaging
+ insufficient = nr_of_descs == 0
+ if np.any(insufficient):
+ logger.warning("You're asking to average %i SOAP vectors, but at this stepsize, %i sites are insufficiently occupied. Num occ./averaging: %s" % (averaging, np.sum(insufficient), counts[insufficient] / averaging))
+ nr_of_descs = np.maximum(nr_of_descs, 1) # If it's 0, just make one with whatever we've got
# This is where I load the descriptor:
- descs = np.zeros((np.sum(nr_of_descs), self.n_dim))
+ descs = np.zeros((np.sum(nr_of_descs), soaper.n_dim))
# An array that tells me the index I'm at for each site type
desc_index = [np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))]
@@ -276,7 +286,7 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
blocked = np.empty(nsit, dtype=bool)
blocked[:] = False
- for site_traj_t, pos in tqdm(zip(site_traj, real_traj), desc="SOAP"):
+ for site_traj_t, pos in zip(tqdm(site_traj, desc="SOAP"), real_traj):
# I update the host lattice positions here, once for every timestep
structure.positions[:] = pos[soap_mask]
@@ -284,17 +294,12 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
if site_idx >= 0 and not blocked[site_idx]:
# Now, for every lithium that has been associated to a site of index site_idx,
# I take my structure and load the position of this mobile atom:
- # calc_connect to calculated distance
-# structure.calc_connect()
#There should only be one descriptor, since there should only be one mobile
- # I also divide by averaging, to avoid getting into large numbers.
-# soapv = self._soaper.calc(structure)['descriptor'][0] / self._averaging
- soapv = soaper(structure, [pos[mob_indices[mob_idx]]])
+ soapv = soaper(structure, [pos[mob_indices[mob_idx]]])[0]
- #~ soapv ,_,_ = get_fingerprints([structure], d)
# So, now I need to figure out where to load the soapv into desc
idx_to_add_desc = desc_index[site_idx]
- descs[idx_to_add_desc, :] += soapv[0] / averaging
+ descs[idx_to_add_desc, :] += soapv / averaging
count_of_site[site_idx] += 1
# Now, if the count reaches the averaging I want, I augment
if count_of_site[site_idx] == averaging:
diff --git a/sitator/site_descriptors/backend/dscribe.py b/sitator/site_descriptors/backend/dscribe.py
index ed7b140..66fc5d5 100644
--- a/sitator/site_descriptors/backend/dscribe.py
+++ b/sitator/site_descriptors/backend/dscribe.py
@@ -6,7 +6,8 @@
'l_max' : 6, 'n_max' : 6,
'atom_sigma' : 0.4,
'rbf' : 'gto',
- 'crossover' : False
+ 'crossover' : False,
+ 'periodic' : True,
}
def dscribe_soap_backend(soap_params = {}):
@@ -16,23 +17,24 @@ def dscribe_soap_backend(soap_params = {}):
soap_opts.update(soap_params)
def backend(sn, soap_mask, tracer_atomic_number, environment_list):
+ soap = SOAP(
+ species = environment_list,
+ crossover = soap_opts['crossover'],
+ rcut = soap_opts['cutoff'],
+ nmax = soap_opts['n_max'],
+ lmax = soap_opts['l_max'],
+ rbf = soap_opts['rbf'],
+ sigma = soap_opts['atom_sigma'],
+ periodic = soap_opts['periodic'],
+ sparse = False
+ )
def dscribe_soap(structure, positions):
- soap = SOAP(
- species = environment_list,
- crossover = soap_opts['crossover'],
- rcut = soap_opts['cutoff'],
- nmax = soap_opts['n_max'],
- lmax = soap_opts['l_max'],
- rbf = soap_opts['rbf'],
- sigma = soap_opts['atom_sigma'],
- periodic = np.all(structure.pbc),
- sparse = False
- )
-
out = soap.create(structure, positions = positions).astype(np.float)
return out
+ dscribe_soap.n_dim = soap.get_number_of_features()
+
return dscribe_soap
return backend
From 0aea0e328b302914732cfd23095da07726d53fb7 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 17:43:08 -0400
Subject: [PATCH 035/129] Vectorized SOAPDescriptorAverages
---
sitator/site_descriptors/SOAP.py | 42 +++++++++++++++-----------------
1 file changed, 20 insertions(+), 22 deletions(-)
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index e9c07df..f1922dd 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -265,6 +265,7 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
averaging = self._averaging
else:
averaging = int(np.floor(np.mean(counts) / self._avg_desc_per_site))
+ logger.debug("Will average %i SOAP vectors for every output vector" % averaging)
if averaging == 0:
logger.warning("Asking for too many average descriptors per site; got averaging = 0; setting averaging = 1")
@@ -275,39 +276,36 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
if np.any(insufficient):
logger.warning("You're asking to average %i SOAP vectors, but at this stepsize, %i sites are insufficiently occupied. Num occ./averaging: %s" % (averaging, np.sum(insufficient), counts[insufficient] / averaging))
nr_of_descs = np.maximum(nr_of_descs, 1) # If it's 0, just make one with whatever we've got
+ logger.debug("Minimum # of descriptors/site: %i; maximum: %i" % (np.min(nr_of_descs), np.max(nr_of_descs)))
# This is where I load the descriptor:
descs = np.zeros((np.sum(nr_of_descs), soaper.n_dim))
# An array that tells me the index I'm at for each site type
- desc_index = [np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))]
- max_index = [np.sum(nr_of_descs[:i+1]) for i in range(len(nr_of_descs))]
+ desc_index = np.asarray([np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))])
+ max_index = np.asarray([np.sum(nr_of_descs[:i+1]) for i in range(len(nr_of_descs))])
count_of_site = np.zeros(len(nr_of_descs), dtype=int)
- blocked = np.empty(nsit, dtype=bool)
- blocked[:] = False
+ allowed = np.ones(nsit, dtype = np.bool)
for site_traj_t, pos in zip(tqdm(site_traj, desc="SOAP"), real_traj):
# I update the host lattice positions here, once for every timestep
structure.positions[:] = pos[soap_mask]
- for mob_idx, site_idx in enumerate(site_traj_t):
- if site_idx >= 0 and not blocked[site_idx]:
- # Now, for every lithium that has been associated to a site of index site_idx,
- # I take my structure and load the position of this mobile atom:
- #There should only be one descriptor, since there should only be one mobile
- soapv = soaper(structure, [pos[mob_indices[mob_idx]]])[0]
-
- # So, now I need to figure out where to load the soapv into desc
- idx_to_add_desc = desc_index[site_idx]
- descs[idx_to_add_desc, :] += soapv / averaging
- count_of_site[site_idx] += 1
- # Now, if the count reaches the averaging I want, I augment
- if count_of_site[site_idx] == averaging:
- desc_index[site_idx] += 1
- count_of_site[site_idx] = 0
- # Now I check whether I have to block this site from accumulating more descriptors
- if max_index[site_idx] == desc_index[site_idx]:
- blocked[site_idx] = True
+ to_describe = (site_traj_t != SiteTrajectory.SITE_UNKNOWN) & allowed[site_traj_t]
+
+ if np.any(to_describe):
+ soaps = soaper(structure, pos[mob_indices[to_describe]])
+ soaps /= averaging
+
+ idx_to_add_desc = desc_index[site_traj_t[to_describe]]
+ descs[idx_to_add_desc] += soaps
+ count_of_site[site_traj_t[to_describe]] += 1
+
+ # Reset and increment full averages
+ full_average = count_of_site == averaging
+ desc_index[full_average] += 1
+ count_of_site[full_average] = 0
+ allowed[max_index == desc_index] = False
desc_to_site = np.repeat(list(range(nsit)), nr_of_descs)
return descs, desc_to_site
From a13944b09353a4694e5598daac8bd9fe296d8e0f Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 26 Jun 2019 18:34:31 -0400
Subject: [PATCH 036/129] Improved progress reporting
---
sitator/landmark/helpers.pyx | 2 +-
sitator/site_descriptors/SOAP.py | 2 +-
sitator/util/DotProdClassifier.pyx | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index de508b9..1dd736e 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -47,7 +47,7 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
cdef Py_ssize_t landmark_dim = self._landmark_dimension
cdef Py_ssize_t current_landmark_i = 0
# Iterate through time
- for i, frame in enumerate(tqdm(frames, desc = "Frame")):
+ for i, frame in enumerate(tqdm(frames, desc = "Landmark Frame")):
static_positions = frame[sn.static_mask]
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index f1922dd..9ae782d 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -287,7 +287,7 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
count_of_site = np.zeros(len(nr_of_descs), dtype=int)
allowed = np.ones(nsit, dtype = np.bool)
- for site_traj_t, pos in zip(tqdm(site_traj, desc="SOAP"), real_traj):
+ for site_traj_t, pos in zip(tqdm(site_traj, desc="SOAP Frame"), real_traj):
# I update the host lattice positions here, once for every timestep
structure.positions[:] = pos[soap_mask]
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 6e03f17..d7d47de 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -99,7 +99,7 @@ class DotProdClassifier(object):
first_iter = True
- for iteration in xrange(self._max_iters):
+ for iteration in tqdm(xrange(self._max_iters), desc = "Clustering iter.", total = float('inf')):
# This iteration's centers
# The first sample is always its own cluster
cluster_centers[0] = old_centers[0]
@@ -107,7 +107,7 @@ class DotProdClassifier(object):
n_assigned_to[0] = old_n_assigned[0]
n_clusters = 1
# skip the first sample which has already been accounted for
- for i, vec in tqdm(zip(xrange(1, old_n_clusters), old_centers[1:old_n_clusters]), desc = "Iteration %i" % iteration):
+ for i, vec in zip(xrange(1, old_n_clusters), old_centers[1:old_n_clusters]):
assigned_to = -1
assigned_cosang = 0.0
From 217f8165701c9d496f6ba9c052804cc78774e8b3 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 27 Jun 2019 10:46:12 -0400
Subject: [PATCH 037/129] Allow dealing with unwrapped trajectories
---
sitator/landmark/LandmarkAnalysis.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index c9ffa2c..f7e525e 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -135,6 +135,9 @@ def run(self, sn, frames):
The input SiteNetwork is a network of predicted sites; it's sites will
be used as the "basis" for the landmark vectors.
+ Wraps a copy of `frames` into the unit cell; if you know `frames` is already
+ wrapped, set `do_wrap = False` to avoid the copy.
+
Takes a SiteNetwork and returns a SiteTrajectory.
"""
assert isinstance(sn, SiteNetwork)
@@ -155,6 +158,16 @@ def run(self, sn, frames):
# Create PBCCalculator
self._pbcc = PBCCalculator(sn.structure.cell)
+ # -- Step 0: Wrap to Unit Cell
+ orig_frames = frames # Keep a reference around
+ frames = frames.copy()
+ # Flatten to list of points for wrapping
+ orig_frame_shape = frames.shape
+ frames.shape = (orig_frame_shape[0] * orig_frame_shape[1], 3)
+ self._pbcc.wrap_points(frames)
+ # Back to list of frames
+ frames.shape = orig_frame_shape
+
# -- Step 1: Compute site-to-vertex distances
self._landmark_dimension = sn.n_sites
@@ -256,7 +269,7 @@ def run(self, sn, frames):
assert out_sn.vertices is None
out_st = SiteTrajectory(out_sn, lmk_lbls, lmk_confs)
- out_st.set_real_traj(frames)
+ out_st.set_real_traj(orig_frames)
self._has_run = True
return out_st
From a625caf5c6655c213f11b51aaa5cef91b6cf88fc Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 28 Jun 2019 13:33:51 -0400
Subject: [PATCH 038/129] Major refactor of merging; add MergeSitesByBarrier
---
sitator/dynamics/MergeSitesByDynamics.py | 83 +-----------
sitator/dynamics/__init__.py | 4 +-
sitator/landmark/pointmerge.py | 72 -----------
sitator/misc/MergeSitesByBarrier.py | 0
.../DiffusionPathwayAnalysis.py | 0
sitator/network/MergeSitesByBarrier.py | 120 ++++++++++++++++++
sitator/network/__init__.py | 3 +
sitator/network/merging.py | 113 +++++++++++++++++
8 files changed, 246 insertions(+), 149 deletions(-)
delete mode 100644 sitator/landmark/pointmerge.py
create mode 100644 sitator/misc/MergeSitesByBarrier.py
rename sitator/{dynamics => network}/DiffusionPathwayAnalysis.py (100%)
create mode 100644 sitator/network/MergeSitesByBarrier.py
create mode 100644 sitator/network/__init__.py
create mode 100644 sitator/network/merging.py
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index 732622c..b5613b6 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -1,24 +1,14 @@
import numpy as np
-from sitator import SiteNetwork, SiteTrajectory
from sitator.dynamics import JumpAnalysis
from sitator.util import PBCCalculator
+from sitator.network.merging import MergeSites
import logging
logger = logging.getLogger(__name__)
-class MergeSitesError(Exception):
- pass
-class MergedSitesTooDistantError(MergeSitesError):
- pass
-
-class TooFewMergedSitesError(MergeSitesError):
- pass
-
-
-
-class MergeSitesByDynamics(object):
+class MergeSitesByDynamics(MergeSites):
"""Merges sites using dynamical data.
Given a SiteTrajectory, merges sites using Markov Clustering.
@@ -48,6 +38,8 @@ def __init__(self,
iterlimit = 100,
markov_parameters = {}):
+ super().__init__(post_check_thresh_factor * distance_threshold)
+
if connectivity_matrix_generator is None:
connectivity_matrix_generator = MergeSitesByDynamics.connectivity_n_ij
assert callable(connectivity_matrix_generator)
@@ -106,18 +98,13 @@ def cfunc(sn):
dmat *= -0.5
np.exp(dmat, out = dmat)
- return sn.p_ij + jump_lag_coeff * jl + distance_coeff * dmat
+ return (sn.p_ij + jump_lag_coeff * jl) * (distance_coeff * dmat + (1 - distance_coeff))
return cfunc
# Real methods
- def run(self, st):
- """Takes a SiteTrajectory and returns a SiteTrajectory, including a new SiteNetwork."""
-
- if self.check_types and st.site_network.site_types is None:
- raise ValueError("Cannot run a check_types=True MergeSitesByDynamics on a SiteTrajectory without type information.")
-
+ def _get_sites_to_merge(self, st):
# -- Compute jump statistics
if not st.site_network.has_attribute('p_ij'):
ja = JumpAnalysis()
@@ -125,8 +112,6 @@ def run(self, st):
pbcc = PBCCalculator(st.site_network.structure.cell)
site_centers = st.site_network.centers
- if self.check_types:
- site_types = st.site_network.site_types
# -- Build connectivity_matrix
connectivity_matrix = self.connectivity_matrix_generator(st.site_network).copy()
@@ -162,61 +147,7 @@ def run(self, st):
# -- Do Markov Clustering
clusters = self._markov_clustering(connectivity_matrix, **self.markov_parameters)
-
- new_n_sites = len(clusters)
-
- logger.info("After merge there will be %i sites" % new_n_sites)
-
- if new_n_sites < np.sum(st.site_network.mobile_mask):
- raise TooFewMergedSitesError("There are %i mobile atoms in this system, but only %i sites after merge" % (np.sum(st.site_network.mobile_mask), new_n_sites))
-
- if self.check_types:
- new_types = np.empty(shape = new_n_sites, dtype = np.int)
-
- # -- Merge Sites
- new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype)
- translation = np.empty(shape = st.site_network.n_sites, dtype = np.int)
- translation.fill(-1)
-
- for newsite in range(new_n_sites):
- mask = list(clusters[newsite])
- # Update translation table
- if np.any(translation[mask] != -1):
- # We've assigned a different cluster for this before... weird
- # degeneracy
- raise ValueError("Markov clustering tried to merge site(s) into more than one new site. This shouldn't happen.")
- translation[mask] = newsite
-
- to_merge = site_centers[mask]
-
- # Check distances
- if not self.post_check_thresh_factor is None:
- dists = pbcc.distances(to_merge[0], to_merge[1:])
- if not np.all(dists < self.post_check_thresh_factor * self.distance_threshold):
- raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold))
-
- # New site center
- new_centers[newsite] = pbcc.average(to_merge)
- if self.check_types:
- assert np.all(site_types[mask] == site_types[mask][0])
- new_types[newsite] = site_types[mask][0]
-
- newsn = st.site_network.copy()
- newsn.centers = new_centers
- if self.check_types:
- newsn.site_types = new_types
-
- newtraj = translation[st._traj]
- newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN
-
- # It doesn't make sense to propagate confidence information through a
- # transform that might completely invalidate it
- newst = SiteTrajectory(newsn, newtraj, confidences = None)
-
- if not st.real_trajectory is None:
- newst.set_real_traj(st.real_trajectory)
-
- return newst
+ return clusters
def _markov_clustering(self,
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index 900f7ca..5515662 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,4 +1,6 @@
from .JumpAnalysis import JumpAnalysis
from .MergeSitesByDynamics import MergeSitesByDynamics
from .RemoveShortJumps import RemoveShortJumps
-from .DiffusionPathwayAnalysis import DiffusionPathwayAnalysis
+
+# For backwards compatability, since this used to be in this module
+from sitator.network import DiffusionPathwayAnalysis
diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py
deleted file mode 100644
index c8dcc8c..0000000
--- a/sitator/landmark/pointmerge.py
+++ /dev/null
@@ -1,72 +0,0 @@
-
-import numpy as np
-
-from sitator.util.progress import tqdm
-
-def merge_points_soap_paths(tsoap,
- pbcc,
- points,
- connectivity_dict,
- threshold,
- n_steps = 5,
- sanity_check_cutoff = np.inf):
- """Merge points using SOAP paths method.
-
- :param SOAP tsoap: to compute SOAPs with.
- :param dict connectivity_dict: Maps a point index to a set of point indexes
- it is connected to, however defined.
- :param threshold: Similarity threshold, 0 < threshold <= 1
- """
-
- merge_sets = set()
-
- points_along = np.empty(shape = (n_steps, 3), dtype = np.float)
- step_vec_mult = np.linspace(0.0, 1.0, num = n_steps)[:, np.newaxis]
-
- for pt_idex in tqdm(connectivity_dict.keys()):
- merge_set = set()
- current_pts = [pt_idex]
- from_soap = None
- keep_going = True
- while keep_going:
- added_this_iter = set()
- for edge_from in current_pts:
- offset = pbcc.cell_centroid - points[edge_from]
- edge_from_pt = pbcc.cell_centroid
-
- for edge_to in connectivity_dict[edge_from] - merge_set:
- edge_to_pt = points[edge_to].copy()
- edge_to_pt += offset
- pbcc.wrap_point(edge_to_pt)
-
- step_vec = edge_to_pt - edge_from_pt
- edge_length = np.linalg.norm(step_vec)
-
- assert edge_length <= sanity_check_cutoff, "edge_length %s" % edge_length
-
- # Points along the line
- for i in range(n_steps):
- points_along[i] = step_vec
- points_along *= step_vec_mult
- points_along += edge_from_pt
- # Re-center back to original center
- points_along -= offset
- # Wrap back into original unit cell - the one frame_atoms has
- pbcc.wrap_points(points_along)
-
- merge = tsoap.soaps_similar_for_points(points_along, threshold = threshold)
-
- if merge:
- added_this_iter.add(edge_from)
- added_this_iter.add(edge_to)
-
- if len(added_this_iter) == 0:
- keep_going = False
- else:
- current_pts = added_this_iter - merge_set
- merge_set.update(added_this_iter)
-
- if len(merge_set) > 0:
- merge_sets.add(frozenset(merge_set))
-
- return merge_sets
diff --git a/sitator/misc/MergeSitesByBarrier.py b/sitator/misc/MergeSitesByBarrier.py
new file mode 100644
index 0000000..e69de29
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py
similarity index 100%
rename from sitator/dynamics/DiffusionPathwayAnalysis.py
rename to sitator/network/DiffusionPathwayAnalysis.py
diff --git a/sitator/network/MergeSitesByBarrier.py b/sitator/network/MergeSitesByBarrier.py
new file mode 100644
index 0000000..0e243b2
--- /dev/null
+++ b/sitator/network/MergeSitesByBarrier.py
@@ -0,0 +1,120 @@
+import numpy as np
+
+import itertools
+
+from scipy.sparse.csgraph import connected_components
+
+from ase.calculators.calculator import all_changes
+
+from sitator.util import PBCCalculator
+from sitator.network.merging import MergeSites
+
+import logging
+logger = logging.getLogger(__name__)
+
+class MergeSitesByBarrier(MergeSites):
+ """Merge sites based on the energy barrier between them.
+
+ Uses a cheap coordinate driving system; this may not be sophisticated enough
+ for complex cases. For each pair of sites within the pairwise distance cutoff,
+ a linear spatial interpolation is applied to produce `n_driven_images`.
+ Two sites are considered mergable if their energies are within
+ `final_initial_energy_threshold` and the barrier between them is below
+ `barrier_threshold`. The barrier is defined as the maximum image energy minus
+ the average of the initial and final energy.
+
+ The energies of the mobile atom are calculated in a static lattice given
+ by `coordinating_mask`; if `None`, this is set to the systems `static_mask`.
+
+ For resonable performance, `calculator` should be something simple like
+ `ase.calculators.lj.LennardJones`.
+
+ Takes species of first mobile atom as mobile species.
+
+ Args:
+ - calculator (ase.Calculator): For computing total potential energies.
+ - final_initial_energy_threshold (float, eV): The maximum difference in
+ energies between two sites for them to be mergable.
+ - barrier_threshold (float, eV): The barrier value above which two sites
+ are not mergable.
+ - n_driven_images (int, default: None): The number of evenly distributed
+ driven images to use.
+ - maximum_pairwise_distance (float, Angstrom): The maximum distance
+ between two sites for them to be considered for merging.
+ - maximum_merge_distance (float, Angstrom): The maxiumum pairwise distance
+ among a group of sites chosed to be merged.
+ """
+ def __init__(self,
+ calculator,
+ final_initial_energy_threshold,
+ barrier_threshold,
+ n_driven_images = None,
+ maximum_pairwise_distance = 2,
+ maximum_merge_distance = 2):
+ super().__init__(maximum_merge_distance)
+ self.final_initial_energy_threshold = final_initial_energy_threshold
+ self.barrier_threshold = barrier_threshold
+ self.maximum_pairwise_distance = maximum_pairwise_distance
+ self.n_driven_images = n_driven_images
+ self.calculator = calculator
+
+
+ def _get_sites_to_merge(self, st, coordinating_mask = None):
+ sn = st.site_network
+ pos = sn.centers
+ if coordinating_mask is None:
+ coordinating_mask = sn.static_mask
+ else:
+ assert not np.any(coordinating_mask & sn.mobile_mask)
+ # -- Build images
+ mobile_idex = np.where(sn.mobile_mask)[0][0]
+ one_mobile_structure = sn.structure[coordinating_mask]
+ one_mobile_structure.extend(sn.structure[mobile_idex])
+ mobile_idex = -1
+ #images = [one_mobile_structure.copy() for _ in range(self.n_driven_images)]
+ interpolation_coeffs = np.linspace(0, 1, self.n_driven_images)
+ energies = np.empty(shape = self.n_driven_images)
+
+ # -- Decide on pairs to check
+ pbcc = PBCCalculator(sn.structure.cell)
+ dists = pbcc.pairwise_distances(pos)
+ # At the start, all within distance cutoff are mergable
+ mergable = dists <= self.maximum_pairwise_distance
+
+ # -- Check pairs' barriers
+ # Symmetric, and diagonal is trivially true. Combinations avoids those cases.
+ jbuf = pos[0].copy()
+ first_calculate = True
+ for i, j in itertools.combinations(range(sn.n_sites)):
+ jbuf[:] = pos[j]
+ # Get minimage
+ _ = pbcc.min_image(pos[i], jbuf)
+ # Do coordinate driving
+ vector = jbuf - pos[i]
+ for image_i in range(self.n_driven_images):
+ one_mobile_structure.positions[mobile_idex] = vector
+ one_mobile_structure.positions[mobile_idex] *= interpolation_coeffs[image_i]
+ one_mobile_structure.positions[mobile_idex] += pos[i]
+ energies[image_i] = self.calculator.calculate(atoms = one_mobile_structure,
+ properties = ['energy'],
+ system_changes = (all_changes if first_calculate else ['positions']))
+ first_calculate = False
+ # Check barrier
+ barrier_idex = np.argmax(energies)
+ if np.abs(energies[0] - energies[-1]) > self.final_initial_energy_threshold:
+ mergable[i, j] = mergable[j, i] = False
+ # Average the initial and final states for a baseline
+ baseline_energy = 0.5 * (energies[0] + energies[-1])
+ barrier_height = energies[barrier_idex] - baseline_energy
+ if barrier_height > self.barrier_threshold:
+ mergable[i, j] = mergable[j, i] = False
+
+ # Get mergable groups
+ n_merged_sites, labels = connected_components(mergable)
+ # MergeSites will check pairwise distances; we just need to make it the
+ # right format.
+ merge_groups = []
+ for lbl in range(n_merged_sites):
+ merge_groups.append(np.where(labels == lbl)[0])
+
+ return merge_groups
diff --git a/sitator/network/__init__.py b/sitator/network/__init__.py
new file mode 100644
index 0000000..590c51c
--- /dev/null
+++ b/sitator/network/__init__.py
@@ -0,0 +1,3 @@
+from .DiffusionPathwayAnalysis import DiffusionPathwayAnalysis
+
+from . import merging
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
new file mode 100644
index 0000000..c713d33
--- /dev/null
+++ b/sitator/network/merging.py
@@ -0,0 +1,113 @@
+import numpy as np
+
+import abc
+
+from sitator import SiteNetwork, SiteTrajectory
+
+import logging
+logger = logging.getLogger(__name__)
+
+class MergeSitesError(Exception):
+ pass
+
+class MergedSitesTooDistantError(MergeSitesError):
+ pass
+
+class TooFewMergedSitesError(MergeSitesError):
+ pass
+
+
+class MergeSites(abc.ABC):
+ """Abstract base class for merging sites.
+
+ :param bool check_types: If True, only sites of the same type are candidates to
+ be merged; if false, type information is ignored. Merged sites will only
+ be assigned types if this is True.
+ """
+ def __init__(self,
+ check_types = True,
+ maximum_merge_distance = None):
+ self.check_types = check_types
+ self.maximum_merge_distance = maximum_merge_distance
+
+
+ def run(self, st, **kwargs):
+ """Takes a SiteTrajectory and returns a SiteTrajectory, including a new SiteNetwork."""
+
+ if self.check_types and st.site_network.site_types is None:
+ raise ValueError("Cannot run a check_types=True MergeSites on a SiteTrajectory without type information.")
+
+ # -- Compute jump statistics
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+ site_centers = st.site_network.centers
+ if self.check_types:
+ site_types = st.site_network.site_types
+
+ clusters = self._get_sites_to_merge(st, **kwargs)
+
+ new_n_sites = len(clusters)
+
+ logger.info("After merge there will be %i sites for %i mobile particles" % (new_n_sites, st.site_network.n_mobile))
+
+ if new_n_sites < st.site_network.n_mobile:
+ raise TooFewMergedSitesError("There are %i mobile atoms in this system, but only %i sites after merge" % (np.sum(st.site_network.mobile_mask), new_n_sites))
+
+ if self.check_types:
+ new_types = np.empty(shape = new_n_sites, dtype = np.int)
+
+ # -- Merge Sites
+ new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype)
+ translation = np.empty(shape = st.site_network.n_sites, dtype = np.int)
+ translation.fill(-1)
+
+ for newsite in range(new_n_sites):
+ mask = list(clusters[newsite])
+ # Update translation table
+ if np.any(translation[mask] != -1):
+ # We've assigned a different cluster for this before... weird
+ # degeneracy
+ raise ValueError("Site merging tried to merge site(s) into more than one new site. This shouldn't happen.")
+ translation[mask] = newsite
+
+ to_merge = site_centers[mask]
+
+ # Check distances
+ if not self.maximum_merge_distance is None:
+ dists = pbcc.distances(to_merge[0], to_merge[1:])
+ if not np.all(dists <= self.maximum_merge_distance):
+ raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold))
+
+ # New site center
+ new_centers[newsite] = pbcc.average(to_merge)
+ if self.check_types:
+ assert np.all(site_types[mask] == site_types[mask][0])
+ new_types[newsite] = site_types[mask][0]
+
+ newsn = st.site_network.copy()
+ newsn.centers = new_centers
+ if self.check_types:
+ newsn.site_types = new_types
+
+ newtraj = translation[st._traj]
+ newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN
+
+ # It doesn't make sense to propagate confidence information through a
+ # transform that might completely invalidate it
+ newst = SiteTrajectory(newsn, newtraj, confidences = None)
+
+ if not st.real_trajectory is None:
+ newst.set_real_traj(st.real_trajectory)
+
+ return newst
+
+ @abc.abstractmethod
+ def _get_sites_to_merge(self, st, **kwargs):
+ """Get the groups of sites to merge.
+
+ Returns a list of list/tuples each containing the numbers of sites to be merged.
+ There should be no overlap, and every site should be mentioned in at most
+ one site merging group.
+
+ If not mentioned in any, the site will disappear.
+ """
+ pass
From 9634c492cbd1e5e79c159664b702f2bf706c7a81 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 1 Jul 2019 16:04:42 -0400
Subject: [PATCH 039/129] Continued merging improvements
---
sitator/dynamics/MergeSitesByDynamics.py | 7 ++-
sitator/network/MergeSitesByBarrier.py | 58 ++++++++++++++++--------
sitator/network/__init__.py | 1 +
sitator/network/merging.py | 5 +-
4 files changed, 47 insertions(+), 24 deletions(-)
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index b5613b6..556f408 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -38,7 +38,10 @@ def __init__(self,
iterlimit = 100,
markov_parameters = {}):
- super().__init__(post_check_thresh_factor * distance_threshold)
+ super().__init__(
+ maximum_merge_distance = post_check_thresh_factor * distance_threshold,
+ check_types = check_types
+ )
if connectivity_matrix_generator is None:
connectivity_matrix_generator = MergeSitesByDynamics.connectivity_n_ij
@@ -106,7 +109,7 @@ def cfunc(sn):
def _get_sites_to_merge(self, st):
# -- Compute jump statistics
- if not st.site_network.has_attribute('p_ij'):
+ if not st.site_network.has_attribute('n_ij'):
ja = JumpAnalysis()
ja.run(st)
diff --git a/sitator/network/MergeSitesByBarrier.py b/sitator/network/MergeSitesByBarrier.py
index 0e243b2..5d7879e 100644
--- a/sitator/network/MergeSitesByBarrier.py
+++ b/sitator/network/MergeSitesByBarrier.py
@@ -1,12 +1,14 @@
import numpy as np
import itertools
+import math
from scipy.sparse.csgraph import connected_components
from ase.calculators.calculator import all_changes
from sitator.util import PBCCalculator
+from sitator.util.progress import tqdm
from sitator.network.merging import MergeSites
import logging
@@ -33,34 +35,43 @@ class MergeSitesByBarrier(MergeSites):
Args:
- calculator (ase.Calculator): For computing total potential energies.
- - final_initial_energy_threshold (float, eV): The maximum difference in
- energies between two sites for them to be mergable.
- barrier_threshold (float, eV): The barrier value above which two sites
are not mergable.
- n_driven_images (int, default: None): The number of evenly distributed
driven images to use.
- maximum_pairwise_distance (float, Angstrom): The maximum distance
between two sites for them to be considered for merging.
+ - minimum_jumps_mergable (int): The minimum number of observed jumps
+ between two sites for their merging to be considered. Setting this
+ higher can avoid unnecessary computations.
- maximum_merge_distance (float, Angstrom): The maxiumum pairwise distance
among a group of sites chosed to be merged.
"""
def __init__(self,
calculator,
- final_initial_energy_threshold,
barrier_threshold,
- n_driven_images = None,
+ n_driven_images = 20,
maximum_pairwise_distance = 2,
- maximum_merge_distance = 2):
- super().__init__(maximum_merge_distance)
- self.final_initial_energy_threshold = final_initial_energy_threshold
+ minimum_jumps_mergable = 1,
+ maximum_merge_distance = 2,
+ **kwargs):
+ super().__init__(maximum_merge_distance = maximum_merge_distance, **kwargs)
self.barrier_threshold = barrier_threshold
self.maximum_pairwise_distance = maximum_pairwise_distance
+ self.minimum_jumps_mergable = minimum_jumps_mergable
+ assert n_driven_images >= 3, "Must have at least initial, transition, and final."
self.n_driven_images = n_driven_images
self.calculator = calculator
def _get_sites_to_merge(self, st, coordinating_mask = None):
sn = st.site_network
+
+ # -- Compute jump statistics
+ if not sn.has_attribute('n_ij'):
+ ja = JumpAnalysis()
+ ja.run(st)
+
pos = sn.centers
if coordinating_mask is None:
coordinating_mask = sn.static_mask
@@ -71,7 +82,7 @@ def _get_sites_to_merge(self, st, coordinating_mask = None):
one_mobile_structure = sn.structure[coordinating_mask]
one_mobile_structure.extend(sn.structure[mobile_idex])
mobile_idex = -1
- #images = [one_mobile_structure.copy() for _ in range(self.n_driven_images)]
+ one_mobile_structure.set_calculator(self.calculator)
interpolation_coeffs = np.linspace(0, 1, self.n_driven_images)
energies = np.empty(shape = self.n_driven_images)
@@ -80,12 +91,15 @@ def _get_sites_to_merge(self, st, coordinating_mask = None):
dists = pbcc.pairwise_distances(pos)
# At the start, all within distance cutoff are mergable
mergable = dists <= self.maximum_pairwise_distance
+ mergable &= sn.n_ij >= self.minimum_jumps_mergable
# -- Check pairs' barriers
# Symmetric, and diagonal is trivially true. Combinations avoids those cases.
jbuf = pos[0].copy()
first_calculate = True
- for i, j in itertools.combinations(range(sn.n_sites)):
+ mergable_pairs = (p for p in itertools.combinations(range(sn.n_sites), r = 2) if mergable[p] or mergable[p[1], p[0]])
+ n_mergable = (np.sum(mergable) - sn.n_sites) // 2
+ for i, j in tqdm(mergable_pairs, total = n_mergable):
jbuf[:] = pos[j]
# Get minimage
_ = pbcc.min_image(pos[i], jbuf)
@@ -95,22 +109,26 @@ def _get_sites_to_merge(self, st, coordinating_mask = None):
one_mobile_structure.positions[mobile_idex] = vector
one_mobile_structure.positions[mobile_idex] *= interpolation_coeffs[image_i]
one_mobile_structure.positions[mobile_idex] += pos[i]
- energies[image_i] = self.calculator.calculate(atoms = one_mobile_structure,
- properties = ['energy'],
- system_changes = (all_changes if first_calculate else ['positions']))
+ energies[image_i] = one_mobile_structure.get_potential_energy()
first_calculate = False
# Check barrier
barrier_idex = np.argmax(energies)
- if np.abs(energies[0] - energies[-1]) > self.final_initial_energy_threshold:
- mergable[i, j] = mergable[j, i] = False
- # Average the initial and final states for a baseline
- baseline_energy = 0.5 * (energies[0] + energies[-1])
- barrier_height = energies[barrier_idex] - baseline_energy
- if barrier_height > self.barrier_threshold:
- mergable[i, j] = mergable[j, i] = False
+ forward_barrier = energies[barrier_idex] - energies[0]
+ backward_barrier = energies[barrier_idex] - energies[-1]
+ # If it's an actual maxima barrier between them, then we want to
+ # check its height
+ if barrier_idex != 0 and barrier_idex != self.n_driven_images - 1:
+ mergable[i, j] = forward_barrier <= self.barrier_threshold
+ mergable[j, i] = backward_barrier <= self.barrier_threshold
+ # Otherwise, if there's no maxima between them, they are in the same
+ # basin.
# Get mergable groups
- n_merged_sites, labels = connected_components(mergable)
+ n_merged_sites, labels = connected_components(
+ mergable,
+ directed = True,
+ connection = 'strong'
+ )
# MergeSites will check pairwise distances; we just need to make it the
# right format.
merge_groups = []
diff --git a/sitator/network/__init__.py b/sitator/network/__init__.py
index 590c51c..0dd8174 100644
--- a/sitator/network/__init__.py
+++ b/sitator/network/__init__.py
@@ -1,3 +1,4 @@
from .DiffusionPathwayAnalysis import DiffusionPathwayAnalysis
from . import merging
+from .MergeSitesByBarrier import MergeSitesByBarrier
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
index c713d33..4e359fc 100644
--- a/sitator/network/merging.py
+++ b/sitator/network/merging.py
@@ -2,6 +2,7 @@
import abc
+from sitator.util import PBCCalculator
from sitator import SiteNetwork, SiteTrajectory
import logging
@@ -47,7 +48,7 @@ def run(self, st, **kwargs):
new_n_sites = len(clusters)
- logger.info("After merge there will be %i sites for %i mobile particles" % (new_n_sites, st.site_network.n_mobile))
+ logger.info("After merging %i sites there will be %i sites for %i mobile particles" % (len(site_centers), new_n_sites, st.site_network.n_mobile))
if new_n_sites < st.site_network.n_mobile:
raise TooFewMergedSitesError("There are %i mobile atoms in this system, but only %i sites after merge" % (np.sum(st.site_network.mobile_mask), new_n_sites))
@@ -75,7 +76,7 @@ def run(self, st, **kwargs):
if not self.maximum_merge_distance is None:
dists = pbcc.distances(to_merge[0], to_merge[1:])
if not np.all(dists <= self.maximum_merge_distance):
- raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold))
+ raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %.2f apart. Lower your distance_threshold?" % self.maximum_merge_distance)
# New site center
new_centers[newsite] = pbcc.average(to_merge)
From d074fe055162be982c48aa67aa00515feb2cfc41 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 2 Jul 2019 10:25:34 -0400
Subject: [PATCH 040/129] Added `MergeSitesByThreshold`
---
sitator/dynamics/MergeSitesByThreshold.py | 58 +++++++++++++++++++++++
sitator/dynamics/__init__.py | 1 +
2 files changed, 59 insertions(+)
create mode 100644 sitator/dynamics/MergeSitesByThreshold.py
diff --git a/sitator/dynamics/MergeSitesByThreshold.py b/sitator/dynamics/MergeSitesByThreshold.py
new file mode 100644
index 0000000..04ed522
--- /dev/null
+++ b/sitator/dynamics/MergeSitesByThreshold.py
@@ -0,0 +1,58 @@
+import numpy as np
+
+import operator
+
+from scipy.sparse.csgraph import connected_components
+
+from sitator.network.merging import MergeSites
+
+
+class MergeSitesByThreshold(MergeSites):
+ """Merge sites using a strict threshold on any edge property.
+
+ Takes the edge property matrix given by `attrname`, applys `relation` to it
+ with `threshold`, and merges all connected components in the graph represented
+ by the resulting boolean adjacency matrix.
+
+ Threshold is given by a keyword argument to `run()`.
+
+ Args:
+ - attrname (str): Name of the edge attribute to merge on.
+ - relation (func, default: operator.ge): The relation to use for the
+ thresholding.
+ - directed, connection (bool, str): Parameters for scipy.sparse.csgraph's
+ `connected_components`.
+ - **kwargs: Passed to `MergeSites`.
+ """
+ def __init__(self,
+ attrname,
+ relation = operator.ge,
+ directed = True,
+ connection = 'strong',
+ **kwargs):
+ self.attrname = attrname
+ self.relation = relation
+ self.directed = directed
+ self.connection = connection
+ super().__init__(**kwargs)
+
+
+ def _get_sites_to_merge(self, st, threshold = 0):
+ sn = st.site_network
+
+ attrmat = getattr(sn, self.attrname)
+ assert attrmat.shape == (sn.n_sites, sn.n_sites), "`attrname` doesn't seem to indicate an edge property."
+
+ # Get mergable groups
+ n_merged_sites, labels = connected_components(
+ self.relation(attrmat, threshold),
+ directed = self.directed,
+ connection = self.connection
+ )
+ # MergeSites will check pairwise distances; we just need to make it the
+ # right format.
+ merge_groups = []
+ for lbl in range(n_merged_sites):
+ merge_groups.append(np.where(labels == lbl)[0])
+
+ return merge_groups
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index 5515662..0a284fd 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,5 +1,6 @@
from .JumpAnalysis import JumpAnalysis
from .MergeSitesByDynamics import MergeSitesByDynamics
+from .MergeSitesByThreshold import MergeSitesByThreshold
from .RemoveShortJumps import RemoveShortJumps
# For backwards compatability, since this used to be in this module
From abf864ac1cd0f25d50537d46568081abfc94cd4a Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 2 Jul 2019 16:23:58 -0400
Subject: [PATCH 041/129] Added `jumps()` generator function
---
sitator/SiteTrajectory.py | 37 +++++++++++++++++++++++++++++++++++++
1 file changed, 37 insertions(+)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index fbaa7cf..7bd13e4 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -212,6 +212,43 @@ def assign_to_last_known_site(self, frame_threshold = 1):
return res
+ def jumps(self, unknown_as_jump = False):
+ """Generator to iterate over all jumps in the trajectory.
+
+ A jump is considered to occur "at the frame" when it first acheives its
+ new site. Ex:
+ Frame 0: Atom 1 at site 4 --> Frame 1: Atom 1 at site 5
+ will yield a jump (1, 1, 4, 5).
+
+ Yields tuples of the form:
+
+ (frame_number, mobile_atom_number, from_site, to_site)
+
+ Args:
+ - unknown_as_jump (bool): If True, moving from a site to unknown
+ (or vice versa) is considered a jump; if False, unassigned mobile
+ atoms are considered to be at their last known sites.
+ """
+ traj = self.traj
+ n_mobile = self.site_network.n_mobile
+ assert n_mobile == traj.shape[1]
+ last_known = traj[0].copy()
+ known = np.ones(shape = len(last_known), dtype = np.bool)
+ jumped = np.zeros(shape = len(last_known), dtype = np.bool)
+ for frame_i in range(1, self.n_frames):
+ if not unknown_as_jump:
+ np.not_equal(traj[frame_i], SiteTrajectory.SITE_UNKNOWN, out = known)
+
+ np.not_equal(traj[frame_i], last_known, out = jumped)
+ jumped &= known # Must be currently known to have jumped
+
+ for atom_i in range(n_mobile):
+ if jumped[atom_i]:
+ yield frame_i, atom_i, last_known[atom_i], traj[frame_i, atom_i]
+
+ last_known[known] = traj[frame_i, known]
+
+
# ---- Plotting code
def plot_frame(self, *args, **kwargs):
if self._default_plotter is None:
From 135f477d3d77eb53ce90203adff2bcd830c66529 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 2 Jul 2019 16:24:14 -0400
Subject: [PATCH 042/129] IO Cleanup
---
sitator/SiteNetwork.py | 77 ++-------------------------------
sitator/misc/oldio.py | 82 +++++++++++++++++++++++++++++++++++
sitator/plotting.py | 97 ------------------------------------------
3 files changed, 85 insertions(+), 171 deletions(-)
create mode 100644 sitator/misc/oldio.py
delete mode 100644 sitator/plotting.py
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 526e197..8027638 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -104,78 +104,6 @@ def __getitem__(self, key):
return sn
- _STRUCT_FNAME = "structure.xyz"
- _SMASK_FNAME = "static_mask.npy"
- _MMASK_FNAME = "mobile_mask.npy"
- _MAIN_FNAMES = ['centers', 'vertices', 'site_types']
-
- def save(self, file):
- """Save this SiteNetwork to a tar archive."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # -- Write the structure
- ase.io.write(os.path.join(tmpdir, self._STRUCT_FNAME), self.structure, parallel = False)
- # -- Write masks
- np.save(os.path.join(tmpdir, self._SMASK_FNAME), self.static_mask)
- np.save(os.path.join(tmpdir, self._MMASK_FNAME), self.mobile_mask)
- # -- Write what we have
- for arrname in self._MAIN_FNAMES:
- if not getattr(self, arrname) is None:
- np.save(os.path.join(tmpdir, "%s.npy" % arrname), getattr(self, arrname))
- # -- Write all site/edge attributes
- for atype, attrs in zip(("site_attr", "edge_attr"), (self._site_attrs, self._edge_attrs)):
- for attr in attrs:
- np.save(os.path.join(tmpdir, "%s-%s.npy" % (atype, attr)), attrs[attr])
- # -- Write final archive
- with tarfile.open(file, mode = 'w:gz', format = tarfile.PAX_FORMAT) as outf:
- outf.add(tmpdir, arcname = "")
-
- @classmethod
- def from_file(cls, file):
- """Load a SiteNetwork from a tar file/file descriptor."""
- all_others = {}
- site_attrs = {}
- edge_attrs = {}
- structure = None
- with tarfile.open(file, mode = 'r:gz', format = tarfile.PAX_FORMAT) as input:
- # -- Load everything
- for member in input.getmembers():
- if member.name == '':
- continue
- f = input.extractfile(member)
- if member.name == cls._STRUCT_FNAME:
- with tempfile.TemporaryDirectory() as tmpdir:
- input.extract(member, path = tmpdir)
- structure = ase.io.read(os.path.join(tmpdir, member.name), format = 'xyz')
- else:
- basename = os.path.splitext(os.path.basename(member.name))[0]
- data = np.load(f)
- if basename.startswith("site_attr"):
- site_attrs[basename.split('-')[1]] = data
- elif basename.startswith("edge_attr"):
- edge_attrs[basename.split('-')[1]] = data
- else:
- all_others[basename] = data
-
- # Create SiteNetwork
- assert not structure is None
- assert all(k in all_others for k in ("static_mask", "mobile_mask")), "Malformed SiteNetwork file."
- sn = SiteNetwork(structure,
- all_others['static_mask'],
- all_others['mobile_mask'])
- if 'centers' in all_others:
- sn.centers = all_others['centers']
- for key in all_others:
- if key in ('centers', 'static_mask', 'mobile_mask'):
- continue
- setattr(sn, key, all_others[key])
-
- assert all(len(sa) == sn.n_sites for sa in site_attrs.values())
- assert all(ea.shape == (sn.n_sites, sn.n_sites) for ea in edge_attrs.values())
- sn._site_attrs = site_attrs
- sn._edge_attrs = edge_attrs
-
- return sn
-
def of_type(self, stype):
"""Returns a "view" to this SiteNetwork with only sites of a certain type."""
if self._types is None:
@@ -272,9 +200,10 @@ def remove_attribute(self, attr):
raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attr)
def __getattr__(self, attrkey):
- if attrkey in self._site_attrs:
+ v = vars(self)
+ if '_site_attrs' in v and attrkey in self._site_attrs:
return self._site_attrs[attrkey]
- elif attrkey in self._edge_attrs:
+ elif '_edge_attrs' in v and attrkey in self._edge_attrs:
return self._edge_attrs[attrkey]
else:
raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attrkey)
diff --git a/sitator/misc/oldio.py b/sitator/misc/oldio.py
new file mode 100644
index 0000000..ddd7ce2
--- /dev/null
+++ b/sitator/misc/oldio.py
@@ -0,0 +1,82 @@
+import numpy as np
+
+import tempfile
+import tarfile
+import os
+
+import ase
+import ase.io
+
+from sitator import SiteNetwork
+
+_STRUCT_FNAME = "structure.xyz"
+_SMASK_FNAME = "static_mask.npy"
+_MMASK_FNAME = "mobile_mask.npy"
+_MAIN_FNAMES = ['centers', 'vertices', 'site_types']
+
+def save(sn, file):
+ """Save this SiteNetwork to a tar archive."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # -- Write the structure
+ ase.io.write(os.path.join(tmpdir, _STRUCT_FNAME), sn.structure, parallel = False)
+ # -- Write masks
+ np.save(os.path.join(tmpdir, _SMASK_FNAME), sn.static_mask)
+ np.save(os.path.join(tmpdir, _MMASK_FNAME), sn.mobile_mask)
+ # -- Write what we have
+ for arrname in _MAIN_FNAMES:
+ if not getattr(sn, arrname) is None:
+ np.save(os.path.join(tmpdir, "%s.npy" % arrname), getattr(sn, arrname))
+ # -- Write all site/edge attributes
+ for atype, attrs in zip(("site_attr", "edge_attr"), (sn._site_attrs, sn._edge_attrs)):
+ for attr in attrs:
+ np.save(os.path.join(tmpdir, "%s-%s.npy" % (atype, attr)), attrs[attr])
+ # -- Write final archive
+ with tarfile.open(file, mode = 'w:gz', format = tarfile.PAX_FORMAT) as outf:
+ outf.add(tmpdir, arcname = "")
+
+
+def from_file(file):
+ """Load a SiteNetwork from a tar file/file descriptor."""
+ all_others = {}
+ site_attrs = {}
+ edge_attrs = {}
+ structure = None
+ with tarfile.open(file, mode = 'r:gz', format = tarfile.PAX_FORMAT) as input:
+ # -- Load everything
+ for member in input.getmembers():
+ if member.name == '':
+ continue
+ f = input.extractfile(member)
+ if member.name == _STRUCT_FNAME:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ input.extract(member, path = tmpdir)
+ structure = ase.io.read(os.path.join(tmpdir, member.name), format = 'xyz')
+ else:
+ basename = os.path.splitext(os.path.basename(member.name))[0]
+ data = np.load(f)
+ if basename.startswith("site_attr"):
+ site_attrs[basename.split('-')[1]] = data
+ elif basename.startswith("edge_attr"):
+ edge_attrs[basename.split('-')[1]] = data
+ else:
+ all_others[basename] = data
+
+ # Create SiteNetwork
+ assert not structure is None
+ assert all(k in all_others for k in ("static_mask", "mobile_mask")), "Malformed SiteNetwork file."
+ sn = SiteNetwork(structure,
+ all_others['static_mask'],
+ all_others['mobile_mask'])
+ if 'centers' in all_others:
+ sn.centers = all_others['centers']
+ for key in all_others:
+ if key in ('centers', 'static_mask', 'mobile_mask'):
+ continue
+ setattr(sn, key, all_others[key])
+
+ assert all(len(sa) == sn.n_sites for sa in site_attrs.values())
+ assert all(ea.shape == (sn.n_sites, sn.n_sites) for ea in edge_attrs.values())
+ sn._site_attrs = site_attrs
+ sn._edge_attrs = edge_attrs
+
+ return sn
diff --git a/sitator/plotting.py b/sitator/plotting.py
deleted file mode 100644
index cff27bd..0000000
--- a/sitator/plotting.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import numpy as np
-
-import matplotlib.pyplot as plt
-import matplotlib
-from mpl_toolkits.mplot3d import Axes3D
-
-import ase
-
-from samos.analysis.jumps.voronoi import collapse_into_unit_cell
-
-DEFAULT_COLORS = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] * 2
-
-# From https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to
-def set_axes_equal(ax):
- '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
- cubes as cubes, etc.. This is one possible solution to Matplotlib's
- ax.set_aspect('equal') and ax.axis('equal') not working for 3D.
-
- Input
- ax: a matplotlib axis, e.g., as output from plt.gca().
- '''
-
- x_limits = ax.get_xlim3d()
- y_limits = ax.get_ylim3d()
- z_limits = ax.get_zlim3d()
-
- x_range = abs(x_limits[1] - x_limits[0])
- x_middle = np.mean(x_limits)
- y_range = abs(y_limits[1] - y_limits[0])
- y_middle = np.mean(y_limits)
- z_range = abs(z_limits[1] - z_limits[0])
- z_middle = np.mean(z_limits)
-
- # The plot bounding box is a sphere in the sense of the infinity
- # norm, hence I call half the max range the plot radius.
- plot_radius = 0.5*max([x_range, y_range, z_range])
-
- ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
- ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
- ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
-
-def plot_atoms(atms, species,
- pts=None, pts_cs = None, pts_marker = 'x',
- cell = None,
- hide_species = (),
- wrap = False,
- title = ""):
- fig = plt.figure()
- ax = fig.add_subplot(111, projection="3d")
- cs = {
- 'Li' : "blue", 'O' : "red", 'Ta' : "gray", 'Ge' : "darkgray", 'P' : "orange",
- 'point' : 'black'
- }
-
- if wrap and not cell is None:
- atms = np.asarray([collapse_into_unit_cell(pt, cell) for pt in atms])
- if not pts is None:
- pts = np.asarray([collapse_into_unit_cell(pt, cell) for pt in pts])
-
- for s in hide_species:
- cs[s] = 'none'
-
- ax.scatter(atms[:,0],
- atms[:,1],
- atms[:,2],
- c = [cs.get(e, 'gray') for e in species],
- s = [10.0*(ase.data.atomic_numbers[s])**0.5 for s in species])
-
-
- if not pts is None:
- c = None
- if pts_cs is None:
- c = cs['point']
- else:
- c = pts_cs
- ax.scatter(pts[:,0],
- pts[:,1],
- pts[:,2],
- marker = pts_marker,
- c = c,
- cmap=matplotlib.cm.Dark2)
-
- if not cell is None:
- for cvec in cell:
- cvec = np.array([[0, 0, 0], cvec])
- ax.plot(cvec[:,0],
- cvec[:,1],
- cvec[:,2],
- color = "gray",
- alpha=0.5,
- linestyle="--")
-
- ax.set_title(title)
-
- set_axes_equal(ax)
-
- return ax
From c0802ebcb841fe0ec396f2ed39b90c5da9517ef7 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 5 Jul 2019 14:44:40 -0400
Subject: [PATCH 043/129] Some quick docs cleanup
---
README.md | 3 +--
sitator/site_descriptors/SiteTypeAnalysis.py | 2 +-
2 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 35abef7..19eb832 100644
--- a/README.md
+++ b/README.md
@@ -25,8 +25,7 @@ If you use `sitator` in your research, please consider citing this paper. The Bi
* The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does not have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
* **Site Type Analysis**
* For computing SOAP vectors: the `quip` binary from [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) **or** the [`DScribe`](https://singroup.github.io/dscribe/index.html) Python library.
-
- The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
+ * The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
After downloading, the package is installed with `pip`:
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index afeccce..ee16489 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -19,7 +19,7 @@
raise ImportError("SiteTypeAnalysis requires the `pydpc` package")
class SiteTypeAnalysis(object):
- """Cluster sites into types using a descriptor and DPCLUS.
+ """Cluster sites into types using a descriptor and Density Peak Clustering.
-- descriptor --
Some kind of object implementing:
From 121e7837b03ca0c9f87429b1096b86a1404356c8 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 5 Jul 2019 14:45:02 -0400
Subject: [PATCH 044/129] Fix deletion order bug
---
sitator/network/DiffusionPathwayAnalysis.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/sitator/network/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py
index ca3687e..16d1de6 100644
--- a/sitator/network/DiffusionPathwayAnalysis.py
+++ b/sitator/network/DiffusionPathwayAnalysis.py
@@ -110,8 +110,9 @@ def run(self, sn, return_count = False):
else:
path_mask = cur_site_mask
# Remove individual merged paths
- for i in intersects_with:
- del site_masks[i]
+ # Going in reverse order means indexes don't become invalid as deletes happen
+ for i in sorted(intersects_with, reverse=True):
+ del site_masks[i]
# Add new (super)path
site_masks.append(path_mask)
From 0abaf4da933177e92bd0081b84681efdd1e936d0 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 5 Jul 2019 18:47:12 -0400
Subject: [PATCH 045/129] Corrected counting bugs
---
sitator/site_descriptors/SOAP.py | 23 +++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index 9ae782d..9a86305 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -258,8 +258,10 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
# Now, I need to allocate the output
# so for each site, I count how much data there is!
- # Add one to deal with -1 -> 0, then just ignore the count for 0
- counts = np.bincount(site_traj.reshape(-1) + 1, minlength = nsit)[1:]
+ counts = np.zeros(shape = nsit + 1, dtype = np.int)
+ for frame in site_traj:
+ counts[frame] += 1 # A duplicate in `frame` will still only cause a single addition due to Python rules.
+ counts = counts[:-1]
if self._averaging is not None:
averaging = self._averaging
@@ -275,10 +277,13 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
insufficient = nr_of_descs == 0
if np.any(insufficient):
logger.warning("You're asking to average %i SOAP vectors, but at this stepsize, %i sites are insufficiently occupied. Num occ./averaging: %s" % (averaging, np.sum(insufficient), counts[insufficient] / averaging))
+ averagings = np.full(shape = len(nr_of_descs), fill_value = averaging)
+ averagings[insufficient] = counts[insufficient]
nr_of_descs = np.maximum(nr_of_descs, 1) # If it's 0, just make one with whatever we've got
+ assert np.all(nr_of_descs >= 1)
logger.debug("Minimum # of descriptors/site: %i; maximum: %i" % (np.min(nr_of_descs), np.max(nr_of_descs)))
# This is where I load the descriptor:
- descs = np.zeros((np.sum(nr_of_descs), soaper.n_dim))
+ descs = np.zeros(shape = (np.sum(nr_of_descs), soaper.n_dim))
# An array that tells me the index I'm at for each site type
desc_index = np.asarray([np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))])
@@ -294,18 +299,20 @@ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soa
to_describe = (site_traj_t != SiteTrajectory.SITE_UNKNOWN) & allowed[site_traj_t]
if np.any(to_describe):
+ sites_to_describe = site_traj_t[to_describe]
soaps = soaper(structure, pos[mob_indices[to_describe]])
- soaps /= averaging
-
- idx_to_add_desc = desc_index[site_traj_t[to_describe]]
+ soaps /= averagings[sites_to_describe][:, np.newaxis]
+ idx_to_add_desc = desc_index[sites_to_describe]
descs[idx_to_add_desc] += soaps
- count_of_site[site_traj_t[to_describe]] += 1
+ count_of_site[sites_to_describe] += 1
# Reset and increment full averages
- full_average = count_of_site == averaging
+ full_average = count_of_site == averagings
desc_index[full_average] += 1
count_of_site[full_average] = 0
allowed[max_index == desc_index] = False
+ assert not np.any(allowed) # We should have maxed out all of them after processing all frames.
+
desc_to_site = np.repeat(list(range(nsit)), nr_of_descs)
return descs, desc_to_site
From f378ca3f283438e9e9b56845571af48a2b2bb5fb Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 5 Jul 2019 18:47:49 -0400
Subject: [PATCH 046/129] Exposed merging groups to user
---
sitator/network/merging.py | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
index 4e359fc..31a2c27 100644
--- a/sitator/network/merging.py
+++ b/sitator/network/merging.py
@@ -24,12 +24,19 @@ class MergeSites(abc.ABC):
:param bool check_types: If True, only sites of the same type are candidates to
be merged; if false, type information is ignored. Merged sites will only
be assigned types if this is True.
+ :param float maximum_merge_distance: Maximum distance between two sites
+ that are in a merge group, above which an error will be raised.
+ :param bool set_merged_into: If True, a site attribute `"merged_into"` will
+ be added to the original `SiteNetwork` indicating which new site
+ each old site was merged into.
"""
def __init__(self,
check_types = True,
- maximum_merge_distance = None):
+ maximum_merge_distance = None,
+ set_merged_into = False):
self.check_types = check_types
self.maximum_merge_distance = maximum_merge_distance
+ self.set_merged_into = set_merged_into
def run(self, st, **kwargs):
@@ -99,6 +106,11 @@ def run(self, st, **kwargs):
if not st.real_trajectory is None:
newst.set_real_traj(st.real_trajectory)
+ if self.set_merged_into:
+ if st.site_network.has_attribute("merged_into"):
+ st.site_network.remove_attribute("merged_into")
+ st.site_network.add_site_attribute("merged_into", translation)
+
return newst
@abc.abstractmethod
From 0c0ff6db4af6e5f4ce9de8064aba84f2149bfcb3 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 5 Jul 2019 18:48:08 -0400
Subject: [PATCH 047/129] Refactored multiple occupancy checking into
SiteTrajectory
---
sitator/SiteTrajectory.py | 32 +++++++++++++++++++-
sitator/landmark/LandmarkAnalysis.py | 44 ++++++++++------------------
2 files changed, 47 insertions(+), 29 deletions(-)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 7bd13e4..c375443 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -137,7 +137,11 @@ def real_positions_for_site(self, site, return_confidences = False):
def compute_site_occupancies(self):
- """Computes site occupancies and adds site attribute `occupancies` to site_network."""
+ """Computes site occupancies and adds site attribute `occupancies` to site_network.
+
+ In cases of multiple occupancy, this will be higher than the number of
+ frames in which the site is occupied and could be over 1.0.
+ """
occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames)
if self.site_network.has_attribute('occupancies'):
self.site_network.remove_attribute('occupancies')
@@ -145,6 +149,32 @@ def compute_site_occupancies(self):
return occ
+ def check_multiple_occupancy(self, max_mobile_per_site = 1):
+ """Count cases of "multiple occupancy" where more than one mobile share the same site at the same time.
+
+ These cases usually indicate bad site analysis.
+
+ Returns:
+ - n_multiple_assignments (int): the total number of multiple assignment
+ incidents.
+ - avg_mobile_per_site (float): the average number of mobile atoms
+ """
+ from sitator.landmark.errors import MultipleOccupancyError
+ n_more_than_ones = 0
+ avg_mobile_per_site = 0
+ divisor = 0
+ for frame_i, site_frame in enumerate(self._traj):
+ _, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
+ count_msk = counts > max_mobile_per_site
+ if np.any(count_msk):
+ raise MultipleOccupancyError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i))
+ n_more_than_ones += np.sum(counts > 1)
+ avg_mobile_per_site += np.sum(counts)
+ divisor += len(counts)
+ avg_mobile_per_site /= divisor
+ return n_more_than_ones, avg_mobile_per_site
+
+
def assign_to_last_known_site(self, frame_threshold = 1):
"""Assign unassigned mobile particles to their last known site within
`frame_threshold` frames.
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index f7e525e..0122c59 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -1,12 +1,10 @@
import numpy as np
-from sitator.landmark import MultipleOccupancyError
from sitator.util import PBCCalculator
from sitator.util.progress import tqdm
import sys
-
import importlib
import tempfile
@@ -116,7 +114,7 @@ def __init__(self,
@property
def cutoff(self):
- return self._cutoff
+ return self._cutoff
@analysis_result
def landmark_vectors(self):
@@ -233,23 +231,6 @@ def run(self, sn, frames):
logging.info(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
- # Check that multiple particles are never assigned to one site at the
- # same time, cause that would be wrong.
- n_more_than_ones = 0
- avg_mobile_per_site = 0
- divisor = 0
- for frame_i, site_frame in enumerate(lmk_lbls):
- _, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
- count_msk = counts > self.max_mobile_per_site
- if np.any(count_msk):
- raise MultipleOccupancyError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i))
- n_more_than_ones += np.sum(counts > 1)
- avg_mobile_per_site += np.sum(counts)
- divisor += len(counts)
-
- self.n_multiple_assignments = n_more_than_ones
- self.avg_mobile_per_site = avg_mobile_per_site / float(divisor)
-
# -- Do output
# - Compute site centers
site_centers = np.empty(shape = (n_sites, 3), dtype = frames.dtype)
@@ -269,6 +250,13 @@ def run(self, sn, frames):
assert out_sn.vertices is None
out_st = SiteTrajectory(out_sn, lmk_lbls, lmk_confs)
+
+ # Check that multiple particles are never assigned to one site at the
+ # same time, cause that would be wrong.
+ self.n_multiple_assignments, self.avg_mobile_per_site = out_st.check_multiple_occupancy(
+ max_mobile_per_site = self.max_mobile_per_site
+ )
+
out_st.set_real_traj(orig_frames)
self._has_run = True
@@ -277,11 +265,11 @@ def run(self, sn, frames):
# -------- "private" methods --------
def _do_peak_evening(self):
- if self._peak_evening == 'none':
- return
- elif self._peak_evening == 'clip':
- lvec_peaks = np.max(self._landmark_vectors, axis = 1)
- # Clip all peaks to the lowest "normal" (stdev.) peak
- lvec_clip = np.mean(lvec_peaks) - np.std(lvec_peaks)
- # Do the clipping
- self._landmark_vectors[self._landmark_vectors > lvec_clip] = lvec_clip
+ if self._peak_evening == 'none':
+ return
+ elif self._peak_evening == 'clip':
+ lvec_peaks = np.max(self._landmark_vectors, axis = 1)
+ # Clip all peaks to the lowest "normal" (stdev.) peak
+ lvec_clip = np.mean(lvec_peaks) - np.std(lvec_peaks)
+ # Do the clipping
+ self._landmark_vectors[self._landmark_vectors > lvec_clip] = lvec_clip
From fc69ea2006a9bc36bfb092cb5022a3ba026a16ba Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 8 Jul 2019 22:18:49 -0400
Subject: [PATCH 048/129] Initial new coordination number code
---
sitator/SiteNetwork.py | 21 ++++++
sitator/dynamics/RemoveUnoccupiedSites.py | 37 ++++++++++
sitator/misc/SiteVolumes.py | 73 ++++++++++++++++---
.../SiteCoordinationNumber.py | 38 ++++++++++
sitator/visualization/SiteNetworkPlotter.py | 2 +-
5 files changed, 159 insertions(+), 12 deletions(-)
create mode 100644 sitator/dynamics/RemoveUnoccupiedSites.py
create mode 100644 sitator/site_descriptors/SiteCoordinationNumber.py
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 8027638..4ca812e 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -5,6 +5,7 @@
import tarfile
import tempfile
+import ase
import ase.io
import matplotlib
@@ -114,6 +115,26 @@ def of_type(self, stype):
return self[self._types == stype]
+ def get_structure_with_sites(self, site_atomic_number = None):
+ """Get an `ase.Atoms` with the sites included.
+
+ Args:
+ - site_atomic_number: If `None`, the species of the first mobile atom
+ will be used.
+ Returns:
+ ase.Atoms and final `site_atomic_number`
+ """
+ out = self.static_structure.copy()
+ if site_atomic_number is None:
+ site_atomic_number = self.structure.get_atomic_numbers()[mobile_mask][0]
+ numbers = np.full(len(self), site_atomic_number)
+ sites_atoms = ase.Atoms(
+ positions = self.centers,
+ numbers = numbers
+ )
+ out.extend(sites_atoms)
+ return out, site_atomic_number
+
@property
def n_sites(self):
if self._centers is None:
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
new file mode 100644
index 0000000..4502025
--- /dev/null
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -0,0 +1,37 @@
+import numpy as np
+
+from sitator import SiteTrajectory
+
+class RemoveUnoccupiedSites(object):
+ def __init__(self):
+ pass
+
+ def run(self, st):
+ """
+ """
+ assert isinstance(st, SiteTrajectory)
+
+ old_sn = st.site_network
+
+ seen_mask = np.zeros(shape = old_sn.n_sites, dtype = np.bool)
+
+ for frame in st.traj:
+ seen_mask[frame] = True
+
+ n_new_sites = np.sum(seen_mask)
+ translation = np.empty(shape = old_sn.n_sites, dtype = np.int)
+ translation[seen_mask] = np.arange(n_new_sites)
+ translation[~seen_mask] = -4321
+
+ newtraj = translation[st.traj.reshape(-1)]
+ newtraj.shape = st.traj.shape
+
+ newsn = old_sn[seen_mask]
+
+ new_st = SiteTrajectory(
+ site_network = newsn,
+ particle_assignments = newtraj
+ )
+ if st.real_trajectory is not None:
+ new_st.set_real_traj(st.real_trajectory)
+ return new_st
diff --git a/sitator/misc/SiteVolumes.py b/sitator/misc/SiteVolumes.py
index 5095dae..bd87169 100644
--- a/sitator/misc/SiteVolumes.py
+++ b/sitator/misc/SiteVolumes.py
@@ -10,14 +10,26 @@
logger = logging.getLogger(__name__)
class SiteVolumes(object):
- """Computes the volumes of convex hulls around all positions associated with a site.
+ """Compute the volumes of sites."""
+ def __init__(self):
+ pass
- Adds the `site_volumes` and `site_surface_areas` attributes to the SiteNetwork.
- """
- def __init__(self, n_recenterings = 8):
- self.n_recenterings = n_recenterings
- def run(self, st):
+ def compute_accessable_volumes(self, st, n_recenterings = 8):
+ """Computes the volumes of convex hulls around all positions associated with a site.
+
+ Uses the shift-and-wrap trick for dealing with periodicity, so sites that
+ take up the majority of the unit cell may give bogus results.
+
+ Adds the `accessable_site_volumes` attribute to the SiteNetwork.
+
+ Args:
+ - st (SiteTrajectory)
+ - n_recenterings (int): How many different recenterings to try (the
+ algorithm will recenter around n of the points and take the minimal
+ resulting volume; this deals with cases where there is one outlier
+ where recentering around it gives very bad results.)
+ """
vols = np.empty(shape = st.site_network.n_sites, dtype = np.float)
areas = np.empty(shape = st.site_network.n_sites, dtype = np.float)
@@ -30,16 +42,16 @@ def run(self, st):
vol = np.inf
area = None
- for i in range(self.n_recenterings):
+ for i in range(n_recenterings):
# Recenter
- offset = pbcc.cell_centroid - pos[int(i * (len(pos)/self.n_recenterings))]
+ offset = pbcc.cell_centroid - pos[int(i * (len(pos)/n_recenterings))]
pos += offset
pbcc.wrap_points(pos)
try:
hull = ConvexHull(pos)
except QhullError as qhe:
- logging.warning("For site %i, iter %i: %s" % (site, i, qhe))
+ logger.warning("For site %i, iter %i: %s" % (site, i, qhe))
vols[site] = np.nan
areas[site] = np.nan
continue
@@ -51,5 +63,44 @@ def run(self, st):
vols[site] = vol
areas[site] = area
- st.site_network.add_site_attribute('site_volumes', vols)
- st.site_network.add_site_attribute('site_surface_areas', areas)
+ st.site_network.add_site_attribute('accessable_site_volumes', vols)
+
+
+ def compute_volumes(self, sn):
+ """Computes the volume of the convex hull defined by each sites' static verticies.
+
+ Requires vertex information in the SiteNetwork.
+
+ Adds the `site_volumes` and `site_surface_areas` attributes.
+
+ Args:
+ - sn (SiteNetwork)
+ """
+ if sn.vertices is None:
+ raise ValueError("SiteNetwork must have verticies to compute volumes!")
+
+ vols = np.empty(shape = st.site_network.n_sites, dtype = np.float)
+ areas = np.empty(shape = st.site_network.n_sites, dtype = np.float)
+
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+
+ for site in range(st.site_network.n_sites):
+ pos = sn.static_structure.positions[sn.vertices[site]]
+ assert pos.flags['OWNDATA'] # It should since we're indexing with index lists
+ # Recenter
+ offset = pbcc.cell_centroid - sn.centers[site]
+ pos += offset
+ pbcc.wrap_points(pos)
+
+ hull = ConvexHull(pos)
+ vols[site] = hull.volume
+ areas[site] = hull.area
+
+ sn.add_site_attribute('site_volumes', vols)
+ sn.add_site_attribute('site_surface_areas', areas)
+
+
+ def run(self, st):
+ """For backwards compatability.
+ """
+ self.compute_accessable_volumes(st)
diff --git a/sitator/site_descriptors/SiteCoordinationNumber.py b/sitator/site_descriptors/SiteCoordinationNumber.py
new file mode 100644
index 0000000..bd53218
--- /dev/null
+++ b/sitator/site_descriptors/SiteCoordinationNumber.py
@@ -0,0 +1,38 @@
+import numpy as np
+
+try:
+ from pymatgen.io.ase import AseAtomsAdaptor
+ import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf
+ has_pymatgen = True
+except ImportError:
+ has_pymatgen = False
+
+
+class SiteCoordinationAnalysis(object):
+ """Determine site types based on local coordination environments.
+
+ Determine site types using the method from the following paper:
+
+ David Waroquiers, Xavier Gonze, Gian-Marco Rignanese, Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk, Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier,
+ “Statistical analysis of coordination environments in oxides”,
+ Chem. Mater., 2017, 29 (19), pp 8346–8360, DOI: 10.1021/acs.chemmater.7b02766
+
+ as implement in `pymatgen`'s `pymatgen.analysis.chemenv.coordination_environments`.
+
+ Args:
+ **kwargs: passed to `compute_structure_environments`.
+ """
+ def __init__(self, **kwargs):
+ if not has_pymatgen:
+ raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
+ self._kwargs = kwargs
+
+ def run(self, sn):
+ site_struct, site_species = sn.get_structure_with_sites()
+ pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
+ lgf = cgf.LocalGeometryFinder()
+ struct_envs = lgf.compute_structure_environments(
+ structure = pymat_struct,
+ indicies = np.where(sn.mobile_mask)[0],
+ only_cations = False,
+ )
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 0650396..f1d14a2 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -98,7 +98,7 @@ def __call__(self, sn, *args, **kwargs):
def _site_layers(self, sn, plot_points_params, same_normalization = False):
pts_arrays = {'points' : sn.centers}
- pts_params = {'cmap' : 'rainbow'}
+ pts_params = {'cmap' : 'copper'}
# -- Apply mapping
# - other mappings
From a248d8f5f688c25e4fc7d5a6a8124b379bc81ca5 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 9 Jul 2019 10:27:51 -0400
Subject: [PATCH 049/129] Added missing import
---
sitator/landmark/LandmarkAnalysis.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 0122c59..be3a3dc 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -10,6 +10,7 @@
from . import helpers
from sitator import SiteNetwork, SiteTrajectory
+from . import MultipleOccupancyError
import logging
logger = logging.getLogger(__name__)
From 6b4b39ebd66a1df33cf893a5a85c066eeae7fcfd Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 9 Jul 2019 13:15:43 -0400
Subject: [PATCH 050/129] Initial MCL for Landmark
---
sitator/dynamics/MergeSitesByDynamics.py | 62 +-----
sitator/landmark/LandmarkAnalysis.py | 4 +-
sitator/landmark/cluster/mcl.py | 43 ++++
sitator/landmark/cluster/pca.py | 39 ++++
sitator/landmark/helpers.pyx | 26 +++
sitator/util/DotProdClassifier.pyx | 256 ++++++++++++-----------
sitator/util/mcl.py | 57 +++++
7 files changed, 302 insertions(+), 185 deletions(-)
create mode 100644 sitator/landmark/cluster/mcl.py
create mode 100644 sitator/landmark/cluster/pca.py
create mode 100644 sitator/util/mcl.py
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index 556f408..e1f1b49 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -3,6 +3,7 @@
from sitator.dynamics import JumpAnalysis
from sitator.util import PBCCalculator
from sitator.network.merging import MergeSites
+from sitator.util.mcl import markov_clustering
import logging
logger = logging.getLogger(__name__)
@@ -35,7 +36,6 @@ def __init__(self,
distance_threshold = 1.0,
post_check_thresh_factor = 1.5,
check_types = True,
- iterlimit = 100,
markov_parameters = {}):
super().__init__(
@@ -149,63 +149,5 @@ def _get_sites_to_merge(self, st):
" This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges)
# -- Do Markov Clustering
- clusters = self._markov_clustering(connectivity_matrix, **self.markov_parameters)
+ clusters = markov_clustering(connectivity_matrix, **self.markov_parameters)
return clusters
-
-
- def _markov_clustering(self,
- transition_matrix,
- expansion = 2,
- inflation = 2,
- pruning_threshold = 0.00001):
- """
- See https://micans.org/mcl/.
-
- Because we're dealing with matrixes that are stochastic already,
- there's no need to add artificial loop values.
-
- Implementation inspired by https://github.com/GuyAllard/markov_clustering
- """
-
- assert transition_matrix.shape[0] == transition_matrix.shape[1]
-
- m1 = transition_matrix.copy()
-
- # Normalize (though it should be close already)
- m1 /= np.sum(m1, axis = 0)
-
- allcols = np.arange(m1.shape[1])
-
- converged = False
- for i in range(self.iterlimit):
- # -- Expansion
- m2 = np.linalg.matrix_power(m1, expansion)
- # -- Inflation
- np.power(m2, inflation, out = m2)
- m2 /= np.sum(m2, axis = 0)
- # -- Prune
- to_prune = m2 < pruning_threshold
- # Exclude the max of every column
- to_prune[np.argmax(m2, axis = 0), allcols] = False
- m2[to_prune] = 0.0
- # -- Check converged
- if np.allclose(m1, m2):
- converged = True
- logger.info("Markov Clustering converged in %i iterations" % i)
- break
-
- m1[:] = m2
-
- if not converged:
- raise ValueError("Markov Clustering couldn't converge in %i iterations" % self.iterlimit)
-
- # -- Get clusters
- attractors = m2.diagonal().nonzero()[0]
-
- clusters = set()
-
- for a in attractors:
- cluster = tuple(m2[a].nonzero()[0])
- clusters.add(cluster)
-
- return list(clusters)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 0122c59..3a941a8 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -210,7 +210,9 @@ def run(self, sn, frames):
self._do_peak_evening()
# - Cluster -
- cluster_func = importlib.import_module("..cluster." + self._cluster_algo, package = __name__).do_landmark_clustering
+ clustermod = importlib.import_module("..cluster." + self._cluster_algo, package = __name__)
+ importlib.reload(clustermod)
+ cluster_func = clustermod.do_landmark_clustering
cluster_counts, lmk_lbls, lmk_confs = \
cluster_func(self._landmark_vectors,
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
new file mode 100644
index 0000000..5bd45f5
--- /dev/null
+++ b/sitator/landmark/cluster/mcl.py
@@ -0,0 +1,43 @@
+import numpy as np
+
+from sitator.util.progress import tqdm
+from sitator.util.mcl import markov_clustering
+from sitator.util import DotProdClassifier
+from ..helpers import _cross_correlation_matrix
+
+import logging
+logger = logging.getLogger(__name__)
+
+DEFAULT_PARAMS = {
+ 'assignment_threshold' : 0.9
+}
+
+def do_landmark_clustering(landmark_vectors,
+ clustering_params,
+ min_samples,
+ verbose = False):
+ tmp = DEFAULT_PARAMS.copy()
+ tmp.update(clustering_params)
+ clustering_params = tmp
+
+ graph = _cross_correlation_matrix(landmark_vectors)
+
+ # -- Cluster Landmarks
+ clusters = markov_clustering(graph) # **clustering_params
+ n_clusters = len(clusters)
+ centers = np.zeros(shape = (n_clusters, landmark_vectors.shape[1]))
+ for i, cluster in enumerate(clusters):
+ centers[i, list(cluster)] = 1.0 # Set the peaks
+
+ landmark_classifier = \
+ DotProdClassifier(threshold = np.nan, # We're not fitting
+ min_samples = min_samples)
+
+ landmark_classifier.set_cluster_centers(centers)
+
+ lmk_lbls, lmk_confs = \
+ landmark_classifier.fit_predict(landmark_vectors,
+ predict_threshold = clustering_params['assignment_threshold'],
+ verbose = verbose)
+
+ return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
diff --git a/sitator/landmark/cluster/pca.py b/sitator/landmark/cluster/pca.py
new file mode 100644
index 0000000..b58d2bb
--- /dev/null
+++ b/sitator/landmark/cluster/pca.py
@@ -0,0 +1,39 @@
+import numpy as np
+
+from sitator.util.progress import tqdm
+from sitator.util import DotProdClassifier
+
+from sklearn.decomposition import IncrementalPCA
+
+import logging
+logger = logging.getLogger(__name__)
+
+DEFAULT_PARAMS = {
+ 'clustering_threshold' : 0.9
+ 'assignment_threshold' : 0.9
+}
+
+def do_landmark_clustering(landmark_vectors,
+ clustering_params,
+ min_samples,
+ verbose = False):
+ tmp = DEFAULT_PARAMS.copy()
+ tmp.update(clustering_params)
+ clustering_params = tmp
+
+ pca = IncrementalPCA()
+ pca.fit(landmark_vectors)
+ keep_n_clusters = np.where(np.cumsum(pca.explained_variance_ratio_) >= clustering_params['clustering_threshold'])[0][0]
+
+ landmark_classifier = \
+ DotProdClassifier(threshold = np.nan, # We're not fitting
+ min_samples = min_samples)
+
+ landmark_classifier.set_cluster_centers(pca.components_[:keep_n_clusters])
+
+ lmk_lbls, lmk_confs = \
+ landmark_classifier.fit_predict(landmark_vectors,
+ predict_threshold = clustering_params['assignment_threshold'],
+ verbose = verbose)
+
+ return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index 1dd736e..0c0e0bc 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -9,6 +9,32 @@ from sitator.landmark import StaticLatticeError, ZeroLandmarkError
ctypedef double precision
+# This is nearly a covariance matrix...
+def _cross_correlation_matrix(const precision [:, :] lvecs):
+ n_lvecs = len(lvecs)
+ n_components = lvecs.shape[1]
+ # -- Construct similarity matrix
+ graph_np = np.zeros(shape = (n_components, n_components))
+ divisors_np = np.zeros(shape = n_components)
+ cdef precision [:] divisors = divisors_np
+ cdef precision [:, :] graph = graph_np
+ cdef precision coeff
+
+ # The matrix is the ensemble average cross correlations between all
+ # landmark vector components (landmarks).
+ for lvec_idex in xrange(n_lvecs):
+ for component in xrange(n_components):
+ coeff = lvecs[lvec_idex, component]
+ if coeff > 0:
+ for i in range(n_components):
+ graph[component, i] += coeff * lvecs[lvec_idex, i]
+ divisors[component] += 1
+
+ #graph /= n_lvecs
+ graph_np /= divisors_np
+
+ return graph_np
+
def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i, logger = None):
if self._landmark_dimension is None:
raise ValueError("_fill_landmark_vectors called before Voronoi!")
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index d7d47de..56f7548 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -23,20 +23,25 @@ class OneValueListlike(object):
return self.value
class DotProdClassifier(object):
+ """Assign vectors to clusters indicated by a representative vector using a cosine metric.
+
+ Cluster centers can be given through `set_cluster_centers()` or approximated
+ using the custom method described in the appendix of the main landmark
+ analysis paper (`fit_centers()`).
+
+ :param float threshold: Similarity threshold for joining a cluster.
+ In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal)
+ :param int max_converge_iters: Maximum number of iterations. If the algorithm hasn't converged
+ by then, it will exit with a warning.
+ :param int|float min_samples: filter out clusters with low sample counts.
+ If an int, filters out clusters with fewer samples than this.
+ If a float, filters out clusters with fewer than floor(min_samples * n_assigned_samples)
+ samples assigned to them.
+ """
def __init__(self,
threshold = 0.9,
max_converge_iters = 10,
min_samples = 1):
- """
- :param float threshold: Similarity threshold for joining a cluster.
- In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal)
- :param int max_converge_iters: Maximum number of iterations. If the algorithm hasn't converged
- by then, it will exit with a warning.
- :param int|float min_samples: filter out clusters with low sample counts.
- If an int, filters out clusters with fewer samples than this.
- If a float, filters out clusters with fewer than floor(min_samples * n_assigned_samples)
- samples assigned to them.
- """
self._threshold = threshold
self._max_iters = max_converge_iters
self._min_samples = min_samples
@@ -48,6 +53,9 @@ class DotProdClassifier(object):
def cluster_centers(self):
return self._cluster_centers
+ def set_cluster_centers(self, centers):
+ self._cluster_centers = centers
+
@property
def cluster_counts(self):
return self._cluster_counts
@@ -70,22 +78,132 @@ class DotProdClassifier(object):
if predict_threshold is None:
predict_threshold = self._threshold
+ if self._cluster_centers is None:
+ self.fit_centers(X)
+
+ # Run a predict now:
+ labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
+
+ total_n_assigned = np.sum(labels >= 0)
+
+ # -- filter out low counts
+ if not self._min_samples is None:
+ self._cluster_counts = np.bincount(labels[labels >= 0])
+
+ assert len(self._cluster_counts) == len(self._cluster_centers)
+
+ min_samples = None
+ if isinstance(self._min_samples, numbers.Integral):
+ min_samples = self._min_samples
+ elif isinstance(self._min_samples, numbers.Real):
+ min_samples = int(np.floor(self._min_samples * total_n_assigned))
+ else:
+ raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
+
+ count_mask = self._cluster_counts >= min_samples
+
+ self._cluster_centers = self._cluster_centers[count_mask]
+ self._cluster_counts = self._cluster_counts[count_mask]
+
+ if len(self._cluster_centers) == 0:
+ # Then we removed everything...
+ raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask))
+
+ logger.info("DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \
+ (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts)))
+
+ # Do another predict -- this could be more efficient, but who cares?
+ labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
+
+ if return_info:
+ info = {
+ 'clusters_below_min_samples' : np.sum(~count_mask)
+ }
+ return labels, confs, info
+ else:
+ return labels, confs
+
+ def predict(self, X, return_confidences = False, threshold = None, verbose = True, ignore_zeros = True):
+ """Return a predicted cluster label for vectors X.
+
+ :param float threshold: alternate threshold. Defaults to None, when self.threshold
+ is used.
+
+ :returns: an array of labels. -1 indicates no assignment.
+ :returns: an array of confidences in assignments. Normalzied
+ values from 0 (no confidence, no label) to 1 (identical to cluster center).
+ """
+
+ assert len(X.shape) == 2, "Data must be 2D."
+
+ if not X.shape[1] == (self._featuredim):
+ raise TypeError("X has wrong dimension %s; should be (%i)" % (X.shape, self._featuredim))
+
+ labels = np.empty(shape = len(X), dtype = np.int)
+
+ if threshold is None:
+ threshold = self._threshold
+
+ confidences = None
+ if return_confidences:
+ confidences = np.empty(shape = len(X), dtype = np.float)
+
+ zeros_count = 0
+
+ center_norms = np.linalg.norm(self._cluster_centers, axis = 1)
+ normed_centers = self._cluster_centers.copy()
+ normed_centers /= center_norms[:, np.newaxis]
+
+ # preallocate buffers
+ diffs = np.empty(shape = len(center_norms), dtype = np.float)
+
+ for i, x in enumerate(tqdm(X, desc = "Sample")):
+
+ if np.all(x == 0):
+ if ignore_zeros:
+ labels[i] = -1
+ zeros_count += 1
+ continue
+ else:
+ raise ValueError("Data %i is all zeros!" % i)
+
+ # diffs = np.sum(x * self._cluster_centers, axis = 1)
+ # diffs /= np.linalg.norm(x) * center_norms
+ np.dot(normed_centers, x, out = diffs)
+ diffs /= np.linalg.norm(x)
+ #diffs /= center_norms
+
+ assigned_to = np.argmax(diffs)
+ assignment_confidence = diffs[assigned_to]
+
+ if assignment_confidence < threshold:
+ assigned_to = -1
+ assignment_confidence = 0.0
+
+ labels[i] = assigned_to
+ confidences[i] = assignment_confidence
+
+ if zeros_count > 0:
+ logger.warning("Encountered %i zero vectors during prediction" % zeros_count)
+
+ if return_confidences:
+ return labels, confidences
+ else:
+ return labels
+
+ def fit_centers(self, X):
# Essentially hierarchical clustering that stops when no cluster *centers*
# are more similar than the threshold.
-
labels = np.empty(shape = len(X), dtype = np.int)
labels.fill(-1)
# Start with each sample as a cluster
- #old_centers = np.copy(X)
- #old_n_assigned = np.ones(shape = len(X), dtype = np.int)
# For memory's sake, no copying
old_centers = X
old_n_assigned = OneValueListlike(value = 1, length = len(X))
old_n_clusters = len(X)
# -- Classification loop
-
# Maximum number of iterations
last_n_sites = -1
did_converge = False
@@ -190,113 +308,3 @@ class DotProdClassifier(object):
raise ValueError("Clustering did not converge after %i iterations" % (self._max_iters))
self._cluster_centers = np.asarray(cluster_centers[:n_clusters])
-
- # Run a predict now:
- labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
-
- total_n_assigned = np.sum(labels >= 0)
-
- # -- filter out low counts
- if not self._min_samples is None:
- self._cluster_counts = np.bincount(labels[labels >= 0])
-
- assert len(self._cluster_counts) == len(self._cluster_centers)
-
- min_samples = None
- if isinstance(self._min_samples, numbers.Integral):
- min_samples = self._min_samples
- elif isinstance(self._min_samples, numbers.Real):
- min_samples = int(np.floor(self._min_samples * total_n_assigned))
- else:
- raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
-
- count_mask = self._cluster_counts >= min_samples
-
- self._cluster_centers = self._cluster_centers[count_mask]
- self._cluster_counts = self._cluster_counts[count_mask]
-
- if len(self._cluster_centers) == 0:
- # Then we removed everything...
- raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask))
-
- logger.info("DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \
- (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts)))
-
- # Do another predict -- this could be more efficient, but who cares?
- labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
-
- if return_info:
- info = {
- 'clusters_below_min_samples' : np.sum(~count_mask)
- }
- return labels, confs, info
- else:
- return labels, confs
-
- def predict(self, X, return_confidences = False, threshold = None, verbose = True, ignore_zeros = True):
- """Return a predicted cluster label for vectors X.
-
- :param float threshold: alternate threshold. Defaults to None, when self.threshold
- is used.
-
- :returns: an array of labels. -1 indicates no assignment.
- :returns: an array of confidences in assignments. Normalzied
- values from 0 (no confidence, no label) to 1 (identical to cluster center).
- """
-
- assert len(X.shape) == 2, "Data must be 2D."
-
- if not X.shape[1] == (self._featuredim):
- raise TypeError("X has wrong dimension %s; should be (%i)" % (X.shape, self._featuredim))
-
- labels = np.empty(shape = len(X), dtype = np.int)
-
- if threshold is None:
- threshold = self._threshold
-
- confidences = None
- if return_confidences:
- confidences = np.empty(shape = len(X), dtype = np.float)
-
- zeros_count = 0
-
- center_norms = np.linalg.norm(self._cluster_centers, axis = 1)
- normed_centers = self._cluster_centers.copy()
- normed_centers /= center_norms[:, np.newaxis]
-
- # preallocate buffers
- diffs = np.empty(shape = len(center_norms), dtype = np.float)
-
- for i, x in enumerate(tqdm(X, desc = "Sample")):
-
- if np.all(x == 0):
- if ignore_zeros:
- labels[i] = -1
- zeros_count += 1
- continue
- else:
- raise ValueError("Data %i is all zeros!" % i)
-
- # diffs = np.sum(x * self._cluster_centers, axis = 1)
- # diffs /= np.linalg.norm(x) * center_norms
- np.dot(normed_centers, x, out = diffs)
- diffs /= np.linalg.norm(x)
- #diffs /= center_norms
-
- assigned_to = np.argmax(diffs)
- assignment_confidence = diffs[assigned_to]
-
- if assignment_confidence < threshold:
- assigned_to = -1
- assignment_confidence = 0.0
-
- labels[i] = assigned_to
- confidences[i] = assignment_confidence
-
- if zeros_count > 0:
- logger.warning("Encountered %i zero vectors during prediction" % zeros_count)
-
- if return_confidences:
- return labels, confidences
- else:
- return labels
diff --git a/sitator/util/mcl.py b/sitator/util/mcl.py
new file mode 100644
index 0000000..214bc87
--- /dev/null
+++ b/sitator/util/mcl.py
@@ -0,0 +1,57 @@
+import numpy as np
+
+def markov_clustering(transition_matrix,
+ expansion = 2,
+ inflation = 2,
+ pruning_threshold = 0.00001,
+ iterlimit = 100):
+ """
+ See https://micans.org/mcl/.
+
+ Because we're dealing with matrixes that are stochastic already,
+ there's no need to add artificial loop values.
+
+ Implementation inspired by https://github.com/GuyAllard/markov_clustering
+ """
+
+ assert transition_matrix.shape[0] == transition_matrix.shape[1]
+
+ m1 = transition_matrix.copy()
+
+ # Normalize (though it should be close already)
+ m1 /= np.sum(m1, axis = 0)
+
+ allcols = np.arange(m1.shape[1])
+
+ converged = False
+ for i in range(iterlimit):
+ # -- Expansion
+ m2 = np.linalg.matrix_power(m1, expansion)
+ # -- Inflation
+ np.power(m2, inflation, out = m2)
+ m2 /= np.sum(m2, axis = 0)
+ # -- Prune
+ to_prune = m2 < pruning_threshold
+ # Exclude the max of every column
+ to_prune[np.argmax(m2, axis = 0), allcols] = False
+ m2[to_prune] = 0.0
+ # -- Check converged
+ if np.allclose(m1, m2):
+ converged = True
+ break
+
+ m1[:] = m2
+
+ if not converged:
+ raise ValueError("Markov Clustering couldn't converge in %i iterations" % iterlimit)
+
+ # -- Get clusters
+ attractors = m2.diagonal().nonzero()[0]
+
+ clusters = set()
+
+ for a in attractors:
+ cluster = tuple(m2[a].nonzero()[0])
+ clusters.add(cluster)
+
+ return list(clusters)
From 6dd4441d58dad3aab3db64659e4a304ddce5e486 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 9 Jul 2019 18:49:44 -0400
Subject: [PATCH 051/129] Working MCL Landmark Clustering
---
sitator/landmark/LandmarkAnalysis.py | 1 +
sitator/landmark/cluster/mcl.py | 23 +++++++++++++---
sitator/landmark/cluster/pca.py | 39 ----------------------------
sitator/landmark/helpers.pyx | 26 -------------------
sitator/util/DotProdClassifier.pyx | 2 +-
5 files changed, 21 insertions(+), 70 deletions(-)
delete mode 100644 sitator/landmark/cluster/pca.py
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 3a941a8..aa5b2b2 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -10,6 +10,7 @@
from . import helpers
from sitator import SiteNetwork, SiteTrajectory
+from .errors import MultipleOccupancyError
import logging
logger = logging.getLogger(__name__)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 5bd45f5..aeecd1f 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -5,6 +5,8 @@
from sitator.util import DotProdClassifier
from ..helpers import _cross_correlation_matrix
+from sklearn.covariance import empirical_covariance
+
import logging
logger = logging.getLogger(__name__)
@@ -12,22 +14,35 @@
'assignment_threshold' : 0.9
}
+def cov2corr( A ):
+ """
+ covariance matrix to correlation matrix.
+ """
+ d = np.sqrt(A.diagonal())
+ A = ((A.T/d).T)/d
+ return A
+
def do_landmark_clustering(landmark_vectors,
clustering_params,
min_samples,
- verbose = False):
+ verbose):
tmp = DEFAULT_PARAMS.copy()
tmp.update(clustering_params)
clustering_params = tmp
- graph = _cross_correlation_matrix(landmark_vectors)
+ cor = empirical_covariance(landmark_vectors)
+ cor = cov2corr(cor)
+ graph = np.clip(cor, 0, None)
+
+ predict_threshold = clustering_params.pop('assignment_threshold')
# -- Cluster Landmarks
- clusters = markov_clustering(graph) # **clustering_params
+ clusters = markov_clustering(graph, **clustering_params)
n_clusters = len(clusters)
centers = np.zeros(shape = (n_clusters, landmark_vectors.shape[1]))
for i, cluster in enumerate(clusters):
centers[i, list(cluster)] = 1.0 # Set the peaks
+ #centers[i] = np.sum(cor[list(cluster)], axis = 0)
landmark_classifier = \
DotProdClassifier(threshold = np.nan, # We're not fitting
@@ -37,7 +52,7 @@ def do_landmark_clustering(landmark_vectors,
lmk_lbls, lmk_confs = \
landmark_classifier.fit_predict(landmark_vectors,
- predict_threshold = clustering_params['assignment_threshold'],
+ predict_threshold = predict_threshold,
verbose = verbose)
return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
diff --git a/sitator/landmark/cluster/pca.py b/sitator/landmark/cluster/pca.py
deleted file mode 100644
index b58d2bb..0000000
--- a/sitator/landmark/cluster/pca.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import numpy as np
-
-from sitator.util.progress import tqdm
-from sitator.util import DotProdClassifier
-
-from sklearn.decomposition import IncrementalPCA
-
-import logging
-logger = logging.getLogger(__name__)
-
-DEFAULT_PARAMS = {
- 'clustering_threshold' : 0.9
- 'assignment_threshold' : 0.9
-}
-
-def do_landmark_clustering(landmark_vectors,
- clustering_params,
- min_samples,
- verbose = False):
- tmp = DEFAULT_PARAMS.copy()
- tmp.update(clustering_params)
- clustering_params = tmp
-
- pca = IncrementalPCA()
- pca.fit(landmark_vectors)
- keep_n_clusters = np.where(np.cumsum(pca.explained_variance_ratio_) >= clustering_params['clustering_threshold'])[0][0]
-
- landmark_classifier = \
- DotProdClassifier(threshold = np.nan, # We're not fitting
- min_samples = min_samples)
-
- landmark_classifier.set_cluster_centers(pca.components_[:keep_n_clusters])
-
- lmk_lbls, lmk_confs = \
- landmark_classifier.fit_predict(landmark_vectors,
- predict_threshold = clustering_params['assignment_threshold'],
- verbose = verbose)
-
- return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index 0c0e0bc..1dd736e 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -9,32 +9,6 @@ from sitator.landmark import StaticLatticeError, ZeroLandmarkError
ctypedef double precision
-# This is nearly a covariance matrix...
-def _cross_correlation_matrix(const precision [:, :] lvecs):
- n_lvecs = len(lvecs)
- n_components = lvecs.shape[1]
- # -- Construct similarity matrix
- graph_np = np.zeros(shape = (n_components, n_components))
- divisors_np = np.zeros(shape = n_components)
- cdef precision [:] divisors = divisors_np
- cdef precision [:, :] graph = graph_np
- cdef precision coeff
-
- # The matrix is the ensemble average cross correlations between all
- # landmark vector components (landmarks).
- for lvec_idex in xrange(n_lvecs):
- for component in xrange(n_components):
- coeff = lvecs[lvec_idex, component]
- if coeff > 0:
- for i in range(n_components):
- graph[component, i] += coeff * lvecs[lvec_idex, i]
- divisors[component] += 1
-
- #graph /= n_lvecs
- graph_np /= divisors_np
-
- return graph_np
-
def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i, logger = None):
if self._landmark_dimension is None:
raise ValueError("_fill_landmark_vectors called before Voronoi!")
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 56f7548..2982886 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -88,7 +88,7 @@ class DotProdClassifier(object):
# -- filter out low counts
if not self._min_samples is None:
- self._cluster_counts = np.bincount(labels[labels >= 0])
+ self._cluster_counts = np.bincount(labels[labels >= 0], minlength = len(self._cluster_centers))
assert len(self._cluster_counts) == len(self._cluster_centers)
From 7f0d3cbffbc9c2d706fa1f63f19096c62c327031 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 9 Jul 2019 22:57:59 -0400
Subject: [PATCH 052/129] Added option of whether to norm dot product metric
---
sitator/landmark/cluster/mcl.py | 4 ++--
sitator/util/DotProdClassifier.pyx | 23 ++++++++++++-----------
2 files changed, 14 insertions(+), 13 deletions(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index aeecd1f..f1d0d50 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -41,8 +41,7 @@ def do_landmark_clustering(landmark_vectors,
n_clusters = len(clusters)
centers = np.zeros(shape = (n_clusters, landmark_vectors.shape[1]))
for i, cluster in enumerate(clusters):
- centers[i, list(cluster)] = 1.0 # Set the peaks
- #centers[i] = np.sum(cor[list(cluster)], axis = 0)
+ centers[i, list(cluster)] = 1 / len(cluster) # Set the peaks
landmark_classifier = \
DotProdClassifier(threshold = np.nan, # We're not fitting
@@ -53,6 +52,7 @@ def do_landmark_clustering(landmark_vectors,
lmk_lbls, lmk_confs = \
landmark_classifier.fit_predict(landmark_vectors,
predict_threshold = predict_threshold,
+ predict_normed = False,
verbose = verbose)
return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 2982886..bf4a231 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -64,7 +64,7 @@ class DotProdClassifier(object):
def n_clusters(self):
return len(self._cluster_counts)
- def fit_predict(self, X, verbose = True, predict_threshold = None, return_info = False):
+ def fit_predict(self, X, verbose = True, predict_threshold = None, predict_normed = True, return_info = False):
""" Fit the data vectors X and return their cluster labels.
"""
@@ -82,7 +82,7 @@ class DotProdClassifier(object):
self.fit_centers(X)
# Run a predict now:
- labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
+ labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold, predict_normed = predict_normed)
total_n_assigned = np.sum(labels >= 0)
@@ -113,7 +113,7 @@ class DotProdClassifier(object):
(np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts)))
# Do another predict -- this could be more efficient, but who cares?
- labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
+ labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold, predict_normed = predict_normed)
if return_info:
info = {
@@ -123,7 +123,7 @@ class DotProdClassifier(object):
else:
return labels, confs
- def predict(self, X, return_confidences = False, threshold = None, verbose = True, ignore_zeros = True):
+ def predict(self, X, return_confidences = False, threshold = None, predict_normed = True, verbose = True, ignore_zeros = True):
"""Return a predicted cluster label for vectors X.
:param float threshold: alternate threshold. Defaults to None, when self.threshold
@@ -150,9 +150,12 @@ class DotProdClassifier(object):
zeros_count = 0
- center_norms = np.linalg.norm(self._cluster_centers, axis = 1)
- normed_centers = self._cluster_centers.copy()
- normed_centers /= center_norms[:, np.newaxis]
+ if predict_normed:
+ center_norms = np.linalg.norm(self._cluster_centers, axis = 1)
+ normed_centers = self._cluster_centers.copy()
+ normed_centers /= center_norms[:, np.newaxis]
+ else:
+ normed_centers = self._cluster_centers
# preallocate buffers
diffs = np.empty(shape = len(center_norms), dtype = np.float)
@@ -167,11 +170,9 @@ class DotProdClassifier(object):
else:
raise ValueError("Data %i is all zeros!" % i)
- # diffs = np.sum(x * self._cluster_centers, axis = 1)
- # diffs /= np.linalg.norm(x) * center_norms
np.dot(normed_centers, x, out = diffs)
- diffs /= np.linalg.norm(x)
- #diffs /= center_norms
+ if predict_normed:
+ diffs /= np.linalg.norm(x)
assigned_to = np.argmax(diffs)
assignment_confidence = diffs[assigned_to]
From e6db4d55c3fb7eb87f2b2d956e66233a6944aaed Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 10 Jul 2019 14:32:30 -0400
Subject: [PATCH 053/129] Text markers
---
sitator/visualization/SiteNetworkPlotter.py | 29 +++++++++++++++++----
1 file changed, 24 insertions(+), 5 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 0650396..e88bc16 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -4,6 +4,7 @@
import matplotlib
from mpl_toolkits.mplot3d.art3d import Line3DCollection
+from matplotlib.textpath import TextPath
from sitator.util import PBCCalculator
from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS, set_axes_equal
@@ -18,7 +19,9 @@ class SiteNetworkPlotter(object):
Params:
- site_mappings (dict): defines how to show different properties. Each
entry maps a visual aspect ('marker', 'color', 'size') to the name
- of a site attribute including 'site_type'.
+ of a site attribute including 'site_type'. The markers can also be
+ arbitrary text (key `"text"`) in which case the value can also be a
+ 2-tuple of an attribute name and a `%` format string.
- edge_mappings (dict): each key maps a visual property ('intensity',
'color', 'width', 'linestyle') to an edge attribute in the SiteNetwork.
- markers (list of str): What `matplotlib` markers to use for sites.
@@ -56,6 +59,7 @@ def __init__(self,
min_width_threshold = 0.0,
title = ""):
self.site_mappings = site_mappings
+ assert not ("marker" in site_mappings and "text" in site_mappings)
self.edge_mappings = edge_mappings
self.markers = markers
self.plot_points_params = plot_points_params
@@ -105,10 +109,21 @@ def _site_layers(self, sn, plot_points_params, same_normalization = False):
markers = None
for key in self.site_mappings:
- val = getattr(sn, self.site_mappings[key])
+ val = self.site_mappings[key]
+ if isinstance(val, tuple):
+ val, param = val
+ else:
+ param = None
+ val = getattr(sn, val)
if key == 'marker':
if not val is None:
markers = val.copy()
+ istextmarker = False
+ elif key == 'text':
+ istextmarker = True
+ format_str = "%s" if param is None else param
+ format_str = "$" + format_str + "$"
+ markers = val.copy()
elif key == 'color':
pts_arrays['c'] = val.copy()
if not same_normalization:
@@ -136,10 +151,14 @@ def _site_layers(self, sn, plot_points_params, same_normalization = False):
else:
markers = self._make_discrete(markers)
unique_markers = np.unique(markers)
- if len(unique_markers) > len(self.markers):
- raise ValueError("Too many distinct values of the site property mapped to markers (there are %i) for the %i markers in `self.markers`" % (len(unique_markers), len(self.markers)))
if not same_normalization:
- self._marker_table = dict(zip(unique_markers, self.markers[:len(unique_markers)]))
+ if istextmarker:
+ self._marker_table = dict(zip(unique_markers, (format_str % um for um in unique_markers)))
+ else:
+ if len(unique_markers) > len(self.markers):
+ raise ValueError("Too many distinct values of the site property mapped to markers (there are %i) for the %i markers in `self.markers`" % (len(unique_markers), len(self.markers)))
+ self._marker_table = dict(zip(unique_markers, self.markers[:len(unique_markers)]))
+
for um in unique_markers:
marker_layers[self._marker_table[um]] = (markers == um)
From a4d4c1c0317d9ddd4362ff1ec5d784c4b3535456 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 10 Jul 2019 21:29:12 -0400
Subject: [PATCH 054/129] Fixed jump accounting bug when there are many
unknowns
---
sitator/dynamics/JumpAnalysis.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index 64725d6..86621d8 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -62,7 +62,7 @@ def run(self, st):
unassigned = frame == SiteTrajectory.SITE_UNKNOWN
# Reassign unassigned
frame[unassigned] = last_known[unassigned]
- fknown = frame >= 0
+ fknown = (frame >= 0) & (last_known >= 0)
if np.any(~fknown):
logger.warning(" at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
From ea859468a341e30f92d411e4c0c63fe7a5bc2be2 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 10 Jul 2019 22:06:07 -0400
Subject: [PATCH 055/129] Make sure empty clusters always removed
---
sitator/util/DotProdClassifier.pyx | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index bf4a231..116a024 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -99,6 +99,7 @@ class DotProdClassifier(object):
min_samples = int(np.floor(self._min_samples * total_n_assigned))
else:
raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
+ min_samples = max(min_samples, 1)
count_mask = self._cluster_counts >= min_samples
@@ -158,7 +159,7 @@ class DotProdClassifier(object):
normed_centers = self._cluster_centers
# preallocate buffers
- diffs = np.empty(shape = len(center_norms), dtype = np.float)
+ diffs = np.empty(shape = len(normed_centers), dtype = np.float)
for i, x in enumerate(tqdm(X, desc = "Sample")):
From a07942af4e027791bff0d51d7194b86c80f865ea Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 10 Jul 2019 22:08:03 -0400
Subject: [PATCH 056/129] Eigenvector clustering centers
---
sitator/landmark/cluster/mcl.py | 26 ++++++++++++++++++++------
sitator/util/mcl.py | 3 +++
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index f1d0d50..659ee31 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -19,6 +19,7 @@ def cov2corr( A ):
covariance matrix to correlation matrix.
"""
d = np.sqrt(A.diagonal())
+ d[d == 0] = np.inf # Forces correlations to zero where variance is 0
A = ((A.T/d).T)/d
return A
@@ -30,18 +31,31 @@ def do_landmark_clustering(landmark_vectors,
tmp.update(clustering_params)
clustering_params = tmp
- cor = empirical_covariance(landmark_vectors)
- cor = cov2corr(cor)
- graph = np.clip(cor, 0, None)
+ n_lmk = landmark_vectors.shape[1]
+
+ cov = empirical_covariance(landmark_vectors)
+ corr = cov2corr(cov)
+ graph = np.clip(corr, 0, None)
+ for i in range(n_lmk):
+ if graph[i, i] == 0: # i.e. no self correlation = 0 variance = landmark never seen
+ graph[i, i] = 1 # Needs a self loop for Markov clustering not to degenerate. Arbitrary value, shouldn't affect anyone else.
predict_threshold = clustering_params.pop('assignment_threshold')
# -- Cluster Landmarks
clusters = markov_clustering(graph, **clustering_params)
+ clusters = [list(c) for c in clusters]
n_clusters = len(clusters)
- centers = np.zeros(shape = (n_clusters, landmark_vectors.shape[1]))
+ centers = np.zeros(shape = (n_clusters, n_lmk))
for i, cluster in enumerate(clusters):
- centers[i, list(cluster)] = 1 / len(cluster) # Set the peaks
+ if len(cluster) == 1:
+ centers[i, cluster] = 1.0 # Eigenvec is trivial case; scale doesn't matter either.
+ else:
+ # PCA inspired:
+ eigenval, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
+ # abs cause all our data is in the first "octant"
+ centers[i, cluster] = np.abs(eigenvec.T)
+
landmark_classifier = \
DotProdClassifier(threshold = np.nan, # We're not fitting
@@ -52,7 +66,7 @@ def do_landmark_clustering(landmark_vectors,
lmk_lbls, lmk_confs = \
landmark_classifier.fit_predict(landmark_vectors,
predict_threshold = predict_threshold,
- predict_normed = False,
+ predict_normed = True,
verbose = verbose)
return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
diff --git a/sitator/util/mcl.py b/sitator/util/mcl.py
index 214bc87..6a975ab 100644
--- a/sitator/util/mcl.py
+++ b/sitator/util/mcl.py
@@ -16,6 +16,9 @@ def markov_clustering(transition_matrix,
assert transition_matrix.shape[0] == transition_matrix.shape[1]
+ # Check for nonzero diagonal -- self loops needed to avoid div by zero and NaNs
+ assert np.count_nonzero(transition_matrix.diagonal()) == len(transition_matrix)
+
m1 = transition_matrix.copy()
# Normalize (though it should be close already)
From e3e6a84f32304e15cbffdf2884b598f1fd02d312 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 11 Jul 2019 10:00:47 -0400
Subject: [PATCH 057/129] Cleanup
---
sitator/landmark/LandmarkAnalysis.py | 2 +-
sitator/landmark/cluster/dbscan.py | 71 ---------------------
sitator/landmark/cluster/mcl.py | 2 +
sitator/util/DotProdClassifier.pyx | 1 +
sitator/visualization/SiteNetworkPlotter.py | 1 -
5 files changed, 4 insertions(+), 73 deletions(-)
delete mode 100644 sitator/landmark/cluster/dbscan.py
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index aa5b2b2..4d1fef7 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -94,7 +94,7 @@ def __init__(self,
self._clustering_params = clustering_params
if not peak_evening in ['none', 'clip']:
- raise ValueError("Invalid value `%s` for peak_evening" % peak_evening)
+ raise ValueError("Invalid value `%s` for peak_evening" % peak_evening)
self._peak_evening = peak_evening
self.verbose = verbose
diff --git a/sitator/landmark/cluster/dbscan.py b/sitator/landmark/cluster/dbscan.py
deleted file mode 100644
index 769152b..0000000
--- a/sitator/landmark/cluster/dbscan.py
+++ /dev/null
@@ -1,71 +0,0 @@
-
-import numpy as np
-
-import numbers
-from sklearn.cluster import DBSCAN
-
-import logging
-logger = logging.getLogger(__name__)
-
-DEFAULT_PARAMS = {
- 'eps' : 0.05,
- 'min_samples' : 5,
- 'n_jobs' : -1
-}
-
-def do_landmark_clustering(landmark_vectors,
- clustering_params,
- min_samples,
- verbose = False):
- # `verbose` ignored.
-
- tmp = DEFAULT_PARAMS.copy()
- tmp.update(clustering_params)
- clustering_params = tmp
-
- landmark_classifier = \
- DBSCAN(eps = clustering_params['eps'],
- min_samples = clustering_params['min_samples'],
- n_jobs = clustering_params['n_jobs'],
- metric = 'cosine')
-
- lmk_lbls = \
- landmark_classifier.fit_predict(landmark_vectors)
-
- # - Filter low occupancy sites
- cluster_counts = np.bincount(lmk_lbls[lmk_lbls >= 0])
- n_assigned = np.sum(cluster_counts)
-
- min_n_samples_cluster = None
- if isinstance(min_samples, numbers.Integral):
- min_n_samples_cluster = min_samples
- elif isinstance(min_samples, numbers.Real):
- min_n_samples_cluster = int(np.floor(min_samples * n_assigned))
- else:
- raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
-
- to_remove_mask = cluster_counts < min_n_samples_cluster
- to_remove = np.where(to_remove_mask)[0]
-
- trans_table = np.empty(shape = len(cluster_counts) + 1, dtype = np.int)
- # Map unknown to unknown
- trans_table[-1] = -1
- # Map removed to unknwon
- trans_table[:-1][to_remove_mask] = -1
- # Map known to rescaled known
- trans_table[:-1][~to_remove_mask] = np.arange(len(cluster_counts) - len(to_remove))
- # Do the remapping
- lmk_lbls = trans_table[lmk_lbls]
-
- logging.info("DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \
- (len(to_remove), len(cluster_counts), min_samples, min_n_samples_cluster, len(cluster_counts) - len(to_remove)))
-
- # Remove counts
- cluster_counts = cluster_counts[~to_remove_mask]
-
- # There are no confidences with DBSCAN, so just give everything confidence 1
- # so as not to screw up later weighting.
- confs = np.ones(shape = lmk_lbls.shape, dtype = np.float)
- confs[lmk_lbls == -1] = 0.0
-
- return cluster_counts, lmk_lbls, confs
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 659ee31..138e1a4 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -7,6 +7,8 @@
from sklearn.covariance import empirical_covariance
+from scipy.sparse.linalg import eigsh
+
import logging
logger = logging.getLogger(__name__)
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 116a024..c7f08de 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -174,6 +174,7 @@ class DotProdClassifier(object):
np.dot(normed_centers, x, out = diffs)
if predict_normed:
diffs /= np.linalg.norm(x)
+ np.abs(diffs, out = diffs)
assigned_to = np.argmax(diffs)
assignment_confidence = diffs[assigned_to]
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index e88bc16..82e7689 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -4,7 +4,6 @@
import matplotlib
from mpl_toolkits.mplot3d.art3d import Line3DCollection
-from matplotlib.textpath import TextPath
from sitator.util import PBCCalculator
from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS, set_axes_equal
From a4323aa72e929a5856f3be18ed956bc6958bbc1f Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 11 Jul 2019 13:19:59 -0400
Subject: [PATCH 058/129] Improved site merging options
---
sitator/dynamics/MergeSitesByThreshold.py | 30 ++++++++++++++++++++++-
1 file changed, 29 insertions(+), 1 deletion(-)
diff --git a/sitator/dynamics/MergeSitesByThreshold.py b/sitator/dynamics/MergeSitesByThreshold.py
index 04ed522..4469222 100644
--- a/sitator/dynamics/MergeSitesByThreshold.py
+++ b/sitator/dynamics/MergeSitesByThreshold.py
@@ -4,6 +4,7 @@
from scipy.sparse.csgraph import connected_components
+from sitator.util import PBCCalculator
from sitator.network.merging import MergeSites
@@ -29,11 +30,15 @@ def __init__(self,
relation = operator.ge,
directed = True,
connection = 'strong',
+ distance_threshold = np.inf,
+ forbid_multiple_occupancy = False,
**kwargs):
self.attrname = attrname
self.relation = relation
self.directed = directed
self.connection = connection
+ self.distance_threshold = distance_threshold
+ self.forbid_multiple_occupancy = forbid_multiple_occupancy
super().__init__(**kwargs)
@@ -42,10 +47,33 @@ def _get_sites_to_merge(self, st, threshold = 0):
attrmat = getattr(sn, self.attrname)
assert attrmat.shape == (sn.n_sites, sn.n_sites), "`attrname` doesn't seem to indicate an edge property."
+ connmat = self.relation(attrmat, threshold)
+
+ # Apply distance threshold
+ if self.distance_threshold < np.inf:
+ pbcc = PBCCalculator(sn.structure.cell)
+ centers = sn.centers
+ for i in range(sn.n_sites):
+ dists = pbcc.distances(centers[i], centers[i + 1:])
+ js_too_far = np.where(dists > self.distance_threshold)[0]
+ js_too_far += i + 1
+
+ connmat[i, js_too_far] = False
+ connmat[js_too_far, i] = False # Symmetry
+
+ if self.forbid_multiple_occupancy:
+ n_mobile = sn.n_mobile
+ for frame in st.traj:
+ for mob in range(n_mobile):
+ # can't merge occupied site with other simulatanious occupied sites
+ connmat[frame[mob], frame] = False
+
+ # Everything is always mergable with itself.
+ np.fill_diagonal(connmat, True)
# Get mergable groups
n_merged_sites, labels = connected_components(
- self.relation(attrmat, threshold),
+ connmat,
directed = self.directed,
connection = self.connection
)
From bc09650fedc672186d51f188da7bcb78d8a11938 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 11 Jul 2019 14:33:10 -0400
Subject: [PATCH 059/129] Landmark vertices; merging vertices
---
sitator/SiteNetwork.py | 7 +++++++
sitator/landmark/LandmarkAnalysis.py | 25 ++++++++++++++++++-------
sitator/landmark/cluster/mcl.py | 24 ++++++++++++++++--------
sitator/network/merging.py | 8 ++++++++
sitator/util/DotProdClassifier.pyx | 3 ++-
5 files changed, 51 insertions(+), 16 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 8027638..22df97a 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -158,6 +158,13 @@ def vertices(self, value):
raise ValueError("Wrong # of vertices %i; expected %i" % (len(value), len(self._centers)))
self._vertices = value
+ @property
+ def number_of_vertices(self):
+ if self._vertices is None:
+ return None
+ else:
+ return [len(v) for v in self._vertices]
+
@property
def site_types(self):
if self._types is None:
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 4d1fef7..f4e1f99 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -211,16 +211,26 @@ def run(self, sn, frames):
self._do_peak_evening()
# - Cluster -
+ # FIXME: remove reload after development done
clustermod = importlib.import_module("..cluster." + self._cluster_algo, package = __name__)
importlib.reload(clustermod)
cluster_func = clustermod.do_landmark_clustering
- cluster_counts, lmk_lbls, lmk_confs = \
+ clustering = \
cluster_func(self._landmark_vectors,
clustering_params = self._clustering_params,
min_samples = self._minimum_site_occupancy / float(sn.n_mobile),
verbose = self.verbose)
+ if len(clustering) == 3:
+ cluster_counts, lmk_lbls, lmk_confs = clustering
+ landmark_clusters = None
+ elif len(clustering) == 4:
+ cluster_counts, lmk_lbls, lmk_confs, landmark_clusters = clustering
+ assert len(cluster_counts) == len(landmark_clusters)
+ else:
+ raise ValueError("Clustering function returned invalid result %s" % clustering)
+
logging.info(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))))
# reshape lables and confidences
@@ -235,9 +245,9 @@ def run(self, sn, frames):
logging.info(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
# -- Do output
+ out_sn = sn.copy()
# - Compute site centers
site_centers = np.empty(shape = (n_sites, 3), dtype = frames.dtype)
-
for site in range(n_sites):
mask = lmk_lbls == site
pts = frames[:, sn.mobile_mask][mask]
@@ -245,12 +255,13 @@ def run(self, sn, frames):
site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask])
else:
site_centers[site] = self._pbcc.average(pts)
-
- # Build output obejcts
- out_sn = sn.copy()
-
out_sn.centers = site_centers
- assert out_sn.vertices is None
+ # - If clustering gave us that, compute site vertices
+ if landmark_clusters is not None:
+ vertices = []
+ for lclust in landmark_clusters:
+ vertices.append(set.union(*[set(sn.vertices[l]) for l in lclust]))
+ out_sn.vertices = vertices
out_st = SiteTrajectory(out_sn, lmk_lbls, lmk_confs)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 138e1a4..eba0ed2 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -3,7 +3,6 @@
from sitator.util.progress import tqdm
from sitator.util.mcl import markov_clustering
from sitator.util import DotProdClassifier
-from ..helpers import _cross_correlation_matrix
from sklearn.covariance import empirical_covariance
@@ -13,7 +12,8 @@
logger = logging.getLogger(__name__)
DEFAULT_PARAMS = {
- 'assignment_threshold' : 0.9
+ 'inflation' : 4,
+ 'assignment_threshold' : 0.7,
}
def cov2corr( A ):
@@ -55,8 +55,7 @@ def do_landmark_clustering(landmark_vectors,
else:
# PCA inspired:
eigenval, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
- # abs cause all our data is in the first "octant"
- centers[i, cluster] = np.abs(eigenvec.T)
+ centers[i, cluster] = eigenvec.T
landmark_classifier = \
@@ -65,10 +64,19 @@ def do_landmark_clustering(landmark_vectors,
landmark_classifier.set_cluster_centers(centers)
- lmk_lbls, lmk_confs = \
+ lmk_lbls, lmk_confs, info = \
landmark_classifier.fit_predict(landmark_vectors,
predict_threshold = predict_threshold,
predict_normed = True,
- verbose = verbose)
-
- return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
+ verbose = verbose,
+ return_info = True)
+
+ msk = info['kept_clusters_mask']
+ clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
+
+ return (
+ landmark_classifier.cluster_counts,
+ lmk_lbls,
+ lmk_confs,
+ clusters
+ )
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
index 31a2c27..cf7444f 100644
--- a/sitator/network/merging.py
+++ b/sitator/network/merging.py
@@ -53,6 +53,7 @@ def run(self, st, **kwargs):
clusters = self._get_sites_to_merge(st, **kwargs)
+ old_n_sites = st.site_network.n_sites
new_n_sites = len(clusters)
logger.info("After merging %i sites there will be %i sites for %i mobile particles" % (len(site_centers), new_n_sites, st.site_network.n_mobile))
@@ -62,6 +63,9 @@ def run(self, st, **kwargs):
if self.check_types:
new_types = np.empty(shape = new_n_sites, dtype = np.int)
+ merge_verts = st.site_network.vertices is not None
+ if merge_verts:
+ new_verts = []
# -- Merge Sites
new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype)
@@ -90,11 +94,15 @@ def run(self, st, **kwargs):
if self.check_types:
assert np.all(site_types[mask] == site_types[mask][0])
new_types[newsite] = site_types[mask][0]
+ if merge_verts:
+ new_verts.append(set.union(*[set(st.site_network.vertices[i]) for i in mask]))
newsn = st.site_network.copy()
newsn.centers = new_centers
if self.check_types:
newsn.site_types = new_types
+ if merge_verts:
+ newsn.vertices = new_verts
newtraj = translation[st._traj]
newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index c7f08de..31c3417 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -118,7 +118,8 @@ class DotProdClassifier(object):
if return_info:
info = {
- 'clusters_below_min_samples' : np.sum(~count_mask)
+ 'clusters_below_min_samples' : np.sum(~count_mask),
+ 'kept_clusters_mask' : count_mask
}
return labels, confs, info
else:
From 58d56ff8d368e6cc10875142a7ef9aeca53ddd1c Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 11 Jul 2019 14:33:29 -0400
Subject: [PATCH 060/129] Corrected behaviour for "forbid multiple occupancy"
merging
---
sitator/dynamics/MergeSitesByThreshold.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/sitator/dynamics/MergeSitesByThreshold.py b/sitator/dynamics/MergeSitesByThreshold.py
index 4469222..258f560 100644
--- a/sitator/dynamics/MergeSitesByThreshold.py
+++ b/sitator/dynamics/MergeSitesByThreshold.py
@@ -64,9 +64,10 @@ def _get_sites_to_merge(self, st, threshold = 0):
if self.forbid_multiple_occupancy:
n_mobile = sn.n_mobile
for frame in st.traj:
- for mob in range(n_mobile):
+ frame = [s for s in frame if s >= 0]
+ for site in frame: # only known
# can't merge occupied site with other simulatanious occupied sites
- connmat[frame[mob], frame] = False
+ connmat[site, frame] = False
# Everything is always mergable with itself.
np.fill_diagonal(connmat, True)
From d7e9cd823b57fdaa131c4b40f5a918b4a4966ae0 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 14:59:55 -0400
Subject: [PATCH 061/129] Pymatgen ChemEnv for site type analysis
---
sitator/SiteNetwork.py | 5 +-
.../SiteCoordinationEnvironment.py | 85 +++++++++++++++++++
.../SiteCoordinationNumber.py | 38 ---------
sitator/site_descriptors/SiteTypeAnalysis.py | 2 +-
sitator/visualization/SiteNetworkPlotter.py | 3 +-
5 files changed, 91 insertions(+), 42 deletions(-)
create mode 100644 sitator/site_descriptors/SiteCoordinationEnvironment.py
delete mode 100644 sitator/site_descriptors/SiteCoordinationNumber.py
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 64954f9..e396529 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -122,7 +122,7 @@ def get_structure_with_sites(self, site_atomic_number = None):
- site_atomic_number: If `None`, the species of the first mobile atom
will be used.
Returns:
- ase.Atoms and final `site_atomic_number`
+ ase.Atoms, indices of sites in the returned structure, and final `site_atomic_number`
"""
out = self.static_structure.copy()
if site_atomic_number is None:
@@ -132,8 +132,9 @@ def get_structure_with_sites(self, site_atomic_number = None):
positions = self.centers,
numbers = numbers
)
+ site_idexes = len(out) + np.arange(self.n_sites)
out.extend(sites_atoms)
- return out, site_atomic_number
+ return out, site_idexes, ite_atomic_number
@property
def n_sites(self):
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
new file mode 100644
index 0000000..973f8ad
--- /dev/null
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -0,0 +1,85 @@
+import numpy as np
+
+from sitator.util.progress import tqdm
+
+try:
+ from pymatgen.io.ase import AseAtomsAdaptor
+ import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf
+ from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \
+ LightStructureEnvironments
+ from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions
+ has_pymatgen = True
+except ImportError:
+ has_pymatgen = False
+
+
+class SiteCoordinationEnvironment(object):
+ """Determine site types based on local coordination environments.
+
+ Determine site types using the method from the following paper:
+
+ David Waroquiers, Xavier Gonze, Gian-Marco Rignanese, Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk, Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier,
+ “Statistical analysis of coordination environments in oxides”,
+ Chem. Mater., 2017, 29 (19), pp 8346–8360, DOI: 10.1021/acs.chemmater.7b02766
+
+ as implement in `pymatgen`'s `pymatgen.analysis.chemenv.coordination_environments`.
+
+ Args:
+ **kwargs: passed to `compute_structure_environments`.
+ """
+ def __init__(self, **kwargs):
+ if not has_pymatgen:
+ raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
+ self._kwargs = kwargs
+
+ def run(self, sn):
+ # -- Determine local environments
+ # Get an ASE structure with a single mobile atom that we'll move around
+ site_struct, idexes, site_species = sn[0:1].get_structure_with_sites()
+ pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
+ lgf = cgf.LocalGeometryFinder()
+ index = idexes[0]
+
+ coord_envs = []
+ vertices = []
+
+ # Do this once.
+ # __init__ here defaults to disabling structure refinement, so all this
+ # method is doing is making a copy of the structure and setting some
+ # variables to None.
+ lgf.setup_structure(structure = pymat_struct)
+
+ for site in tqdm(range(sn.n_sites), desc = 'Site'):
+ # Update the position of the site
+ lgf.structure[index].coords = sn.centers[site]
+ # Compute structure environments for the site
+ struct_envs = lgf.compute_structure_environments(only_indices = [index])
+ struct_envs = LightStructureEnvironments.from_structure_environments(
+ strategy=cgf.LocalGeometryFinder.DEFAULT_STRATEGY,
+ structure_environments=struct_envs
+ )
+ # Store the results
+ # We take the first environment for each site since it's the most likely
+ coord_envs.append(struct_envs.coordination_environments[index][0])
+ vertices.append(
+ [n['index'] for n in struct_envs.neighbors_sets[index][0].neighb_indices_and_images]
+ )
+
+ del lgf
+ del struct_envs
+
+ # -- Postprocess
+ # TODO: allow user to ask for full fractional breakdown
+ unique_envs = list(set(env['ce_symbol'] for env in coord_envs))
+ site_types = np.array([unique_envs.index(env['ce_symbol']) for env in coord_envs])
+ # The closer to 1 this is, the better
+ site_type_confidences = np.array([unique_envs.index(env[0]['ce_fraction']) for env in coord_envs])
+ coordination_numbers = np.array([int(env['ce_symbol'].split(':')[1]) for env in coord_envs])
+ assert np.all(coordination_numbers == [len(v) for v in vertices])
+
+ sn.site_types = site_types
+ sn.vertices = vertices
+ sn.add_site_attribute("site_type_confidences", site_type_confidences)
+ sn.add_site_attribute("coordination_numbers", coordination_numbers)
+
+ return sn
diff --git a/sitator/site_descriptors/SiteCoordinationNumber.py b/sitator/site_descriptors/SiteCoordinationNumber.py
deleted file mode 100644
index bd53218..0000000
--- a/sitator/site_descriptors/SiteCoordinationNumber.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import numpy as np
-
-try:
- from pymatgen.io.ase import AseAtomsAdaptor
- import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf
- has_pymatgen = True
-except ImportError:
- has_pymatgen = False
-
-
-class SiteCoordinationAnalysis(object):
- """Determine site types based on local coordination environments.
-
- Determine site types using the method from the following paper:
-
- David Waroquiers, Xavier Gonze, Gian-Marco Rignanese, Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk, Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier,
- “Statistical analysis of coordination environments in oxides”,
- Chem. Mater., 2017, 29 (19), pp 8346–8360, DOI: 10.1021/acs.chemmater.7b02766
-
- as implement in `pymatgen`'s `pymatgen.analysis.chemenv.coordination_environments`.
-
- Args:
- **kwargs: passed to `compute_structure_environments`.
- """
- def __init__(self, **kwargs):
- if not has_pymatgen:
- raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
- self._kwargs = kwargs
-
- def run(self, sn):
- site_struct, site_species = sn.get_structure_with_sites()
- pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
- lgf = cgf.LocalGeometryFinder()
- struct_envs = lgf.compute_structure_environments(
- structure = pymat_struct,
- indicies = np.where(sn.mobile_mask)[0],
- only_cations = False,
- )
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index ee16489..caeb4b4 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -19,7 +19,7 @@
raise ImportError("SiteTypeAnalysis requires the `pydpc` package")
class SiteTypeAnalysis(object):
- """Cluster sites into types using a descriptor and Density Peak Clustering.
+ """Cluster sites into types using a continuous descriptor and Density Peak Clustering.
-- descriptor --
Some kind of object implementing:
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index fefd0be..e582cc2 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -148,7 +148,8 @@ def _site_layers(self, sn, plot_points_params, same_normalization = False):
# Just one layer with all points and one marker
marker_layers[SiteNetworkPlotter.DEFAULT_MARKERS[0]] = np.ones(shape = sn.n_sites, dtype = np.bool)
else:
- markers = self._make_discrete(markers)
+ if not istextmarker:
+ markers = self._make_discrete(markers)
unique_markers = np.unique(markers)
if not same_normalization:
if istextmarker:
From 866f5c0ea6a22dcdaeebeb36f64cf2b05bf0f6cf Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 15:02:13 -0400
Subject: [PATCH 062/129] Cleanup
---
sitator/misc/__init__.py | 1 -
sitator/{misc => site_descriptors}/SiteVolumes.py | 0
sitator/site_descriptors/__init__.py | 4 ++--
3 files changed, 2 insertions(+), 3 deletions(-)
rename sitator/{misc => site_descriptors}/SiteVolumes.py (100%)
diff --git a/sitator/misc/__init__.py b/sitator/misc/__init__.py
index 647ed08..8240d62 100644
--- a/sitator/misc/__init__.py
+++ b/sitator/misc/__init__.py
@@ -1,5 +1,4 @@
from .NAvgsPerSite import NAvgsPerSite
from .GenerateAroundSites import GenerateAroundSites
-from .SiteVolumes import SiteVolumes
from .GenerateClampedTrajectory import GenerateClampedTrajectory
diff --git a/sitator/misc/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py
similarity index 100%
rename from sitator/misc/SiteVolumes.py
rename to sitator/site_descriptors/SiteVolumes.py
diff --git a/sitator/site_descriptors/__init__.py b/sitator/site_descriptors/__init__.py
index 050a51e..fd5e7af 100644
--- a/sitator/site_descriptors/__init__.py
+++ b/sitator/site_descriptors/__init__.py
@@ -1,3 +1,3 @@
from .SiteTypeAnalysis import SiteTypeAnalysis
-
-from .SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages, SOAP
+from .SiteCoordinationEnvironment import SiteCoordinationEnvironment
+from .SiteVolumes import SiteVolumes
From 71e9e9a91740be503940ab9d0c40d7a8765d9140 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 15:44:34 -0400
Subject: [PATCH 063/129] Update README with `pymatgen` dependency; bugfix
---
README.md | 15 ++++++++-------
sitator/SiteNetwork.py | 2 +-
.../SiteCoordinationEnvironment.py | 4 ++--
3 files changed, 11 insertions(+), 10 deletions(-)
diff --git a/README.md b/README.md
index 19eb832..3b4d5a6 100644
--- a/README.md
+++ b/README.md
@@ -21,15 +21,16 @@ If you use `sitator` in your research, please consider citing this paper. The Bi
`sitator` is built for Python >=3.2 (the older version supports Python 2.7). We recommend the use of a virtual environment (`virtualenv`, `conda`, etc.). `sitator` has a number of optional dependencies that enable various features:
- * **Landmark Analysis**
- * The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does not have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
- * **Site Type Analysis**
- * For computing SOAP vectors: the `quip` binary from [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) **or** the [`DScribe`](https://singroup.github.io/dscribe/index.html) Python library.
- * The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
+* Landmark Analysis
+ * The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does not have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
+* Site Type Analysis
+ * For SOAP-based site types: either the `quip` binary from [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) **or** the [`DScribe`](https://singroup.github.io/dscribe/index.html) Python library.
+ * The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
+ * For coordination environment analysis (`sitator.site_descriptors.SiteCoordinationEnvironment`), we integrate the `pymatgen.analysis.chemenv` package; a somewhat recent installation of `pymatgen` is required.
After downloading, the package is installed with `pip`:
-```
+```bash
# git clone ... OR unzip ... OR ...
cd sitator
pip install .
@@ -37,7 +38,7 @@ pip install .
To enable site type analysis, add the `[SiteTypeAnalysis]` option (this adds two dependencies -- Python packages `pydpc` and `dscribe`):
-```
+```bash
pip install ".[SiteTypeAnalysis]"
```
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index e396529..6aec2db 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -126,7 +126,7 @@ def get_structure_with_sites(self, site_atomic_number = None):
"""
out = self.static_structure.copy()
if site_atomic_number is None:
- site_atomic_number = self.structure.get_atomic_numbers()[mobile_mask][0]
+ site_atomic_number = self.structure.get_atomic_numbers()[self.mobile_mask][0]
numbers = np.full(len(self), site_atomic_number)
sites_atoms = ase.Atoms(
positions = self.centers,
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index 973f8ad..b199bad 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -34,7 +34,7 @@ def __init__(self, **kwargs):
def run(self, sn):
# -- Determine local environments
- # Get an ASE structure with a single mobile atom that we'll move around
+ # Get an ASE structure with a single mobile site that we'll move around
site_struct, idexes, site_species = sn[0:1].get_structure_with_sites()
pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
lgf = cgf.LocalGeometryFinder()
@@ -53,7 +53,7 @@ def run(self, sn):
# Update the position of the site
lgf.structure[index].coords = sn.centers[site]
# Compute structure environments for the site
- struct_envs = lgf.compute_structure_environments(only_indices = [index])
+ struct_envs = lgf.compute_structure_environments(only_indices = [index], **self._kwargs)
struct_envs = LightStructureEnvironments.from_structure_environments(
strategy=cgf.LocalGeometryFinder.DEFAULT_STRATEGY,
structure_environments=struct_envs
From ddad73e57c2282dc4bb79454bfb3da986dfcb663 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 16:37:56 -0400
Subject: [PATCH 064/129] Coordination environment logging
---
sitator/SiteNetwork.py | 2 +-
.../site_descriptors/SiteCoordinationEnvironment.py | 10 +++++++++-
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 6aec2db..c0a7d41 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -134,7 +134,7 @@ def get_structure_with_sites(self, site_atomic_number = None):
)
site_idexes = len(out) + np.arange(self.n_sites)
out.extend(sites_atoms)
- return out, site_idexes, ite_atomic_number
+ return out, site_idexes, site_atomic_number
@property
def n_sites(self):
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index b199bad..0e37a0d 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -12,6 +12,9 @@
except ImportError:
has_pymatgen = False
+import logging
+logger = logging.getLogger(__name__)
+
class SiteCoordinationEnvironment(object):
"""Determine site types based on local coordination environments.
@@ -43,6 +46,7 @@ def run(self, sn):
coord_envs = []
vertices = []
+ logger.info("Running site coordination environment analysis...")
# Do this once.
# __init__ here defaults to disabling structure refinement, so all this
# method is doing is making a copy of the structure and setting some
@@ -73,10 +77,14 @@ def run(self, sn):
unique_envs = list(set(env['ce_symbol'] for env in coord_envs))
site_types = np.array([unique_envs.index(env['ce_symbol']) for env in coord_envs])
# The closer to 1 this is, the better
- site_type_confidences = np.array([unique_envs.index(env[0]['ce_fraction']) for env in coord_envs])
+ site_type_confidences = np.array([env['ce_fraction'] for env in coord_envs])
coordination_numbers = np.array([int(env['ce_symbol'].split(':')[1]) for env in coord_envs])
assert np.all(coordination_numbers == [len(v) for v in vertices])
+ n_types = len(unique_envs)
+ logger.info((" " + "Type {:<2} " * n_types).format(*range(n_types)))
+ logger.info(("# of sites " + "{:<8}" * n_types).format(*np.bincount(site_types)))
+
sn.site_types = site_types
sn.vertices = vertices
sn.add_site_attribute("site_type_confidences", site_type_confidences)
From fc4788c41f03bc017774697c3ee5bdc8c86e5495 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 16:39:51 -0400
Subject: [PATCH 065/129] Minor visualization cleanups
---
sitator/visualization/SiteNetworkPlotter.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index e582cc2..2c5a1b8 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -53,7 +53,7 @@ def __init__(self,
plot_points_params = {},
minmax_linewidth = (1.5, 7),
minmax_edge_alpha = (0.15, 0.75),
- minmax_markersize = (80.0, 180.0),
+ minmax_markersize = (20.0, 80.0),
min_color_threshold = 0.0,
min_width_threshold = 0.0,
title = ""):
@@ -101,7 +101,7 @@ def __call__(self, sn, *args, **kwargs):
def _site_layers(self, sn, plot_points_params, same_normalization = False):
pts_arrays = {'points' : sn.centers}
- pts_params = {'cmap' : 'copper'}
+ pts_params = {'cmap' : 'cividis'}
# -- Apply mapping
# - other mappings
@@ -166,6 +166,9 @@ def _site_layers(self, sn, plot_points_params, same_normalization = False):
# If no color info provided, a fallback
if not 'color' in pts_params and not 'c' in pts_arrays:
pts_params['color'] = 'k'
+ # If no color info provided, a fallback
+ if not 's' in pts_params and not 's' in pts_arrays:
+ pts_params['s'] = sum(self.minmax_markersize) / 2
# Add user options for `plot_points`
pts_params.update(plot_points_params)
From 21bce6e6f3e888397f9e97dba679acc35369897a Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 16:53:35 -0400
Subject: [PATCH 066/129] Fixed corner case in ideal site volumes
---
sitator/site_descriptors/SiteVolumes.py | 38 +++++++++++++++++++------
1 file changed, 30 insertions(+), 8 deletions(-)
diff --git a/sitator/site_descriptors/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py
index bd87169..4b76b8a 100644
--- a/sitator/site_descriptors/SiteVolumes.py
+++ b/sitator/site_descriptors/SiteVolumes.py
@@ -3,16 +3,28 @@
from scipy.spatial import ConvexHull
from scipy.spatial.qhull import QhullError
-from sitator import SiteTrajectory
+from sitator import SiteNetwork, SiteTrajectory
from sitator.util import PBCCalculator
import logging
logger = logging.getLogger(__name__)
+class InsufficientCoordinatingAtomsError(Exception):
+ pass
+
class SiteVolumes(object):
- """Compute the volumes of sites."""
- def __init__(self):
- pass
+ """Compute the volumes of sites.
+
+ Args:
+ - error_on_insufficient_coord (bool, default: True): To compute an
+ ideal site volume (`compute_volumes()`), at least 4 coordinating
+ atoms (because we are in 3D space) must be specified in `vertices`.
+ If True, an error will be thrown when a site with less than four
+ vertices is encountered; if False, a volume of 0 and surface area
+ of NaN will be returned.
+ """
+ def __init__(self, error_on_insufficient_coord = True):
+ self.error_on_insufficient_coord = error_on_insufficient_coord
def compute_accessable_volumes(self, st, n_recenterings = 8):
@@ -30,6 +42,7 @@ def compute_accessable_volumes(self, st, n_recenterings = 8):
resulting volume; this deals with cases where there is one outlier
where recentering around it gives very bad results.)
"""
+ assert isinstance(st, SiteTrajectory)
vols = np.empty(shape = st.site_network.n_sites, dtype = np.float)
areas = np.empty(shape = st.site_network.n_sites, dtype = np.float)
@@ -76,16 +89,25 @@ def compute_volumes(self, sn):
Args:
- sn (SiteNetwork)
"""
+ assert isinstance(sn, SiteNetwork)
if sn.vertices is None:
raise ValueError("SiteNetwork must have verticies to compute volumes!")
- vols = np.empty(shape = st.site_network.n_sites, dtype = np.float)
- areas = np.empty(shape = st.site_network.n_sites, dtype = np.float)
+ vols = np.empty(shape = sn.n_sites, dtype = np.float)
+ areas = np.empty(shape = sn.n_sites, dtype = np.float)
- pbcc = PBCCalculator(st.site_network.structure.cell)
+ pbcc = PBCCalculator(sn.structure.cell)
- for site in range(st.site_network.n_sites):
+ for site in range(sn.n_sites):
pos = sn.static_structure.positions[sn.vertices[site]]
+ if len(pos) < 4:
+ if self.error_on_insufficient_coord:
+ raise InsufficientCoordinatingAtomsError("Site %i had only %i vertices (less than needed 4)" % (site, len(pos)))
+ else:
+ vols[site] = 0
+ areas[site] = np.nan
+ continue
+
assert pos.flags['OWNDATA'] # It should since we're indexing with index lists
# Recenter
offset = pbcc.cell_centroid - sn.centers[site]
From e720b95cfbc55630b61610fdf126d89a28308d2e Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 16:59:18 -0400
Subject: [PATCH 067/129] Removed unused QVoronoi interface
---
sitator/util/qvoronoi.py | 144 ---------------------------------------
1 file changed, 144 deletions(-)
delete mode 100644 sitator/util/qvoronoi.py
diff --git a/sitator/util/qvoronoi.py b/sitator/util/qvoronoi.py
deleted file mode 100644
index aaa878e..0000000
--- a/sitator/util/qvoronoi.py
+++ /dev/null
@@ -1,144 +0,0 @@
-
-import tempfile
-import subprocess
-import sys
-
-import re
-from collections import OrderedDict
-
-import numpy as np
-
-from sklearn.neighbors import KDTree
-
-from sitator.util import PBCCalculator
-
-def periodic_voronoi(structure, logfile = sys.stdout):
- """
- :param ASE.Atoms structure:
- """
-
- pbcc = PBCCalculator(structure.cell)
-
- # Make a 3x3x3 supercell
- supercell = structure.repeat((3, 3, 3))
-
- qhull_output = None
-
- logfile.write("Qvoronoi ---")
-
- # Run qhull
- with tempfile.NamedTemporaryFile('w',
- prefix = 'qvor',
- suffix='.in', delete = False) as infile, \
- tempfile.NamedTemporaryFile('r',
- prefix = 'qvor',
- suffix='.out',
- delete=True) as outfile:
- # -- Write input file --
- infile.write("3\n") # num of dimensions
- infile.write("%i\n" % len(supercell)) # num of points
- np.savetxt(infile, supercell.get_positions(), fmt = '%.16f')
- infile.flush()
-
- cmdline = ["qvoronoi", "TI", infile.name, "FF", "Fv", "TO", outfile.name]
- process = subprocess.Popen(cmdline, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
- retcode = process.wait()
- logfile.write(process.stdout.read())
- if retcode != 0:
- raise RuntimeError("qvoronoi returned exit code %i" % retcode)
-
- qhull_output = outfile.read()
-
- facets_regex = re.compile(
- """
- -[ \t](?Pf[0-9]+) [\n]
- [ \t]*-[ ]flags: .* [\n]
- [ \t]*-[ ]normal: .* [\n]
- [ \t]*-[ ]offset: .* [\n]
- [ \t]*-[ ]center:(?P([ ][\-]?[0-9]*[\.]?[0-9]*(e[-?[0-9]+)?){3}) [ \t] [\n]
- [ \t]*-[ ]vertices:(?P([ ]p[0-9]+\(v[0-9]+\))+) [ \t]? [\n]
- [ \t]*-[ ]neighboring[ ]facets:(?P([ ]f[0-9]+)+)
- """, re.X | re.M)
-
- vertices_re = re.compile('(?<=p)[0-9]+')
-
- # Allocate stuff
- centers = []
- vertices = []
- facet_indexes_taken = set()
-
- facet_index_to_our_index = {}
- all_facets_centers = []
-
- # ---- Read facets
- facet_index = -1
- next_our_index = 0
- for facet_match in facets_regex.finditer(qhull_output):
- center = np.asarray(list(map(float, facet_match.group('center').split())))
- facet_index += 1
-
- all_facets_centers.append(center)
-
- if not pbcc.is_in_image_of_cell(center, (1, 1, 1)):
- continue
-
- verts = list(map(int, vertices_re.findall(facet_match.group('vertices'))))
- verts_in_main_cell = tuple(v % len(structure) for v in verts)
-
- facet_indexes_taken.add(facet_index)
-
- centers.append(center)
- vertices.append(verts_in_main_cell)
-
- facet_index_to_our_index[facet_index] = next_our_index
-
- next_our_index += 1
-
- end_of_facets = facet_match.end()
-
- facet_count = facet_index + 1
-
- logfile.write(" qhull gave %i vertices; kept %i" % (facet_count, len(centers)))
-
- # ---- Read ridges
- qhull_output_after_facets = qhull_output[end_of_facets:].strip()
- ridge_re = re.compile('^\d+ \d+ \d+(?P( \d+)+)$', re.M)
-
- ridges = [
- [int(v) for v in match.group('verts').split()]
- for match in ridge_re.finditer(qhull_output_after_facets)
- ]
- # only take ridges with at least 1 facet in main unit cell.
- ridges = [
- r for r in ridges if any(f in facet_indexes_taken for f in r)
- ]
-
- # shift centers back into normal unit cell
- centers -= np.sum(structure.cell, axis = 0)
-
- nearest_center = KDTree(centers)
-
- ridges_in_main_cell = set()
- threw_out = 0
- for r in ridges:
- ridge_centers = np.asarray([all_facets_centers[f] for f in r if f < len(all_facets_centers)])
- if not pbcc.all_in_unit_cell(ridge_centers):
- continue
-
- pbcc.wrap_points(ridge_centers)
- dists, ridge_centers_in_main = nearest_center.query(ridge_centers, return_distance = True)
-
- if np.any(dists > 0.00001):
- threw_out += 1
- continue
-
- assert ridge_centers_in_main.shape == (len(ridge_centers), 1), "%s" % ridge_centers_in_main.shape
- ridge_centers_in_main = ridge_centers_in_main[:,0]
-
- ridges_in_main_cell.add(frozenset(ridge_centers_in_main))
-
- logfile.write(" Threw out %i ridges" % threw_out)
-
- logfile.flush()
-
- return centers, vertices, ridges_in_main_cell
From c68dc8179ab425da3d6d5b5a75fea048987f1492 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 12 Jul 2019 17:09:36 -0400
Subject: [PATCH 068/129] Save coordination environment symbols in SiteNetwork
---
sitator/site_descriptors/SiteCoordinationEnvironment.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index 0e37a0d..ebb97df 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -87,6 +87,7 @@ def run(self, sn):
sn.site_types = site_types
sn.vertices = vertices
+ sn.add_site_attribute("coordination_environments", [env['ce_symbol'] for env in coord_envs])
sn.add_site_attribute("site_type_confidences", site_type_confidences)
sn.add_site_attribute("coordination_numbers", coordination_numbers)
From bc7a9ff0728e58b5f8047eba9a125e1398f8c364 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 15 Jul 2019 10:19:08 -0400
Subject: [PATCH 069/129] Only anion-cation bonds in coordination analysis
---
sitator/SiteNetwork.py | 5 +-
.../SiteCoordinationEnvironment.py | 54 ++++++++++++++-----
2 files changed, 46 insertions(+), 13 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index c0a7d41..1d9c6fd 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -118,6 +118,9 @@ def of_type(self, stype):
def get_structure_with_sites(self, site_atomic_number = None):
"""Get an `ase.Atoms` with the sites included.
+ Sites are appended to the static structure; the first `np.sum(static_mask)`
+ atoms in the returned object are the static structure.
+
Args:
- site_atomic_number: If `None`, the species of the first mobile atom
will be used.
@@ -134,7 +137,7 @@ def get_structure_with_sites(self, site_atomic_number = None):
)
site_idexes = len(out) + np.arange(self.n_sites)
out.extend(sites_atoms)
- return out, site_idexes, site_atomic_number
+ return out, site_atomic_number
@property
def n_sites(self):
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index ebb97df..273f919 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -8,6 +8,7 @@
from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \
LightStructureEnvironments
from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions
+ from pymatgen.analysis.bond_valence import BVAnalyzer
has_pymatgen = True
except ImportError:
has_pymatgen = False
@@ -30,22 +31,41 @@ class SiteCoordinationEnvironment(object):
Args:
**kwargs: passed to `compute_structure_environments`.
"""
- def __init__(self, **kwargs):
+ def __init__(self, guess_ionic_bonds = True, **kwargs):
if not has_pymatgen:
raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
self._kwargs = kwargs
+ self._guess_ionic_bonds = guess_ionic_bonds
def run(self, sn):
# -- Determine local environments
# Get an ASE structure with a single mobile site that we'll move around
- site_struct, idexes, site_species = sn[0:1].get_structure_with_sites()
+ site_struct, site_species = sn[0:1].get_structure_with_sites()
pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
lgf = cgf.LocalGeometryFinder()
- index = idexes[0]
+ site_atom_index = len(site_struct) - 1
coord_envs = []
vertices = []
+ valences = 'undefined'
+ if self._guess_ionic_bonds:
+ sim_struct = AseAtomsAdaptor.get_structure(sn.structure)
+ valences = np.zeros(shape = len(site_struct), dtype = np.int)
+ bv = BVAnalyzer()
+ try:
+ struct_valences = np.asarray(bv.get_valences(sim_struct))
+ except ValueError as ve:
+ logger.warning("Failed to compute bond valences: %s" % ve)
+ else:
+ valences = np.zeros(shape = len(site_struct), dtype = np.int)
+ valences[:site_atom_index] = struct_valences[sn.static_mask]
+ mob_val = struct_valences[sn.mobile_mask]
+ if np.any(mob_val != mob_val[0]):
+ logger.warning("Mobile atom estimated valences (%s) not uniform; arbitrarily taking first." % mob_val)
+ valences[site_atom_index] = mob_val[0]
+ valences = list(valences)
+
logger.info("Running site coordination environment analysis...")
# Do this once.
# __init__ here defaults to disabling structure refinement, so all this
@@ -55,19 +75,29 @@ def run(self, sn):
for site in tqdm(range(sn.n_sites), desc = 'Site'):
# Update the position of the site
- lgf.structure[index].coords = sn.centers[site]
+ lgf.structure[site_atom_index].coords = sn.centers[site]
# Compute structure environments for the site
- struct_envs = lgf.compute_structure_environments(only_indices = [index], **self._kwargs)
+ struct_envs = lgf.compute_structure_environments(
+ only_indices = [site_atom_index],
+ valences = valences,
+ additional_conditions = [AdditionalConditions.ONLY_ANION_CATION_BONDS],
+ **self._kwargs
+ )
struct_envs = LightStructureEnvironments.from_structure_environments(
- strategy=cgf.LocalGeometryFinder.DEFAULT_STRATEGY,
- structure_environments=struct_envs
+ strategy = cgf.LocalGeometryFinder.DEFAULT_STRATEGY,
+ structure_environments = struct_envs
)
# Store the results
# We take the first environment for each site since it's the most likely
- coord_envs.append(struct_envs.coordination_environments[index][0])
- vertices.append(
- [n['index'] for n in struct_envs.neighbors_sets[index][0].neighb_indices_and_images]
- )
+ this_site_envs = struct_envs.coordination_environments[site_atom_index]
+ if len(this_site_envs) > 0:
+ coord_envs.append(this_site_envs[0])
+ vertices.append(
+ [n['index'] for n in struct_envs.neighbors_sets[site_atom_index][0].neighb_indices_and_images]
+ )
+ else:
+ coord_envs.append({'ce_symbol' : 'BAD:0', 'ce_fraction' : 0.})
+ vertices.append([])
del lgf
del struct_envs
@@ -82,7 +112,7 @@ def run(self, sn):
assert np.all(coordination_numbers == [len(v) for v in vertices])
n_types = len(unique_envs)
- logger.info((" " + "Type {:<2} " * n_types).format(*range(n_types)))
+ logger.info(("Type " + "{:<8}" * n_types).format(*unique_envs))
logger.info(("# of sites " + "{:<8}" * n_types).format(*np.bincount(site_types)))
sn.site_types = site_types
From d30aa3ec757b2fcb7e7a976e5d170128acaa1853 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 15 Jul 2019 12:02:06 -0400
Subject: [PATCH 070/129] Weighted spatial merge averaging
---
sitator/network/merging.py | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
index cf7444f..5b6e631 100644
--- a/sitator/network/merging.py
+++ b/sitator/network/merging.py
@@ -29,14 +29,18 @@ class MergeSites(abc.ABC):
:param bool set_merged_into: If True, a site attribute `"merged_into"` will
be added to the original `SiteNetwork` indicating which new site
each old site was merged into.
+ :param bool weighted_spatial_average: If True, the spatial average giving
+ the position of the merged site will be weighted by occupancy.
"""
def __init__(self,
check_types = True,
maximum_merge_distance = None,
- set_merged_into = False):
+ set_merged_into = False,
+ weighted_spatial_average = True):
self.check_types = check_types
self.maximum_merge_distance = maximum_merge_distance
self.set_merged_into = set_merged_into
+ self.weighted_spatial_average = weighted_spatial_average
def run(self, st, **kwargs):
@@ -90,7 +94,12 @@ def run(self, st, **kwargs):
raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %.2f apart. Lower your distance_threshold?" % self.maximum_merge_distance)
# New site center
- new_centers[newsite] = pbcc.average(to_merge)
+ if self.weighted_spatial_average:
+ new_centers[newsite] = pbcc.average(to_merge)
+ else:
+ occs = st.site_network.occupancies[mask]
+ new_centers[newsite] = pbcc.average(to_merge, weights = occs)
+
if self.check_types:
assert np.all(site_types[mask] == site_types[mask][0])
new_types[newsite] = site_types[mask][0]
From a92f70166b21ee308291966c94f5ac50f0d8ce6c Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 15 Jul 2019 12:02:23 -0400
Subject: [PATCH 071/129] Handle sets of vertices without error
---
sitator/site_descriptors/SiteVolumes.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/site_descriptors/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py
index 4b76b8a..1025672 100644
--- a/sitator/site_descriptors/SiteVolumes.py
+++ b/sitator/site_descriptors/SiteVolumes.py
@@ -99,7 +99,7 @@ def compute_volumes(self, sn):
pbcc = PBCCalculator(sn.structure.cell)
for site in range(sn.n_sites):
- pos = sn.static_structure.positions[sn.vertices[site]]
+ pos = sn.static_structure.positions[list(sn.vertices[site])]
if len(pos) < 4:
if self.error_on_insufficient_coord:
raise InsufficientCoordinatingAtomsError("Site %i had only %i vertices (less than needed 4)" % (site, len(pos)))
From 47b8f3045f735860777e218bd94f790438ac41fa Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 14:59:12 -0400
Subject: [PATCH 072/129] Better short jump removal
---
sitator/dynamics/RemoveShortJumps.py | 53 +++++++++++++++++++----
sitator/dynamics/RemoveUnoccupiedSites.py | 25 +++++++++--
sitator/dynamics/__init__.py | 2 +
3 files changed, 67 insertions(+), 13 deletions(-)
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
index 99bf662..de27438 100644
--- a/sitator/dynamics/RemoveShortJumps.py
+++ b/sitator/dynamics/RemoveShortJumps.py
@@ -1,6 +1,9 @@
import numpy as np
+from collections import defaultdict
+
from sitator import SiteTrajectory
+from sitator.dynamics import RemoveUnoccupiedSites
import logging
logger = logging.getLogger(__name__)
@@ -16,13 +19,17 @@ class RemoveShortJumps(object):
- only_returning_jumps (bool, default: True): If True, only short jumps
where the mobile atom returns to its initial site will be removed.
"""
- def __init__(self, only_returning_jumps = True):
+ def __init__(self,
+ only_returning_jumps = True,
+ remove_unoccupied_sites = True):
self.only_returning_jumps = only_returning_jumps
+ self.remove_unoccupied_sites = remove_unoccupied_sites
def run(self,
st,
- threshold):
+ threshold,
+ return_stats = False):
"""Returns a copy of `st` with short jumps removed.
Args:
@@ -47,7 +54,12 @@ def run(self,
n_problems = 0
n_short_jumps = 0
+ # Dict of lists [sum_jump_times, n_short_jumps]
+ short_jump_info = defaultdict(lambda: [0, 0])
+
for i, frame in enumerate(st.traj):
+ if i == 0:
+ continue
# -- Deal with unassigned
# Don't screw up the SiteTrajectory
np.copyto(framebuf, frame)
@@ -63,8 +75,10 @@ def run(self,
# -- Update stats
jumped = (frame != last_known) & fknown
- problems = last_known[jumped] == -1
- jumped[np.where(jumped)[0][problems]] = False
+ #problems = last_known[jumped] == -1
+ #jumped[np.where(jumped)[0][problems]] = False
+ problems = last_known == -1
+ jumped[problems] = False
n_problems += np.sum(problems)
jump_froms = last_known[jumped]
@@ -78,9 +92,13 @@ def run(self,
short_mask &= jump_tos == previous_site[jumped]
# Remove short jumps
for sj_atom in np.arange(n_mobile)[jumped][short_mask]:
- #print("atom %s removing %i -> %i (%i) -> %i" % (sj_atom, previous_site[sj_atom], last_known[sj_atom], time_at_current[sj_atom], frame[sj_atom]))
+ # Bookkeeping
+ sjkey = (previous_site[sj_atom], last_known[sj_atom], frame[sj_atom])
+ short_jump_info[sjkey][0] += time_at_current[sj_atom]
+ short_jump_info[sjkey][1] += 1
n_short_jumps += 1
- out[i - time_at_current[sj_atom]:i+1, sj_atom] = previous_site[sj_atom]
+ # Remove short jump
+ out[i - time_at_current[sj_atom]:i, sj_atom] = previous_site[sj_atom]
previous_site[jumped] = last_known[jumped]
@@ -93,9 +111,26 @@ def run(self,
if n_problems != 0:
logger.warning("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
logger.info("Removed %i short jumps" % n_short_jumps)
- self.n_short_jumps = n_short_jumps
+ # Do average
+ for k in short_jump_info.keys():
+ short_jump_info[k][0] /= short_jump_info[k][1]
+ logger.info(
+ "Short jump statistics:\n" +
+ "\n".join(
+ " removed {1[1]:3}x {0[0]:2} -> {0[1]:2} -> {0[2]:2}; avg. residence at {0[1]:2} of {1[0]} frames".format(
+ k, v
+ ) for k, v in short_jump_info.items()
+ )
+ )
st = st.copy()
st._traj = out
-
- return st
+ if self.remove_unoccupied_sites:
+ # Removing short jumps could have made some sites completely unoccupied
+ st = RemoveUnoccupiedSites().run(st)
+ st.site_network.clear_attributes()
+
+ if return_stats:
+ return st, short_jump_info
+ else:
+ return st
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
index 4502025..bb5a6e8 100644
--- a/sitator/dynamics/RemoveUnoccupiedSites.py
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -2,26 +2,40 @@
from sitator import SiteTrajectory
+import logging
+logger = logging.getLogger(__name__)
+
+
class RemoveUnoccupiedSites(object):
def __init__(self):
pass
- def run(self, st):
+ def run(self, st, return_kept_sites = False):
"""
+
+ Can return `st` unmodified if all sites are at some point occupied.
+
"""
assert isinstance(st, SiteTrajectory)
old_sn = st.site_network
- seen_mask = np.zeros(shape = old_sn.n_sites, dtype = np.bool)
+ # Allow for the -1 to affect nothing
+ seen_mask = np.zeros(shape = old_sn.n_sites + 1, dtype = np.bool)
for frame in st.traj:
seen_mask[frame] = True
+ if np.all(seen_mask[:-1]):
+ return st
+
+ seen_mask = seen_mask[:-1]
+
+ logger.info("Removing unoccupied sites %s" % np.where(~seen_mask)[0])
n_new_sites = np.sum(seen_mask)
translation = np.empty(shape = old_sn.n_sites, dtype = np.int)
translation[seen_mask] = np.arange(n_new_sites)
- translation[~seen_mask] = -4321
+ translation[~seen_mask] = SiteTrajectory.SITE_UNKNOWN
newtraj = translation[st.traj.reshape(-1)]
newtraj.shape = st.traj.shape
@@ -34,4 +48,7 @@ def run(self, st):
)
if st.real_trajectory is not None:
new_st.set_real_traj(st.real_trajectory)
- return new_st
+ if return_kept_sites:
+ return new_st, np.where(seen_mask)
+ else:
+ return new_st
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index 0a284fd..cf0760e 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -2,6 +2,8 @@
from .MergeSitesByDynamics import MergeSitesByDynamics
from .MergeSitesByThreshold import MergeSitesByThreshold
from .RemoveShortJumps import RemoveShortJumps
+from .RemoveUnoccupiedSites import RemoveUnoccupiedSites
+from .AverageVibrationalFrequency import AverageVibrationalFrequency
# For backwards compatability, since this used to be in this module
from sitator.network import DiffusionPathwayAnalysis
From af51cca39fd5d4e473b3cb2525a2626788acd057 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 14:59:26 -0400
Subject: [PATCH 073/129] Added AverageVibrationalFrequency
---
.../dynamics/AverageVibrationalFrequency.py | 54 +++++++++++++++++++
1 file changed, 54 insertions(+)
create mode 100644 sitator/dynamics/AverageVibrationalFrequency.py
diff --git a/sitator/dynamics/AverageVibrationalFrequency.py b/sitator/dynamics/AverageVibrationalFrequency.py
new file mode 100644
index 0000000..726a983
--- /dev/null
+++ b/sitator/dynamics/AverageVibrationalFrequency.py
@@ -0,0 +1,54 @@
+import numpy as np
+
+class AverageVibrationalFrequency(object):
+ """Compute the average vibrational frequency of indicated atoms in a trajectory.
+
+ Uses the method described in section 2.2 of this paper:
+
+ Klerk, Niek J.J. de, Eveline van der Maas, and Marnix Wagemaker.
+ “Analysis of Diffusion in Solid-State Electrolytes through MD Simulations,
+ Improvement of the Li-Ion Conductivity in β-Li3PS4 as an Example.”
+ ACS Applied Energy Materials 1, no. 7 (July 23, 2018): 3230–42.
+ https://doi.org/10.1021/acsaem.8b00457.
+
+ """
+ def __init__(self,
+ min_frequency = 0,
+ max_frequency = np.inf):
+ # Always want to exclude DC frequency
+ assert min_frequency >= 0
+ self.min_frequency = min_frequency
+ self.max_frequency = max_frequency
+
+ def compute_avg_vibrational_freq(self, traj, mask, return_stdev = False):
+ """Compute the average vibrational frequency.
+
+ Args:
+ - traj (ndarray n_frames x n_atoms x 3)
+ - mask (ndarray n_atoms bool): which atoms to average over.
+ Returns:
+ A frequency in units of (timestep)^-1
+ """
+ speeds = traj[1:, mask]
+ speeds -= traj[:-1, mask]
+ speeds = np.linalg.norm(speeds, axis = 2)
+
+ freqs = np.fft.rfftfreq(speeds.shape[0])
+ fmask = (freqs > self.min_frequency) & (freqs < self.max_frequency)
+ assert np.any(fmask), "Trajectory too short?"
+ freqs = freqs[fmask]
+
+ # de Klerk et. al. do an average over the atom-by-atom averages
+ n_mob = speeds.shape[1]
+ avg_freqs = np.empty(shape = n_mob)
+
+ for mob in range(n_mob):
+ ps = np.abs(np.fft.rfft(speeds[:, mob])) ** 2
+ avg_freqs[mob] = np.average(freqs, weights = ps[fmask])
+
+ avg_freq = np.mean(avg_freqs)
+
+ if return_stdev:
+ return avg_freq, np.std(avg_freqs)
+ else:
+ return avg_freq
From 32ab480ab380fe1c57e51115c3bddf61ee1033b0 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 15:03:49 -0400
Subject: [PATCH 074/129] "Computed" SiteNetwork attributes
---
sitator/SiteNetwork.py | 54 ++++++++++++++++-------
sitator/SiteTrajectory.py | 6 ++-
sitator/dynamics/RemoveShortJumps.py | 2 +-
sitator/dynamics/RemoveUnoccupiedSites.py | 1 +
4 files changed, 43 insertions(+), 20 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 1d9c6fd..79f56d1 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -14,19 +14,33 @@
class SiteNetwork(object):
"""A network of mobile particle sites in a static lattice.
- Stores the locations of sites (`centers`), their defining static atoms (`vertices`),
- and their "types" (`site_types`).
+ Stores the locations of sites (`centers`) for some indicated mobile atoms
+ (`mobile_mask`) in a structure (`structure`). Optionally includes their
+ defining static atoms (`vertices`) and "types" (`site_types`).
Arbitrary data can also be associated with each site and with each edge
between sites. Site data can be any array of length n_sites; edge data can be
any matrix of shape (n_sites, n_sites) where entry i, j is the value for the
- edge from site i to site j.
+ edge from site i to site j (edge attributes can be asymmetric).
+
+ Attributes can be marked as "computed"; this is a hint that the attribute
+ was computed based on a `SiteTrajectory`. Most `sitator` algorithms that
+ modify/process `SiteTrajectory`s will clear "computed" attrbutes,
+ assuming that they are invalidated by the changes to the `SiteTrajectory`.
Attributes:
centers (ndarray): (n_sites, 3) coordinates of each site.
vertices (list, optional): list of lists of indexes of static atoms defining each
site.
site_types (ndarray, optional): (n_sites,) values grouping sites into types.
+
+ Args:
+ structure (Atoms): an ASE/Quippy ``Atoms`` containging whatever atoms exist
+ in the MD trajectory.
+ static_mask (ndarray): Boolean mask indicating which atoms make up the
+ host lattice.
+ mobile_mask (ndarray): Boolean mask indicating which atoms' movement we
+ are interested in.
"""
ATTR_NAME_REGEX = re.compile("^[a-zA-Z][a-zA-Z0-9_]*$")
@@ -35,16 +49,6 @@ def __init__(self,
structure,
static_mask,
mobile_mask):
- """
- Args:
- structure (Atoms): an ASE/Quippy ``Atoms`` containging whatever atoms exist
- in the MD trajectory.
- static_mask (ndarray): Boolean mask indicating which atoms make up the
- host lattice.
- mobile_mask (ndarray): Boolean mask indicating which atoms' movement we
- are interested in.
- """
-
assert static_mask.ndim == mobile_mask.ndim == 1, "The masks must be one-dimensional"
assert len(structure) == len(static_mask) == len(mobile_mask), "The masks must have the same length as the # of atoms in the strucutre."
@@ -69,12 +73,16 @@ def __init__(self,
self._site_attrs = {}
self._edge_attrs = {}
+ self._attr_computed = {}
- def copy(self):
+ def copy(self, with_computed = True):
"""Returns a (shallowish) copy of self."""
# Use a mask to force a copy
msk = np.ones(shape = self.n_sites, dtype = np.bool)
- return self[msk]
+ sn = self[msk]
+ if not with_computed:
+ sn.clear_computed_attributes()
+ return sn
def __len__(self):
return self.n_sites
@@ -164,6 +172,7 @@ def centers(self, value):
self._types = None
self._site_attrs = {}
self._edge_attrs = {}
+ self._attr_computed = {}
# Set centers
self._centers = value
@@ -231,6 +240,15 @@ def remove_attribute(self, attr):
else:
raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attr)
+ def clear_attributes(self):
+ self._site_attrs = {}
+ self._edge_attrs = {}
+
+ def clear_computed_attributes(self):
+ for k, computed in self._attr_computed.items():
+ if computed:
+ self.remove_attribute(k)
+
def __getattr__(self, attrkey):
v = vars(self)
if '_site_attrs' in v and attrkey in self._site_attrs:
@@ -277,21 +295,23 @@ def get_edge(self, edge):
return out
- def add_site_attribute(self, name, attr):
+ def add_site_attribute(self, name, attr, computed = True):
self._check_name(name)
attr = np.asarray(attr)
if not attr.shape[0] == self.n_sites:
raise ValueError("Attribute array has only %i entries; need one for all %i sites." % (len(attr), self.n_sites))
self._site_attrs[name] = attr
+ self._attr_computed[name] = computed
- def add_edge_attribute(self, name, attr):
+ def add_edge_attribute(self, name, attr, computed = True):
self._check_name(name)
attr = np.asarray(attr)
if not (attr.shape[0] == attr.shape[1] == self.n_sites):
raise ValueError("Attribute matrix has shape %s; need first two dimensions to be %i" % (attr.shape, self.n_sites))
self._edge_attrs[name] = attr
+ self._attr_computed[name] = computed
def _check_name(self, name):
if not SiteNetwork.ATTR_NAME_REGEX.match(name):
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index c375443..6b40cda 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -88,8 +88,10 @@ def site_network(self, value):
def real_trajectory(self):
return self._real_traj
- def copy(self):
- return self[:]
+ def copy(self, with_computed = True):
+ st = self[:]
+ st.site_network = st.site_network.copy(with_computed = with_computed)
+ return st
def set_real_traj(self, real_traj):
"""Assocaite this SiteTrajectory with a trajectory of points in real space.
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
index de27438..8d00713 100644
--- a/sitator/dynamics/RemoveShortJumps.py
+++ b/sitator/dynamics/RemoveShortJumps.py
@@ -123,7 +123,7 @@ def run(self,
)
)
- st = st.copy()
+ st = st.copy(with_computed = False)
st._traj = out
if self.remove_unoccupied_sites:
# Removing short jumps could have made some sites completely unoccupied
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
index bb5a6e8..c0175b4 100644
--- a/sitator/dynamics/RemoveUnoccupiedSites.py
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -40,6 +40,7 @@ def run(self, st, return_kept_sites = False):
newtraj = translation[st.traj.reshape(-1)]
newtraj.shape = st.traj.shape
+ # We don't clear computed attributes since nothing is changing for other sites.
newsn = old_sn[seen_mask]
new_st = SiteTrajectory(
From f18149ac2b166f15d431616e96f22a65110dc099 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 15:09:12 -0400
Subject: [PATCH 075/129] Colormap change (again)
---
sitator/visualization/SiteNetworkPlotter.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 2c5a1b8..48def0a 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -101,7 +101,7 @@ def __call__(self, sn, *args, **kwargs):
def _site_layers(self, sn, plot_points_params, same_normalization = False):
pts_arrays = {'points' : sn.centers}
- pts_params = {'cmap' : 'cividis'}
+ pts_params = {'cmap' : 'winter'}
# -- Apply mapping
# - other mappings
From 927a1864fdc835b3628d8be3fc7f530a3436cbde Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 15:42:55 -0400
Subject: [PATCH 076/129] Import fix
---
sitator/dynamics/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index cf0760e..ac88173 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,8 +1,8 @@
from .JumpAnalysis import JumpAnalysis
from .MergeSitesByDynamics import MergeSitesByDynamics
from .MergeSitesByThreshold import MergeSitesByThreshold
-from .RemoveShortJumps import RemoveShortJumps
from .RemoveUnoccupiedSites import RemoveUnoccupiedSites
+from .RemoveShortJumps import RemoveShortJumps
from .AverageVibrationalFrequency import AverageVibrationalFrequency
# For backwards compatability, since this used to be in this module
From 09a15f902177f28292304b81dc7c26387ceb9ece Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 17:53:19 -0400
Subject: [PATCH 077/129] Added Documentation (#7)
---
README.md | 3 +
docs/Makefile | 20 ++++
docs/make.bat | 35 ++++++
docs/source/conf.py | 54 +++++++++
docs/source/index.rst | 20 ++++
docs/source/modules.rst | 7 ++
docs/source/sitator.descriptors.rst | 15 +++
docs/source/sitator.dynamics.rst | 40 +++++++
docs/source/sitator.landmark.cluster.rst | 30 +++++
docs/source/sitator.landmark.rst | 42 +++++++
docs/source/sitator.misc.rst | 34 ++++++
docs/source/sitator.network.rst | 25 +++++
docs/source/sitator.rst | 38 +++++++
.../sitator.site_descriptors.backend.rst | 30 +++++
docs/source/sitator.site_descriptors.rst | 43 ++++++++
docs/source/sitator.util.rst | 61 +++++++++++
docs/source/sitator.visualization.rst | 40 +++++++
docs/source/sitator.voronoi.rst | 10 ++
sitator/SiteNetwork.py | 87 ++++++++++++---
sitator/SiteTrajectory.py | 80 ++++++++++----
sitator/descriptors/ConfigurationalEntropy.py | 15 ++-
.../dynamics/AverageVibrationalFrequency.py | 15 ++-
sitator/dynamics/JumpAnalysis.py | 41 ++++---
sitator/dynamics/MergeSitesByDynamics.py | 6 +-
sitator/dynamics/MergeSitesByThreshold.py | 16 +--
sitator/dynamics/RemoveShortJumps.py | 11 +-
sitator/dynamics/RemoveUnoccupiedSites.py | 8 +-
sitator/landmark/LandmarkAnalysis.py | 103 +++++++++---------
sitator/landmark/errors.py | 9 +-
sitator/misc/GenerateAroundSites.py | 7 +-
sitator/misc/GenerateClampedTrajectory.py | 16 +--
sitator/misc/MergeSitesByBarrier.py | 0
sitator/misc/NAvgsPerSite.py | 14 ++-
sitator/network/DiffusionPathwayAnalysis.py | 13 +--
sitator/network/MergeSitesByBarrier.py | 25 +++--
sitator/network/merging.py | 14 +--
sitator/site_descriptors/SOAP.py | 44 ++++----
.../SiteCoordinationEnvironment.py | 27 ++++-
sitator/site_descriptors/SiteTypeAnalysis.py | 25 ++++-
sitator/site_descriptors/SiteVolumes.py | 21 ++--
sitator/util/DotProdClassifier.pyx | 23 ++--
sitator/util/PBCCalculator.pyx | 17 +--
sitator/util/RecenterTrajectory.pyx | 2 +-
sitator/util/elbow.py | 2 +-
sitator/util/mcl.py | 2 +-
sitator/util/zeo.py | 4 +-
sitator/visualization/SiteNetworkPlotter.py | 24 ++--
.../visualization/SiteTrajectoryPlotter.py | 20 ++++
.../VoronoiSiteGenerator.py => voronoi.py} | 12 +-
sitator/voronoi/__init__.py | 2 -
50 files changed, 1002 insertions(+), 250 deletions(-)
create mode 100644 docs/Makefile
create mode 100644 docs/make.bat
create mode 100644 docs/source/conf.py
create mode 100644 docs/source/index.rst
create mode 100644 docs/source/modules.rst
create mode 100644 docs/source/sitator.descriptors.rst
create mode 100644 docs/source/sitator.dynamics.rst
create mode 100644 docs/source/sitator.landmark.cluster.rst
create mode 100644 docs/source/sitator.landmark.rst
create mode 100644 docs/source/sitator.misc.rst
create mode 100644 docs/source/sitator.network.rst
create mode 100644 docs/source/sitator.rst
create mode 100644 docs/source/sitator.site_descriptors.backend.rst
create mode 100644 docs/source/sitator.site_descriptors.rst
create mode 100644 docs/source/sitator.util.rst
create mode 100644 docs/source/sitator.visualization.rst
create mode 100644 docs/source/sitator.voronoi.rst
delete mode 100644 sitator/misc/MergeSitesByBarrier.py
rename sitator/{voronoi/VoronoiSiteGenerator.py => voronoi.py} (73%)
delete mode 100644 sitator/voronoi/__init__.py
diff --git a/README.md b/README.md
index 3b4d5a6..4a9f991 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,9 @@ pip install ".[SiteTypeAnalysis]"
Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/).
+`sitator` generally assumes units of femtoseconds for time, Angstroms for space,
+and Cartesian (not crystal) coordinates.
+
All individual classes and parameters are documented with docstrings in the source code.
## Global Options
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..d0c3cbf
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000..6247f7e
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 0000000..af7b379
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,54 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+
+# -- Project information -----------------------------------------------------
+
+project = 'sitator'
+copyright = '2019, Alby Musaelian'
+author = 'Alby Musaelian'
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.napoleon'
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'alabaster'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 0000000..e2f4d80
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,20 @@
+.. sitator documentation master file, created by
+ sphinx-quickstart on Tue Jul 16 15:15:41 2019.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to sitator's documentation!
+===================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
new file mode 100644
index 0000000..244d585
--- /dev/null
+++ b/docs/source/modules.rst
@@ -0,0 +1,7 @@
+sitator
+=======
+
+.. toctree::
+ :maxdepth: 4
+
+ sitator
diff --git a/docs/source/sitator.descriptors.rst b/docs/source/sitator.descriptors.rst
new file mode 100644
index 0000000..aeb7b2e
--- /dev/null
+++ b/docs/source/sitator.descriptors.rst
@@ -0,0 +1,15 @@
+sitator.descriptors package
+===========================
+
+Module contents
+---------------
+
+.. automodule:: sitator.descriptors
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.descriptors.ConfigurationalEntropy
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.dynamics.rst b/docs/source/sitator.dynamics.rst
new file mode 100644
index 0000000..c2510ac
--- /dev/null
+++ b/docs/source/sitator.dynamics.rst
@@ -0,0 +1,40 @@
+sitator.dynamics package
+========================
+
+Module contents
+---------------
+
+.. automodule:: sitator.dynamics
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.AverageVibrationalFrequency
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.JumpAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.MergeSitesByDynamics
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.MergeSitesByThreshold
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.RemoveShortJumps
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.RemoveUnoccupiedSites
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.landmark.cluster.rst b/docs/source/sitator.landmark.cluster.rst
new file mode 100644
index 0000000..c332d86
--- /dev/null
+++ b/docs/source/sitator.landmark.cluster.rst
@@ -0,0 +1,30 @@
+sitator.landmark.cluster package
+================================
+
+Submodules
+----------
+
+sitator.landmark.cluster.dotprod module
+---------------------------------------
+
+.. automodule:: sitator.landmark.cluster.dotprod
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.landmark.cluster.mcl module
+-----------------------------------
+
+.. automodule:: sitator.landmark.cluster.mcl
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.landmark.cluster
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.landmark.rst b/docs/source/sitator.landmark.rst
new file mode 100644
index 0000000..8bf446a
--- /dev/null
+++ b/docs/source/sitator.landmark.rst
@@ -0,0 +1,42 @@
+sitator.landmark package
+========================
+
+Subpackages
+-----------
+
+.. toctree::
+
+ sitator.landmark.cluster
+
+Submodules
+----------
+
+sitator.landmark.errors module
+------------------------------
+
+.. automodule:: sitator.landmark.errors
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.landmark.helpers module
+-------------------------------
+
+.. automodule:: sitator.landmark.helpers
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.landmark
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.landmark.LandmarkAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.misc.rst b/docs/source/sitator.misc.rst
new file mode 100644
index 0000000..3b441d0
--- /dev/null
+++ b/docs/source/sitator.misc.rst
@@ -0,0 +1,34 @@
+sitator.misc package
+====================
+
+sitator.misc.oldio module
+-------------------------
+
+.. automodule:: sitator.misc.oldio
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.misc
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.misc.GenerateAroundSites
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.misc.GenerateClampedTrajectory
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.misc.NAvgsPerSite
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.network.rst b/docs/source/sitator.network.rst
new file mode 100644
index 0000000..be7fb28
--- /dev/null
+++ b/docs/source/sitator.network.rst
@@ -0,0 +1,25 @@
+sitator.network package
+=======================
+
+Module contents
+---------------
+
+.. automodule:: sitator.network
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.network.DiffusionPathwayAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.network.MergeSitesByBarrier
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.network.merging
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.rst b/docs/source/sitator.rst
new file mode 100644
index 0000000..8662328
--- /dev/null
+++ b/docs/source/sitator.rst
@@ -0,0 +1,38 @@
+sitator package
+===============
+
+Subpackages
+-----------
+
+.. toctree::
+
+ sitator.descriptors
+ sitator.dynamics
+ sitator.landmark
+ sitator.misc
+ sitator.network
+ sitator.site_descriptors
+ sitator.util
+ sitator.visualization
+ sitator.voronoi
+
+Submodules
+----------
+
+Module contents
+---------------
+
+.. automodule:: sitator
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.SiteNetwork
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.SiteTrajectory
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.site_descriptors.backend.rst b/docs/source/sitator.site_descriptors.backend.rst
new file mode 100644
index 0000000..a7afe2f
--- /dev/null
+++ b/docs/source/sitator.site_descriptors.backend.rst
@@ -0,0 +1,30 @@
+sitator.site\_descriptors.backend package
+=========================================
+
+Submodules
+----------
+
+sitator.site\_descriptors.backend.dscribe module
+------------------------------------------------
+
+.. automodule:: sitator.site_descriptors.backend.dscribe
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.site\_descriptors.backend.quip module
+---------------------------------------------
+
+.. automodule:: sitator.site_descriptors.backend.quip
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.site_descriptors.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.site_descriptors.rst b/docs/source/sitator.site_descriptors.rst
new file mode 100644
index 0000000..4275695
--- /dev/null
+++ b/docs/source/sitator.site_descriptors.rst
@@ -0,0 +1,43 @@
+sitator.site\_descriptors package
+=================================
+
+Subpackages
+-----------
+
+.. toctree::
+
+ sitator.site_descriptors.backend
+
+Submodules
+----------
+
+sitator.site\_descriptors.SOAP module
+-------------------------------------
+
+.. automodule:: sitator.site_descriptors.SOAP
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: sitator.site_descriptors
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.site_descriptors.SiteCoordinationEnvironment
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.site_descriptors.SiteTypeAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.site_descriptors.SiteVolumes
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.util.rst b/docs/source/sitator.util.rst
new file mode 100644
index 0000000..1f95c20
--- /dev/null
+++ b/docs/source/sitator.util.rst
@@ -0,0 +1,61 @@
+sitator.util package
+====================
+
+Submodules
+----------
+
+sitator.util.elbow module
+-------------------------
+
+.. automodule:: sitator.util.elbow
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.util.mcl module
+-----------------------
+
+.. automodule:: sitator.util.mcl
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.util.progress module
+----------------------------
+
+.. automodule:: sitator.util.progress
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.util.zeo module
+-----------------------
+
+.. automodule:: sitator.util.zeo
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.util
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.util.DotProdClassifier
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.util.PBCCalculator
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.util.RecenterTrajectory
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.visualization.rst b/docs/source/sitator.visualization.rst
new file mode 100644
index 0000000..9596867
--- /dev/null
+++ b/docs/source/sitator.visualization.rst
@@ -0,0 +1,40 @@
+sitator.visualization package
+=============================
+
+Submodules
+----------
+
+sitator.visualization.atoms module
+----------------------------------
+
+.. automodule:: sitator.visualization.atoms
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.visualization.common module
+-----------------------------------
+
+.. automodule:: sitator.visualization.common
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.visualization
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.visualization.SiteNetworkPlotter
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.visualization.SiteTrajectoryPlotter
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.voronoi.rst b/docs/source/sitator.voronoi.rst
new file mode 100644
index 0000000..adad362
--- /dev/null
+++ b/docs/source/sitator.voronoi.rst
@@ -0,0 +1,10 @@
+sitator.voronoi package
+=======================
+
+Module contents
+---------------
+
+.. automodule:: sitator.voronoi
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 79f56d1..ed89aab 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -14,9 +14,9 @@
class SiteNetwork(object):
"""A network of mobile particle sites in a static lattice.
- Stores the locations of sites (`centers`) for some indicated mobile atoms
- (`mobile_mask`) in a structure (`structure`). Optionally includes their
- defining static atoms (`vertices`) and "types" (`site_types`).
+ Stores the locations of sites (``centers``) for some indicated mobile atoms
+ (``mobile_mask``) in a structure (``structure``). Optionally includes their
+ defining static atoms (``vertices``) and "types" (``site_types``).
Arbitrary data can also be associated with each site and with each edge
between sites. Site data can be any array of length n_sites; edge data can be
@@ -24,9 +24,9 @@ class SiteNetwork(object):
edge from site i to site j (edge attributes can be asymmetric).
Attributes can be marked as "computed"; this is a hint that the attribute
- was computed based on a `SiteTrajectory`. Most `sitator` algorithms that
- modify/process `SiteTrajectory`s will clear "computed" attrbutes,
- assuming that they are invalidated by the changes to the `SiteTrajectory`.
+ was computed based on a ``SiteTrajectory``. Most ``sitator`` algorithms that
+ modify/process ``SiteTrajectory``s will clear "computed" attrbutes,
+ assuming that they are invalidated by the changes to the ``SiteTrajectory``.
Attributes:
centers (ndarray): (n_sites, 3) coordinates of each site.
@@ -35,7 +35,7 @@ class SiteNetwork(object):
site_types (ndarray, optional): (n_sites,) values grouping sites into types.
Args:
- structure (Atoms): an ASE/Quippy ``Atoms`` containging whatever atoms exist
+ structure (ase.Atoms): an ASE ``Atoms`` containging whatever atoms exist
in the MD trajectory.
static_mask (ndarray): Boolean mask indicating which atoms make up the
host lattice.
@@ -76,7 +76,14 @@ def __init__(self,
self._attr_computed = {}
def copy(self, with_computed = True):
- """Returns a (shallowish) copy of self."""
+ """Returns a (shallowish) copy of self.
+
+ Args:
+ with_computed (bool): If ``False``, attributes marked "computed" will
+ not be included in the copy.
+ Returns:
+ A ``SiteNetwork``.
+ """
# Use a mask to force a copy
msk = np.ones(shape = self.n_sites, dtype = np.bool)
sn = self[msk]
@@ -114,7 +121,13 @@ def __getitem__(self, key):
return sn
def of_type(self, stype):
- """Returns a "view" to this SiteNetwork with only sites of a certain type."""
+ """Returns a subset of this ``SiteNetwork`` with only sites of a certain type.
+
+ Args:
+ stype (int)
+ Returns:
+ A ``SiteNetwork``.
+ """
if self._types is None:
raise ValueError("This SiteNetwork has no type information.")
@@ -124,16 +137,16 @@ def of_type(self, stype):
return self[self._types == stype]
def get_structure_with_sites(self, site_atomic_number = None):
- """Get an `ase.Atoms` with the sites included.
+ """Get an ``ase.Atoms`` with the sites included.
- Sites are appended to the static structure; the first `np.sum(static_mask)`
+ Sites are appended to the static structure; the first ``np.sum(static_mask)``
atoms in the returned object are the static structure.
Args:
- - site_atomic_number: If `None`, the species of the first mobile atom
- will be used.
+ site_atomic_number: If ``None``, the species of the first mobile
+ atom will be used.
Returns:
- ase.Atoms, indices of sites in the returned structure, and final `site_atomic_number`
+ ``ase.Atoms`` and final ``site_atomic_number``
"""
out = self.static_structure.copy()
if site_atomic_number is None:
@@ -149,16 +162,19 @@ def get_structure_with_sites(self, site_atomic_number = None):
@property
def n_sites(self):
+ """The number of sites."""
if self._centers is None:
return 0
return len(self._centers)
@property
def n_total(self):
+ """The total number of atoms in the system."""
return len(self.static_mask)
@property
def centers(self):
+ """The positions of the sites."""
view = self._centers.view()
view.flags.writeable = False
return view
@@ -177,13 +193,18 @@ def centers(self, value):
self._centers = value
def update_centers(self, newcenters):
- """Update the SiteNetwork's centers *without* reseting all other information."""
+ """Update the ``SiteNetwork``'s centers *without* reseting all other information.
+
+ Args:
+ newcenters (ndarray): Must have same length as current number of sites.
+ """
if newcenters.shape != self._centers.shape:
raise ValueError("New `centers` must have same shape as old; try using the setter `.centers = ...`")
self._centers = newcenters
@property
def vertices(self):
+ """The static atoms defining each site."""
return self._vertices
@vertices.setter
@@ -194,6 +215,7 @@ def vertices(self, value):
@property
def number_of_vertices(self):
+ """The number of vertices of each site."""
if self._vertices is None:
return None
else:
@@ -201,6 +223,7 @@ def number_of_vertices(self):
@property
def site_types(self):
+ """The type IDs of each site."""
if self._types is None:
return None
view = self._types.view()
@@ -215,24 +238,40 @@ def site_types(self, value):
@property
def n_types(self):
+ """The number of site types in the ``SiteNetwork``."""
return len(np.unique(self.site_types))
@property
def types(self):
+ """The unique site type IDs in the ``SiteNetwork``."""
return np.unique(self.site_types)
@property
def site_attributes(self):
+ """The names of the ``SiteNetwork``'s site attributes."""
return list(self._site_attrs.keys())
@property
def edge_attributes(self):
+ """The names of the ``SiteNetwork``'s edge attributes."""
return list(self._edge_attrs.keys())
def has_attribute(self, attr):
+ """Whether the ``SiteNetwork`` has a given site or edge attrbute.
+
+ Args:
+ attr (str)
+ Returns:
+ bool
+ """
return (attr in self._site_attrs) or (attr in self._edge_attrs)
def remove_attribute(self, attr):
+ """Remove a site or edge attribute.
+
+ Args:
+ attr (str)
+ """
if attr in self._site_attrs:
del self._site_attrs[attr]
elif attr in self._edge_attrs:
@@ -241,10 +280,12 @@ def remove_attribute(self, attr):
raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attr)
def clear_attributes(self):
+ """Remove all site and edge attributes."""
self._site_attrs = {}
self._edge_attrs = {}
def clear_computed_attributes(self):
+ """Remove all attributes marked "computed"."""
for k, computed in self._attr_computed.items():
if computed:
self.remove_attribute(k)
@@ -296,6 +337,13 @@ def get_edge(self, edge):
return out
def add_site_attribute(self, name, attr, computed = True):
+ """Add a site attribute.
+
+ Args:
+ name (str)
+ attr (ndarray): Must be of length ``n_sites``.
+ computed (bool): Whether to mark this attribute as "computed".
+ """
self._check_name(name)
attr = np.asarray(attr)
if not attr.shape[0] == self.n_sites:
@@ -305,6 +353,13 @@ def add_site_attribute(self, name, attr, computed = True):
self._attr_computed[name] = computed
def add_edge_attribute(self, name, attr, computed = True):
+ """Add an edge attribute.
+
+ Args:
+ name (str)
+ attr (ndarray): Must be of shape ``(n_sites, n_sites)``.
+ computed (bool): Whether to mark this attribute as "computed".
+ """
self._check_name(name)
attr = np.asarray(attr)
if not (attr.shape[0] == attr.shape[1] == self.n_sites):
@@ -322,6 +377,6 @@ def _check_name(self, name):
raise ValueError("Attribute name `%s` reserved." % name)
def plot(self, *args, **kwargs):
- """Convenience method -- constructs a defualt SiteNetworkPlotter and calls it."""
+ """Convenience method -- constructs a defualt ``SiteNetworkPlotter`` and calls it."""
p = SiteNetworkPlotter(title = "Sites")
p(self, *args, **kwargs)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 6b40cda..de1c02e 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -54,23 +54,27 @@ def __getitem__(self, key):
@property
def traj(self):
- """The underlying trajectory."""
+ """The site assignments over time."""
return self._traj
@property
def n_frames(self):
+ """The number of frames in the trajectory."""
return len(self._traj)
@property
def n_unassigned(self):
+ """The total number of times a mobile particle is unassigned."""
return np.sum(self._traj < 0)
@property
def n_assigned(self):
+ """The total number of times a mobile particle was assigned to a site."""
return self._sn.n_mobile * self.n_frames - self.n_unassigned
@property
def percent_unassigned(self):
+ """Proportion of particle positions that are unassigned over all time."""
return float(self.n_unassigned) / (self._sn.n_mobile * self.n_frames)
@property
@@ -86,9 +90,15 @@ def site_network(self, value):
@property
def real_trajectory(self):
+ """The real-space trajectory this ``SiteTrajectory`` is based on."""
return self._real_traj
def copy(self, with_computed = True):
+ """Return a copy.
+
+ Args:
+ with_computed (bool): See ``SiteNetwork.copy()``.
+ """
st = self[:]
st.site_network = st.site_network.copy(with_computed = with_computed)
return st
@@ -96,7 +106,10 @@ def copy(self, with_computed = True):
def set_real_traj(self, real_traj):
"""Assocaite this SiteTrajectory with a trajectory of points in real space.
- The trajectory is not copied, and should have shape (n_frames, n_total)
+ The trajectory is not copied.
+
+ Args:
+ real_traj (ndarray of shape (n_frames, n_total))
"""
expected_shape = (self.n_frames, self._sn.n_total, 3)
if not real_traj.shape == expected_shape:
@@ -111,7 +124,15 @@ def remove_real_traj(self):
def trajectory_for_particle(self, i, return_confidences = False):
- """Returns the array of sites particle i is assigned to over time."""
+ """Returns the array of sites particle i is assigned to over time.
+
+ Args:
+ i (int)
+ return_confidences (bool): If ``True``, also return the confidences
+ with which those assignments were made.
+ Returns:
+ ndarray (int) of length ``n_frames``[, ndarray (float) length ``n_frames``]
+ """
if return_confidences and self._confs is None:
raise ValueError("This SiteTrajectory has no confidences")
if return_confidences:
@@ -121,6 +142,17 @@ def trajectory_for_particle(self, i, return_confidences = False):
def real_positions_for_site(self, site, return_confidences = False):
+ """Get all real-space positions assocated with a site.
+
+ Args:
+ site (int)
+ return_confidences (bool): If ``True``, the confidences with which
+ each real-space position was assigned to ``site`` are also
+ returned.
+
+ Returns:
+ ndarray (N, 3)[, ndarray (N)]
+ """
if self._real_traj is None:
raise ValueError("This SiteTrajectory has no real trajectory")
if return_confidences and self._confs is None:
@@ -139,10 +171,15 @@ def real_positions_for_site(self, site, return_confidences = False):
def compute_site_occupancies(self):
- """Computes site occupancies and adds site attribute `occupancies` to site_network.
+ """Computes site occupancies.
+
+ Adds site attribute ``occupancies`` to ``site_network``.
In cases of multiple occupancy, this will be higher than the number of
frames in which the site is occupied and could be over 1.0.
+
+ Returns:
+ ndarray of occupancies (length ``n_sites``)
"""
occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames)
if self.site_network.has_attribute('occupancies'):
@@ -157,9 +194,8 @@ def check_multiple_occupancy(self, max_mobile_per_site = 1):
These cases usually indicate bad site analysis.
Returns:
- - n_multiple_assignments (int): the total number of multiple assignment
- incidents.
- - avg_mobile_per_site (float): the average number of mobile atoms
+ int: the total number of multiple assignment incidents; and
+ float: the average number of mobile atoms at any site at any one time.
"""
from sitator.landmark.errors import MultipleOccupancyError
n_more_than_ones = 0
@@ -178,10 +214,15 @@ def check_multiple_occupancy(self, max_mobile_per_site = 1):
def assign_to_last_known_site(self, frame_threshold = 1):
- """Assign unassigned mobile particles to their last known site within
- `frame_threshold` frames.
+ """Assign unassigned mobile particles to their last known site.
+
+ Args:
+ frame_threshold (int): The maximum number of frames between the last
+ known site and the present frame up to which the last known site
+ can be used.
- :returns: information dictionary of debugging/diagnostic information.
+ Returns:
+ information dictionary of debugging/diagnostic information.
"""
total_unknown = self.n_unassigned
@@ -248,18 +289,19 @@ def jumps(self, unknown_as_jump = False):
"""Generator to iterate over all jumps in the trajectory.
A jump is considered to occur "at the frame" when it first acheives its
- new site. Ex:
- Frame 0: Atom 1 at site 4 --> Frame 1: Atom 1 at site 5
- will yield a jump (1, 1, 4, 5).
-
- Yields tuples of the form:
+ new site. For example,
- (frame_number, mobile_atom_number, from_site, to_site)
+ - Frame 0: Atom 1 at site 4
+ - Frame 1: Atom 1 at site 5
+
+ will yield a jump ``(1, 1, 4, 5)``.
Args:
- - unknown_as_jump (bool): If True, moving from a site to unknown
- (or vice versa) is considered a jump; if False, unassigned mobile
- atoms are considered to be at their last known sites.
+ unknown_as_jump (bool): If ``True``, moving from a site to unknown
+ (or vice versa) is considered a jump; if ``False``, unassigned
+ mobile atoms are considered to be at their last known sites.
+ Yields:
+ tuple: (frame_number, mobile_atom_number, from_site, to_site)
"""
traj = self.traj
n_mobile = self.site_network.n_mobile
diff --git a/sitator/descriptors/ConfigurationalEntropy.py b/sitator/descriptors/ConfigurationalEntropy.py
index de61258..4008588 100644
--- a/sitator/descriptors/ConfigurationalEntropy.py
+++ b/sitator/descriptors/ConfigurationalEntropy.py
@@ -7,18 +7,23 @@
logger = logging.getLogger(__name__)
class ConfigurationalEntropy(object):
- """Compute the S~ configurational entropy.
+ """Compute the S~ (S tilde) configurational entropy.
If the SiteTrajectory lacks type information, the summation is taken over
the sites rather than the site types.
- Ref:
- Structural, Chemical, and Dynamical Frustration: Origins of Superionic Conductivity in closo-Borate Solid Electrolytes
- Kyoung E. Kweon, Joel B. Varley, Patrick Shea, Nicole Adelstein, Prateek Mehta, Tae Wook Heo, Terrence J. Udovic, Vitalie Stavila, and Brandon C. Wood
+ Reference:
+ Structural, Chemical, and Dynamical Frustration: Origins of Superionic
+ Conductivity in closo-Borate Solid Electrolytes
+
+ Kyoung E. Kweon, Joel B. Varley, Patrick Shea, Nicole Adelstein,
+ Prateek Mehta, Tae Wook Heo, Terrence J. Udovic, Vitalie Stavila,
+ and Brandon C. Wood
+
Chemistry of Materials 2017 29 (21), 9142-9153
DOI: 10.1021/acs.chemmater.7b02902
"""
- def __init__(self, acceptable_overshoot = 0.0001):
+ def __init__(self, acceptable_overshoot = 0.0):
self.acceptable_overshoot = acceptable_overshoot
def compute(self, st):
diff --git a/sitator/dynamics/AverageVibrationalFrequency.py b/sitator/dynamics/AverageVibrationalFrequency.py
index 726a983..88b6d34 100644
--- a/sitator/dynamics/AverageVibrationalFrequency.py
+++ b/sitator/dynamics/AverageVibrationalFrequency.py
@@ -6,11 +6,18 @@ class AverageVibrationalFrequency(object):
Uses the method described in section 2.2 of this paper:
Klerk, Niek J.J. de, Eveline van der Maas, and Marnix Wagemaker.
- “Analysis of Diffusion in Solid-State Electrolytes through MD Simulations,
- Improvement of the Li-Ion Conductivity in β-Li3PS4 as an Example.”
+
+ Analysis of Diffusion in Solid-State Electrolytes through MD Simulations,
+ Improvement of the Li-Ion Conductivity in β-Li3PS4 as an Example.
+
ACS Applied Energy Materials 1, no. 7 (July 23, 2018): 3230–42.
https://doi.org/10.1021/acsaem.8b00457.
+ Args:
+ min_frequency (float, units: timestep^-1): Compute mean frequency of power
+ spectrum above this frequency.
+ max_frequency (float, units: timestep^-1): Compute mean frequency of power
+ spectrum below this frequency.
"""
def __init__(self,
min_frequency = 0,
@@ -24,8 +31,8 @@ def compute_avg_vibrational_freq(self, traj, mask, return_stdev = False):
"""Compute the average vibrational frequency.
Args:
- - traj (ndarray n_frames x n_atoms x 3)
- - mask (ndarray n_atoms bool): which atoms to average over.
+ traj (ndarray n_frames x n_atoms x 3): An MD trajectory.
+ mask (ndarray n_atoms bool): Which atoms to average over.
Returns:
A frequency in units of (timestep)^-1
"""
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index 86621d8..639b852 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -12,13 +12,13 @@ class JumpAnalysis(object):
"""Given a SiteTrajectory, compute various statistics about the jumps it contains.
Adds these edge attributes to the SiteTrajectory's SiteNetwork:
- - `n_ij`: total number of jumps from i to j.
- - `p_ij`: being at i, the probability of jumping to j.
- - `jump_lag`: The average number of frames a particle spends at i before jumping
+ - ``n_ij``: total number of jumps from i to j.
+ - ``p_ij``: being at i, the probability of jumping to j.
+ - ``jump_lag``: The average number of frames a particle spends at i before jumping
to j. Can be +inf if no such jumps every occur.
And these site attributes:
- - `residence_times`: Avg. number of frames a particle spends at a site before jumping.
- - `total_corrected_residences`: Total number of frames when a particle was at the site,
+ - ``residence_times``: Avg. number of frames a particle spends at a site before jumping.
+ - ``total_corrected_residences``: Total number of frames when a particle was at the site,
*including* frames when an unassigned particle's last known site was this site.
"""
def __init__(self):
@@ -27,7 +27,13 @@ def __init__(self):
def run(self, st):
"""Run the analysis.
- Adds edge attributes to st's SiteNetwork and returns st.
+ Adds edge attributes to ``st``'s ``SiteNetwork``.
+
+ Args:
+ st (SiteTrajectory)
+
+ Returns:
+ ``st``
"""
assert isinstance(st, SiteTrajectory)
@@ -135,18 +141,23 @@ def run(self, st):
def jump_lag_by_type(self,
sn,
return_counts = False):
- """Given a SiteNetwork with jump_lag info, compute avg. residence times by type
+ """Given a SiteNetwork with jump_lag info, compute avg. residence times by type.
Computes the average number of frames a mobile particle spends at each
type of site before jumping to each other type of site.
- Returns an (n_types, n_types) matrix. If no jumps of a given type occured,
+ Args:
+ sn (SiteNetwork)
+ return_counts (bool): Whether to also return a matrix giving the
+ number of each type of jump that occured.
+
+
+ Returns:
+ An (n_types, n_types) matrix. If no jumps of a given type occured,
the corresponding entry is +inf.
- If ``return_counts``, then also returns a matrix giving the number of
- each type of jump that occured.
+ If ``return_counts``, two such matrixes.
"""
-
if sn.site_types is None:
raise ValueError("SiteNetwork has no type information.")
@@ -182,13 +193,11 @@ def plot_jump_lag(self, sn, mode = 'site', min_n_events = 1, ax = None, fig = No
"""Plot the jump lag of a site network.
:param SiteNetwork sn:
- :param str mode: If 'site', show jump lag between individual sites.
- If 'type', show jump lag between types of sites (see :func:jump_lag_by_type)
- Default: 'site'
+ :param str mode: If ``'site'``, show jump lag between individual sites.
+ If ``'type'``, show jump lag between types of sites (see :func:jump_lag_by_type)
:param int min_n_events: Minimum number of jump events of a given type
(i -> j or type -> type) to show a jump lag. If a given jump has
- occured less than min_n_events times, no jump lag will be shown.
- Default: 1
+ occured less than ``min_n_events`` times, no jump lag will be shown.
"""
if mode == 'site':
mat = np.copy(sn.jump_lag)
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index e1f1b49..c8950e5 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -75,10 +75,10 @@ def connectivity_jump_lag_biased(jump_lag_coeff = 1.0,
The jump lag and distance are processed through Gaussian functions with
the given sigmas (i.e. higher jump lag/larger distance => lower
connectivity value). These matrixes are then added to p_ij, with a prefactor
- of `jump_lag_coeff` and `distance_coeff`.
+ of ``jump_lag_coeff`` and ``distance_coeff``.
- Site pairs with jump lags greater than `jump_lag_cutoff` have their bias
- set to zero regardless of `jump_lag_sigma`. Defaults to `inf`.
+ Site pairs with jump lags greater than ``jump_lag_cutoff`` have their bias
+ set to zero regardless of ``jump_lag_sigma``. Defaults to ``inf``.
"""
def cfunc(sn):
jl = sn.jump_lag.copy()
diff --git a/sitator/dynamics/MergeSitesByThreshold.py b/sitator/dynamics/MergeSitesByThreshold.py
index 258f560..449c236 100644
--- a/sitator/dynamics/MergeSitesByThreshold.py
+++ b/sitator/dynamics/MergeSitesByThreshold.py
@@ -11,19 +11,19 @@
class MergeSitesByThreshold(MergeSites):
"""Merge sites using a strict threshold on any edge property.
- Takes the edge property matrix given by `attrname`, applys `relation` to it
- with `threshold`, and merges all connected components in the graph represented
+ Takes the edge property matrix given by ``attrname``, applys ``relation`` to it
+ with ``threshold``, and merges all connected components in the graph represented
by the resulting boolean adjacency matrix.
- Threshold is given by a keyword argument to `run()`.
+ Threshold is given by a keyword argument to ``run()``.
Args:
- - attrname (str): Name of the edge attribute to merge on.
- - relation (func, default: operator.ge): The relation to use for the
+ attrname (str): Name of the edge attribute to merge on.
+ relation (func, default: ``operator.ge``): The relation to use for the
thresholding.
- - directed, connection (bool, str): Parameters for scipy.sparse.csgraph's
- `connected_components`.
- - **kwargs: Passed to `MergeSites`.
+ directed, connection (bool, str): Parameters for ``scipy.sparse.csgraph``'s
+ ``connected_components``.
+ **kwargs: Passed to ``MergeSites``.
"""
def __init__(self,
attrname,
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
index 8d00713..d53bfda 100644
--- a/sitator/dynamics/RemoveShortJumps.py
+++ b/sitator/dynamics/RemoveShortJumps.py
@@ -16,7 +16,7 @@ class RemoveShortJumps(object):
jumped from.
Args:
- - only_returning_jumps (bool, default: True): If True, only short jumps
+ only_returning_jumps (bool): If True, only short jumps
where the mobile atom returns to its initial site will be removed.
"""
def __init__(self,
@@ -30,12 +30,15 @@ def run(self,
st,
threshold,
return_stats = False):
- """Returns a copy of `st` with short jumps removed.
+ """Returns a copy of ``st`` with short jumps removed.
Args:
- - st (SiteTrajectory): Unassigned considered to be last known.
- - threshold (int): The largest number of frames the mobile atom
+ st (SiteTrajectory): Unassigned considered to be last known.
+ threshold (int): The largest number of frames the mobile atom
can spend at a site while the jump is still considered short.
+
+ Returns:
+ A ``SiteTrajectory``.
"""
n_mobile = st.site_network.n_mobile
n_frames = st.n_frames
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
index c0175b4..e7dafb3 100644
--- a/sitator/dynamics/RemoveUnoccupiedSites.py
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -7,14 +7,18 @@
class RemoveUnoccupiedSites(object):
+ """Remove unoccupied sites."""
def __init__(self):
pass
def run(self, st, return_kept_sites = False):
"""
+ Args:
+ return_kept_sites (bool): If ``True``, a list of the sites from ``st``
+ that were kept will be returned.
- Can return `st` unmodified if all sites are at some point occupied.
-
+ Returns:
+ A ``SiteTrajectory``, or ``st`` itself if it has no unoccupied sites.
"""
assert isinstance(st, SiteTrajectory)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index cf8681d..b2116ab 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -27,8 +27,49 @@ def wrapper(self, *args, **kwargs):
return wrapper
class LandmarkAnalysis(object):
- """Track a mobile species through a fixed lattice using landmark vectors."""
-
+ """Site analysis of mobile atoms in a static lattice with landmark analysis.
+
+ :param double cutoff_center: The midpoint for the logistic function used
+ as the landmark cutoff function. (unitless)
+ :param double cutoff_steepness: Steepness of the logistic cutoff function.
+ :param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied)
+ for a site to qualify as such.
+ :param dict clustering_params: Parameters for the chosen ``clustering_algorithm``.
+ :param str peak_evening: Whether and what kind of peak "evening" to apply;
+ that is, processing that makes all large peaks in the landmark vector
+ more similar in magnitude. This can help in site clustering.
+
+ Valid options: 'none', 'clip'
+ :param bool weighted_site_positions: When computing site positions, whether
+ to weight the average by assignment confidence.
+ :param bool check_for_zero_landmarks: Whether to check for and raise exceptions
+ when all-zero landmark vectors are computed.
+ :param float static_movement_threshold: (Angstrom) the maximum allowed
+ distance between an instantanous static atom position and it's ideal position.
+ :param bool dynamic_lattice_mapping: Whether to dynamically decide each
+ frame which static atom represents each average lattice position;
+ this allows the LandmarkAnalysis to deal with, say, a rare exchage of
+ two static atoms that does not change the structure of the lattice.
+
+ It does NOT allow LandmarkAnalysis to deal with lattices whose structures
+ actually change over the course of the trajectory.
+
+ In certain cases this is better delt with by ``MergeSitesByDynamics``.
+ :param int max_mobile_per_site: The maximum number of mobile atoms that can
+ be assigned to a single site without throwing an error. Regardless of the
+ value, assignments of more than one mobile atom to a single site will
+ be recorded and reported.
+
+ Setting this to 2 can be necessary for very diffusive, liquid-like
+ materials at high temperatures.
+
+ Statistics related to this are reported in ``self.avg_mobile_per_site``
+ and ``self.n_multiple_assignments``.
+ :param bool force_no_memmap: if True, landmark vectors will be stored only in memory.
+ Only useful if access to landmark vectors after the analysis has run is desired.
+ :param bool verbose: Verbosity for the ``clustering_algorithm``. Other output
+ controlled through ``logging``.
+ """
def __init__(self,
clustering_algorithm = 'dotprod',
clustering_params = {},
@@ -44,49 +85,6 @@ def __init__(self,
max_mobile_per_site = 1,
force_no_memmap = False,
verbose = True):
- """
- :param double cutoff_center: The midpoint for the logistic function used
- as the landmark cutoff function. (unitless)
- :param double cutoff_steepness: Steepness of the logistic cutoff function.
- :param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied)
- for a site to qualify as such.
- :param dict clustering_params: Parameters for the chosen clustering_algorithm
- :param str peak_evening: Whether and what kind of peak "evening" to apply;
- that is, processing that makes all large peaks in the landmark vector
- more similar in magnitude. This can help in site clustering.
-
- Valid options: 'none', 'clip'
- :param bool weighted_site_positions: When computing site positions, whether
- to weight the average by assignment confidence.
- :param bool check_for_zero_landmarks: Whether to check for and raise exceptions
- when all-zero landmark vectors are computed.
- :param float static_movement_threshold: (Angstrom) the maximum allowed
- distance between an instantanous static atom position and it's ideal position.
- :param bool dynamic_lattice_mapping: Whether to dynamically decide each
- frame which static atom represents each average lattice position;
- this allows the LandmarkAnalysis to deal with, say, a rare exchage of
- two static atoms that does not change the structure of the lattice.
-
- It does NOT allow LandmarkAnalysis to deal with lattices whose structures
- actually change over the course of the trajectory.
-
- In certain cases this is better delt with by MergeSitesByDynamics.
- :param int max_mobile_per_site: The maximum number of mobile atoms that can
- be assigned to a single site without throwing an error. Regardless of the
- value, assignments of more than one mobile atom to a single site will
- be recorded and reported.
-
- Setting this to 2 can be necessary for very diffusive, liquid-like
- materials at high temperatures.
-
- Statistics related to this are reported in self.avg_mobile_per_site
- and self.n_multiple_assignments.
- :param bool force_no_memmap: if True, landmark vectors will be stored only in memory.
- Only useful if access to landmark vectors after the analysis has run is desired.
- :param bool verbose: If `True`, progress bars will be printed to stdout.
- Other output is handled seperately through the `logging` module.
- """
-
self._cutoff_midpoint = cutoff_midpoint
self._cutoff_steepness = cutoff_steepness
self._minimum_site_occupancy = minimum_site_occupancy
@@ -120,25 +118,30 @@ def cutoff(self):
@analysis_result
def landmark_vectors(self):
+ """Landmark vectors from the last invocation of ``run()``"""
view = self._landmark_vectors[:]
view.flags.writeable = False
return view
@analysis_result
def landmark_dimension(self):
+ """Number of components in a single landmark vector."""
return self._landmark_dimension
-
def run(self, sn, frames):
"""Run the landmark analysis.
- The input SiteNetwork is a network of predicted sites; it's sites will
+ The input ``SiteNetwork`` is a network of predicted sites; it's sites will
be used as the "basis" for the landmark vectors.
- Wraps a copy of `frames` into the unit cell; if you know `frames` is already
- wrapped, set `do_wrap = False` to avoid the copy.
+ Wraps a copy of ``frames`` into the unit cell.
- Takes a SiteNetwork and returns a SiteTrajectory.
+ Args:
+ sn (SiteNetwork): The landmark basis. Each site is a landmark defined
+ by its vertex static atoms, as indicated by `sn.vertices`.
+ (Typically from ``VoronoiSiteGenerator``.)
+ frames (ndarray n_frames x n_atoms x 3): A trajectory. Can be unwrapped;
+ a copy will be wrapped before the analysis.
"""
assert isinstance(sn, SiteNetwork)
diff --git a/sitator/landmark/errors.py b/sitator/landmark/errors.py
index a170cd3..dbc12e6 100644
--- a/sitator/landmark/errors.py
+++ b/sitator/landmark/errors.py
@@ -9,9 +9,7 @@ class StaticLatticeError(LandmarkAnalysisError):
lattice_atoms (list, optional): The indexes of the atoms in the static lattice that
caused the error.
frame (int, optional): The frame in the trajectory at which the error occured.
-
"""
-
TRY_RECENTERING_MSG = "Try recentering the input trajectory (sitator.util.RecenterTrajectory)"
def __init__(self, message, lattice_atoms = None, frame = None, try_recentering = False):
@@ -26,6 +24,12 @@ def __init__(self, message, lattice_atoms = None, frame = None, try_recentering
self.frame = frame
class ZeroLandmarkError(LandmarkAnalysisError):
+ """Error raised when a landmark vector containing only zeros is encountered.
+
+ Attributes:
+ mobile_index (int): Which mobile atom had the all-zero vector.
+ frame (int): At which frame it was encountered.
+ """
def __init__(self, mobile_index, frame):
message = "Encountered a zero landmark vector for mobile ion %i at frame %i. Try increasing `cutoff_midpoint` and/or decreasing `cutoff_steepness`." % (mobile_index, frame)
@@ -36,4 +40,5 @@ def __init__(self, mobile_index, frame):
self.frame = frame
class MultipleOccupancyError(LandmarkAnalysisError):
+ """Error raised when multiple mobile atoms are assigned to the same site."""
pass
diff --git a/sitator/misc/GenerateAroundSites.py b/sitator/misc/GenerateAroundSites.py
index 12a6138..c668239 100644
--- a/sitator/misc/GenerateAroundSites.py
+++ b/sitator/misc/GenerateAroundSites.py
@@ -4,7 +4,12 @@
from sitator.util import PBCCalculator
class GenerateAroundSites(object):
- """Generate n normally distributed sites around each input site"""
+ """Generate ``n`` normally distributed sites around each site.
+
+ Args:
+ n (int): How many sites to produce for each input site.
+ sigma (float): Standard deviation of the spatial Gaussian.
+ """
def __init__(self, n, sigma):
self.n = n
self.sigma = sigma
diff --git a/sitator/misc/GenerateClampedTrajectory.py b/sitator/misc/GenerateClampedTrajectory.py
index 64575ea..c06484c 100644
--- a/sitator/misc/GenerateClampedTrajectory.py
+++ b/sitator/misc/GenerateClampedTrajectory.py
@@ -12,12 +12,12 @@ class GenerateClampedTrajectory(object):
positions of the current site/their fixed static position.
Args:
- - wrap (bool, default: False): If True, all clamped positions will be in
- the unit cell; if False, the clamped position will be the minimum
+ wrap (bool): If ``True``, all clamped positions will be in
+ the unit cell; if ``False``, the clamped position will be the minimum
image of the clamped position with respect to the real-space position.
(This can generate a clamped, unwrapped real-space trajectory
from an unwrapped real space trajectory.)
- - pass_through_unassigned (bool, default: False): If True, when a
+ pass_through_unassigned (bool): If ``True``, when a
mobile atom is supposed to be clamped but is unassigned at some
frame, its real-space position will be passed through from the
real trajectory. If False, an error will be raised.
@@ -30,16 +30,16 @@ def __init__(self, wrap = False, pass_through_unassigned = False):
def run(self, st, clamp_mask = None):
"""Create a real-space trajectory with the fixed site/static structure positions.
- Generate a real-space trajectory where the atoms indicated in `clamp_mask` --
+ Generate a real-space trajectory where the atoms indicated in ``clamp_mask`` --
any mixture of static and mobile -- are clamped to: (1) the fixed position of
their current site, if mobile, or (2) the corresponding fixed position in
- the `SiteNetwork`'s static structure, if static.
+ the ``SiteNetwork``'s static structure, if static.
- Atoms not indicated in `clamp_mask` will have their positions from `real_traj`
- passed through.
+ Atoms not indicated in ``clamp_mask`` will have their positions from
+ ``real_traj`` passed through.
Args:
- - clamp_mask (ndarray, len(sn.structure))
+ clamp_mask (ndarray, len(sn.structure))
Returns:
ndarray (n_frames x n_atoms x 3)
"""
diff --git a/sitator/misc/MergeSitesByBarrier.py b/sitator/misc/MergeSitesByBarrier.py
deleted file mode 100644
index e69de29..0000000
diff --git a/sitator/misc/NAvgsPerSite.py b/sitator/misc/NAvgsPerSite.py
index 013c4e2..b4d1f01 100644
--- a/sitator/misc/NAvgsPerSite.py
+++ b/sitator/misc/NAvgsPerSite.py
@@ -5,15 +5,15 @@
from sitator.util import PBCCalculator
class NAvgsPerSite(object):
- """Given a SiteTrajectory, return a SiteNetwork containing n avg. positions per site.
+ """Given a ``SiteTrajectory``, return a ``SiteNetwork`` containing n avg. positions per site.
- The `types` of sites in the output are the index of the site in the input that generated
- that average.
+ The ``site_types`` of sites in the output are the index of the site in the
+ input that generated that average.
:param int n: How many averages to take
:param bool error_on_insufficient: Whether to throw an error if n points cannot
be provided for a site, or just take all that are available.
- :param bool weighted: Use SiteTrajectory confidences to weight the averages.
+ :param bool weighted: Use ``SiteTrajectory`` confidences to weight the averages.
"""
def __init__(self, n,
@@ -25,6 +25,12 @@ def __init__(self, n,
self.weighted = weighted
def run(self, st):
+ """
+ Args:
+ st (SiteTrajectory)
+ Returns:
+ A ``SiteNetwork``.
+ """
assert isinstance(st, SiteTrajectory)
if st.real_trajectory is None:
raise ValueError("SiteTrajectory must have associated real trajectory.")
diff --git a/sitator/network/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py
index 16d1de6..05439c5 100644
--- a/sitator/network/DiffusionPathwayAnalysis.py
+++ b/sitator/network/DiffusionPathwayAnalysis.py
@@ -23,8 +23,8 @@ class DiffusionPathwayAnalysis(object):
a pathway for it to be considered as such.
:param bool true_periodic_pathways: Whether only to return true periodic
pathways that include sites and their periodic images (i.e. conductive
- in the bulk) rather than just connected components. If True, `minimum_n_sites`
- is NOT respected.
+ in the bulk) rather than just connected components. If ``True``,
+ ``minimum_n_sites`` is NOT respected.
"""
NO_PATHWAY = -1
@@ -41,14 +41,13 @@ def __init__(self,
def run(self, sn, return_count = False):
"""
- Expects a SiteNetwork that has had a JumpAnalysis run on it.
+ Expects a ``SiteNetwork`` that has had a ``JumpAnalysis`` run on it.
- Adds information to `sn` in place.
+ Adds information to ``sn`` in place.
Args:
- - sn (SiteNetwork): Must have jump statistics from a `JumpAnalysis()`.
- - return_count (bool, default: False): Return the number of connected
- pathways.
+ sn (SiteNetwork): Must have jump statistics from a ``JumpAnalysis``.
+ return_count (bool): Return the number of connected pathways.
Returns:
sn, [n_pathways]
"""
diff --git a/sitator/network/MergeSitesByBarrier.py b/sitator/network/MergeSitesByBarrier.py
index 5d7879e..3865952 100644
--- a/sitator/network/MergeSitesByBarrier.py
+++ b/sitator/network/MergeSitesByBarrier.py
@@ -19,32 +19,33 @@ class MergeSitesByBarrier(MergeSites):
Uses a cheap coordinate driving system; this may not be sophisticated enough
for complex cases. For each pair of sites within the pairwise distance cutoff,
- a linear spatial interpolation is applied to produce `n_driven_images`.
+ a linear spatial interpolation is applied to produce ``n_driven_images``.
Two sites are considered mergable if their energies are within
- `final_initial_energy_threshold` and the barrier between them is below
- `barrier_threshold`. The barrier is defined as the maximum image energy minus
+ ``final_initial_energy_threshold`` and the barrier between them is below
+ ``barrier_threshold``. The barrier is defined as the maximum image energy minus
the average of the initial and final energy.
The energies of the mobile atom are calculated in a static lattice given
- by `coordinating_mask`; if `None`, this is set to the systems `static_mask`.
+ by ``coordinating_mask``; if ``None``, this is set to the system's
+ ``static_mask``.
- For resonable performance, `calculator` should be something simple like
- `ase.calculators.lj.LennardJones`.
+ For resonable performance, ``calculator`` should be something simple like
+ ``ase.calculators.lj.LennardJones``.
Takes species of first mobile atom as mobile species.
Args:
- - calculator (ase.Calculator): For computing total potential energies.
- - barrier_threshold (float, eV): The barrier value above which two sites
+ calculator (ase.Calculator): For computing total potential energies.
+ barrier_threshold (float, eV): The barrier value above which two sites
are not mergable.
- - n_driven_images (int, default: None): The number of evenly distributed
+ n_driven_images (int, default: None): The number of evenly distributed
driven images to use.
- - maximum_pairwise_distance (float, Angstrom): The maximum distance
+ maximum_pairwise_distance (float, Angstrom): The maximum distance
between two sites for them to be considered for merging.
- - minimum_jumps_mergable (int): The minimum number of observed jumps
+ minimum_jumps_mergable (int): The minimum number of observed jumps
between two sites for their merging to be considered. Setting this
higher can avoid unnecessary computations.
- - maximum_merge_distance (float, Angstrom): The maxiumum pairwise distance
+ maximum_merge_distance (float, Angstrom): The maxiumum pairwise distance
among a group of sites chosed to be merged.
"""
def __init__(self,
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
index 5b6e631..25f7a51 100644
--- a/sitator/network/merging.py
+++ b/sitator/network/merging.py
@@ -21,15 +21,15 @@ class TooFewMergedSitesError(MergeSitesError):
class MergeSites(abc.ABC):
"""Abstract base class for merging sites.
- :param bool check_types: If True, only sites of the same type are candidates to
- be merged; if false, type information is ignored. Merged sites will only
- be assigned types if this is True.
+ :param bool check_types: If ``True``, only sites of the same type are candidates to
+ be merged; if ``False``, type information is ignored. Merged sites will only
+ be assigned types if this is ``True``.
:param float maximum_merge_distance: Maximum distance between two sites
that are in a merge group, above which an error will be raised.
- :param bool set_merged_into: If True, a site attribute `"merged_into"` will
- be added to the original `SiteNetwork` indicating which new site
+ :param bool set_merged_into: If ``True``, a site attribute ``"merged_into"``
+ will be added to the original ``SiteNetwork`` indicating which new site
each old site was merged into.
- :param bool weighted_spatial_average: If True, the spatial average giving
+ :param bool weighted_spatial_average: If ``True``, the spatial average giving
the position of the merged site will be weighted by occupancy.
"""
def __init__(self,
@@ -44,7 +44,7 @@ def __init__(self,
def run(self, st, **kwargs):
- """Takes a SiteTrajectory and returns a SiteTrajectory, including a new SiteNetwork."""
+ """Takes a ``SiteTrajectory`` and returns a new ``SiteTrajectory``."""
if self.check_types and st.site_network.site_types is None:
raise ValueError("Cannot run a check_types=True MergeSites on a SiteTrajectory without type information.")
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index 9a86305..df0c128 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -20,21 +20,22 @@ class SOAP(object, metaclass=ABCMeta):
of the environment to consider. I.e. for Li2CO3, can be set to ['O'] or [8]
for oxygen only, or ['C', 'O'] / ['C', 8] / [6,8] if carbon and oxygen
are considered an environment.
- Defaults to `None`, in which case all non-mobile atoms are considered
+ Defaults to ``None``, in which case all non-mobile atoms are considered
regardless of species.
- :param soap_mask: Which atoms in the SiteNetwork's structure
+ :param soap_mask: Which atoms in the ``SiteNetwork``'s structure
to use in SOAP calculations.
Can be either a boolean mask ndarray or a tuple of species.
- If `None`, the entire static_structure of the SiteNetwork will be used.
+ If ``None``, the entire ``static_structure`` of the ``SiteNetwork`` will be used.
Mobile atoms cannot be used for the SOAP host structure.
Even not masked, species not considered in environment will be not accounted for.
- For ideal performance: Specify environment and soap_mask correctly!
+ For ideal performance: Specify environment and ``soap_mask`` correctly!
:param dict soap_params = {}: Any custom SOAP params.
- :param func backend: A function that can be called with `sn, soap_mask, tracer_atomic_number, environment_list` as
+ :param func backend: A function that can be called with
+ ``sn, soap_mask, tracer_atomic_number, environment_list`` as
parameters, returning a function that, given the current soap structure
along with tracer atoms, returns SOAP vectors in a numpy array. (i.e.
- its signature is `soap(structure, positions)`). The returned function
- can also have a property, `n_dim`, giving the length of a single SOAP
+ its signature is ``soap(structure, positions)``). The returned function
+ can also have a property, ``n_dim``, giving the length of a single SOAP
vector.
"""
@@ -77,11 +78,13 @@ def __init__(self, tracer_atomic_number, environment = None,
self._environment = None
def get_descriptors(self, stn):
- """
- Get the descriptors.
- :param stn: A valid instance of SiteTrajectory or SiteNetwork
- :returns: an array of descriptor vectors and an equal length array of
- labels indicating which descriptors correspond to which sites.
+ """Get the descriptors.
+
+ Args:
+ stn (SiteTrajectory or SiteNetwork)
+ Returns:
+ An array of descriptor vectors and an equal length array of labels
+ indicating which descriptors correspond to which sites.
"""
# Build SOAP host structure
if isinstance(stn, SiteTrajectory):
@@ -149,7 +152,7 @@ def _get_descriptors(self, stn, structure, tracer_atomic_number, soaper):
class SOAPCenters(SOAP):
"""Compute the SOAPs of the site centers in the fixed host structure.
- Requires a SiteNetwork as input.
+ Requires a ``SiteNetwork`` as input.
"""
def _get_descriptors(self, sn, structure, tracer_atomic_number, soap_mask, soaper):
if isinstance(sn, SiteTrajectory):
@@ -164,21 +167,20 @@ def _get_descriptors(self, sn, structure, tracer_atomic_number, soap_mask, soape
class SOAPSampledCenters(SOAPCenters):
- """Compute the SOAPs of representative points for each site, as determined by `sampling_transform`.
+ """Compute the SOAPs of representative points for each site, as determined by ``sampling_transform``.
- Takes either a SiteNetwork or SiteTrajectory as input; requires that
- `sampling_transform` produce a SiteNetwork where `site_types` indicates
- which site in the original SiteNetwork/SiteTrajectory it was sampled from.
+ Takes either a ``SiteNetwork`` or ``SiteTrajectory`` as input; requires that
+ ``sampling_transform`` produce a ``SiteNetwork`` where ``site_types`` indicates
+ which site in the original ``SiteNetwork``/``SiteTrajectory`` it was sampled from.
- Typical sampling transforms are `sitator.misc.NAvgsPerSite` (for a SiteTrajectory)
- and `sitator.misc.GenerateAroundSites` (for a SiteNetwork).
+ Typical sampling transforms are ``sitator.misc.NAvgsPerSite`` (for a ``SiteTrajectory``)
+ and ``sitator.misc.GenerateAroundSites`` (for a ``SiteNetwork``).
"""
def __init__(self, *args, **kwargs):
self.sampling_transform = kwargs.pop('sampling_transform', 1)
super(SOAPSampledCenters, self).__init__(*args, **kwargs)
def get_descriptors(self, stn):
-
# Do sampling
sampled = self.sampling_transform.run(stn)
assert isinstance(sampled, SiteNetwork), "Sampling transform returned `%s`, not a SiteNetwork" % sampled
@@ -203,7 +205,7 @@ class SOAPDescriptorAverages(SOAP):
:param int stepsize: Stride (in frames) when computing SOAPs. Default 1.
:param int averaging: Number of SOAP vectors to average for each output vector.
- :param int avg_descriptors_per_site: Can be specified instead of `averaging`.
+ :param int avg_descriptors_per_site: Can be specified instead of ``averaging``.
Specifies the _average_ number of average SOAP vectors to compute for each
site. This does not guerantee that number of SOAP vectors for any site,
rather, it allows a trajectory-size agnostic way to specify approximately
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index 273f919..87be7db 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -22,14 +22,27 @@ class SiteCoordinationEnvironment(object):
Determine site types using the method from the following paper:
- David Waroquiers, Xavier Gonze, Gian-Marco Rignanese, Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk, Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier,
- “Statistical analysis of coordination environments in oxides”,
+ David Waroquiers, Xavier Gonze, Gian-Marco Rignanese,
+ Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk,
+ Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier
+
+ Statistical analysis of coordination environments in oxides
+
Chem. Mater., 2017, 29 (19), pp 8346–8360, DOI: 10.1021/acs.chemmater.7b02766
- as implement in `pymatgen`'s `pymatgen.analysis.chemenv.coordination_environments`.
+ as implement in ``pymatgen``'s
+ ``pymatgen.analysis.chemenv.coordination_environments``.
+
+ Adds three site attributes:
+ - ``coordination_environments``: The name of the coordination environment,
+ as returned by ``pymatgen``. Example: ``"T:4"`` (tetrahedral, coordination
+ of 4).
+ - ``site_type_confidences``: The ``ce_fraction`` of the best match chemical
+ environment (from 0 to 1).
+ - ``coordination_numbers``: The coordination number of the site.
Args:
- **kwargs: passed to `compute_structure_environments`.
+ **kwargs: passed to ``compute_structure_environments``.
"""
def __init__(self, guess_ionic_bonds = True, **kwargs):
if not has_pymatgen:
@@ -38,6 +51,12 @@ def __init__(self, guess_ionic_bonds = True, **kwargs):
self._guess_ionic_bonds = guess_ionic_bonds
def run(self, sn):
+ """
+ Args:
+ sn (SiteNetwork)
+ Returns:
+ ``sn``, with type information.
+ """
# -- Determine local environments
# Get an ASE structure with a single mobile site that we'll move around
site_struct, site_species = sn[0:1].get_structure_with_sites()
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index caeb4b4..c77c3bc 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -21,11 +21,20 @@
class SiteTypeAnalysis(object):
"""Cluster sites into types using a continuous descriptor and Density Peak Clustering.
- -- descriptor --
- Some kind of object implementing:
- - get_descriptors(site_traj or site_network): returns an array of descriptor vectors
- of dimension (M, n_dim) and an array of length M indicating which
- descriptor vectors correspond to which sites in (site_traj.)site_network.
+ Computes descriptor vectors, processes them with Principal Component Analysis,
+ and then clusters using Density Peak Clustering.
+
+ Args:
+ descriptor (object): Must implement ``get_descriptors(st|sn)``, which
+ returns an array of descriptor vectors of dimension (M, n_dim) and
+ an array of length M indicating which descriptor vectors correspond
+ to which sites in (``site_traj.``)``site_network``.
+ min_pca_variance (float): The minimum proportion of the total variance
+ that the taken principal components of the descriptor must explain.
+ min_pca_dimensions (int): Force taking at least this many principal
+ components.
+ n_site_types_max (int): Maximum number of clusters. Must be set reasonably
+ for the automatic selection of cluster number to work.
"""
def __init__(self, descriptor,
min_pca_variance = 0.9, min_pca_dimensions = 2,
@@ -38,6 +47,12 @@ def __init__(self, descriptor,
self._n_dvecs = None
def run(self, descriptor_input, **kwargs):
+ """
+ Args:
+ descriptor_input (SiteNetwork or SiteTrajectory)
+ Returns:
+ ``SiteNetwork``
+ """
if not self._n_dvecs is None:
raise ValueError("Can't run SiteTypeAnalysis more than once!")
diff --git a/sitator/site_descriptors/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py
index 1025672..0353aa5 100644
--- a/sitator/site_descriptors/SiteVolumes.py
+++ b/sitator/site_descriptors/SiteVolumes.py
@@ -16,11 +16,11 @@ class SiteVolumes(object):
"""Compute the volumes of sites.
Args:
- - error_on_insufficient_coord (bool, default: True): To compute an
- ideal site volume (`compute_volumes()`), at least 4 coordinating
- atoms (because we are in 3D space) must be specified in `vertices`.
- If True, an error will be thrown when a site with less than four
- vertices is encountered; if False, a volume of 0 and surface area
+ error_on_insufficient_coord (bool): To compute an ideal
+ site volume (``compute_volumes()``), at least 4 coordinating
+ atoms (because we are in 3D space) must be specified in ``vertices``.
+ If ``True``, an error will be thrown when a site with less than four
+ vertices is encountered; if ``False``, a volume of 0 and surface area
of NaN will be returned.
"""
def __init__(self, error_on_insufficient_coord = True):
@@ -33,11 +33,11 @@ def compute_accessable_volumes(self, st, n_recenterings = 8):
Uses the shift-and-wrap trick for dealing with periodicity, so sites that
take up the majority of the unit cell may give bogus results.
- Adds the `accessable_site_volumes` attribute to the SiteNetwork.
+ Adds the ``accessable_site_volumes`` attribute to the ``SiteNetwork``.
Args:
- - st (SiteTrajectory)
- - n_recenterings (int): How many different recenterings to try (the
+ st (SiteTrajectory)
+ n_recenterings (int): How many different recenterings to try (the
algorithm will recenter around n of the points and take the minimal
resulting volume; this deals with cases where there is one outlier
where recentering around it gives very bad results.)
@@ -84,7 +84,7 @@ def compute_volumes(self, sn):
Requires vertex information in the SiteNetwork.
- Adds the `site_volumes` and `site_surface_areas` attributes.
+ Adds the ``site_volumes`` and ``site_surface_areas`` attributes.
Args:
- sn (SiteNetwork)
@@ -123,6 +123,5 @@ def compute_volumes(self, sn):
def run(self, st):
- """For backwards compatability.
- """
+ """For backwards compatability."""
self.compute_accessable_volumes(st)
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index 31c3417..51a61c9 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -25,18 +25,19 @@ class OneValueListlike(object):
class DotProdClassifier(object):
"""Assign vectors to clusters indicated by a representative vector using a cosine metric.
- Cluster centers can be given through `set_cluster_centers()` or approximated
+ Cluster centers can be given through ``set_cluster_centers()`` or approximated
using the custom method described in the appendix of the main landmark
- analysis paper (`fit_centers()`).
-
- :param float threshold: Similarity threshold for joining a cluster.
- In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal)
- :param int max_converge_iters: Maximum number of iterations. If the algorithm hasn't converged
- by then, it will exit with a warning.
- :param int|float min_samples: filter out clusters with low sample counts.
- If an int, filters out clusters with fewer samples than this.
- If a float, filters out clusters with fewer than floor(min_samples * n_assigned_samples)
- samples assigned to them.
+ analysis paper (``fit_centers()``).
+
+ Args:
+ threshold (float): Similarity threshold for joining a cluster.
+ In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal)
+ max_converge_iters (int): Maximum number of iterations. If the algorithm
+ hasn't converged by then, it will exit with a warning.
+ min_samples (float or int): filter out clusters with low sample counts.
+ If an int, filters out clusters with fewer samples than this.
+ If a float, filters out clusters with fewer than
+ floor(min_samples * n_assigned_samples) samples assigned to them.
"""
def __init__(self,
threshold = 0.9,
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index a4022e6..07f3381 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -11,7 +11,7 @@ ctypedef double precision
ctypedef double cell_precision
cdef class PBCCalculator(object):
- """Performs calculations on collections of 3D points under PBC"""
+ """Performs calculations on collections of 3D points under PBC."""
cdef cell_precision [:, :] _cell_mat_array
cdef cell_precision [:, :] _cell_mat_inverse_array
@@ -40,7 +40,7 @@ cdef class PBCCalculator(object):
cpdef pairwise_distances(self, pts):
- """Compute the pairwise distance matrix of `pts` with itself.
+ """Compute the pairwise distance matrix of ``pts`` with itself.
:returns ndarray (len(pts), len(pts)): distances
"""
@@ -56,10 +56,11 @@ cdef class PBCCalculator(object):
cpdef distances(self, pt1, pts2, in_place = False, out = None):
- """Compute the Euclidean distances from pt1 to all points in pts2, using
- shift-and-wrap.
+ """
+ Compute the Euclidean distances from ``pt1`` to all points in
+ ``pts2``, using shift-and-wrap.
- Makes a copy of pts2 unless in_place == True.
+ Makes a copy of ``pts2`` unless ``in_place == True``.
:returns ndarray len(pts2): distances
"""
@@ -248,11 +249,11 @@ cdef class PBCCalculator(object):
cpdef int min_image(self, const precision [:] ref, precision [:] pt):
- """Find the minimum image of `pt` relative to `ref`. In place in pt.
+ """Find the minimum image of ``pt`` relative to ``ref``. In place in pt.
Uses the brute force algorithm for correctness; returns the minimum image.
- Assumes that `ref` and `pt` are already in the *same* cell (though not
+ Assumes that ``ref`` and ``pt`` are already in the *same* cell (though not
necessarily the <0,0,0> cell -- any periodic image will do).
:returns int[3] minimg: Which image was the minimum image.
@@ -327,7 +328,7 @@ cdef class PBCCalculator(object):
cpdef void wrap_points(self, precision [:, :] points):
- """Wrap `points` into a unit cell, IN PLACE. 3D only.
+ """Wrap ``points`` into a unit cell, IN PLACE. 3D only.
"""
assert points.shape[1] == 3, "Points must be 3D"
diff --git a/sitator/util/RecenterTrajectory.pyx b/sitator/util/RecenterTrajectory.pyx
index c70430f..1e3a7c9 100644
--- a/sitator/util/RecenterTrajectory.pyx
+++ b/sitator/util/RecenterTrajectory.pyx
@@ -18,7 +18,7 @@ class RecenterTrajectory(object):
``static_mask``, IN PLACE.
Args:
- structure (ASE Atoms): An atoms representing the structure of the
+ structure (ase.Atoms): An atoms representing the structure of the
simulation.
static_mask (ndarray): Boolean mask indicating which atoms to recenter on
positions (ndarray): (n_frames, n_atoms, 3), modified in place
diff --git a/sitator/util/elbow.py b/sitator/util/elbow.py
index ec038ac..078bec9 100644
--- a/sitator/util/elbow.py
+++ b/sitator/util/elbow.py
@@ -2,7 +2,7 @@
# See discussion around this question: https://stackoverflow.com/questions/2018178/finding-the-best-trade-off-point-on-a-curve/2022348#2022348
def index_of_elbow(points):
- """Returns the index of the "elbow" in points.
+ """Returns the index of the "elbow" in ``points``.
Decently fast and pretty approximate. Performs worse with disproportionately
long "flat" tails. For example, in a dataset with a nearly right-angle elbow,
diff --git a/sitator/util/mcl.py b/sitator/util/mcl.py
index 6a975ab..784cf7a 100644
--- a/sitator/util/mcl.py
+++ b/sitator/util/mcl.py
@@ -5,7 +5,7 @@ def markov_clustering(transition_matrix,
inflation = 2,
pruning_threshold = 0.00001,
iterlimit = 100):
- """
+ """Compute the Markov Clustering of a graph.
See https://micans.org/mcl/.
Because we're dealing with matrixes that are stochastic already,
diff --git a/sitator/util/zeo.py b/sitator/util/zeo.py
index 338e6e0..19bddf0 100644
--- a/sitator/util/zeo.py
+++ b/sitator/util/zeo.py
@@ -22,7 +22,7 @@
# TODO: benchmark CUC vs CIF
class Zeopy(object):
- """A wrapper for the Zeo++ `network` tool.
+ """A wrapper for the Zeo++ ``network`` tool.
:warning: Do not use a single instance of Zeopy in parallel.
"""
@@ -30,7 +30,7 @@ class Zeopy(object):
def __init__(self, path_to_zeo):
"""Create a Zeopy.
- :param str path_to_zeo: Path to the `network` executable.
+ :param str path_to_zeo: Path to the ``network`` executable.
"""
if not (os.path.exists(path_to_zeo) and os.access(path_to_zeo, os.X_OK)):
raise ValueError("`%s` doesn't seem to be the path to an executable file." % path_to_zeo)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 48def0a..54d2985 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -9,32 +9,32 @@
from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS, set_axes_equal
class SiteNetworkPlotter(object):
- """Plot a SiteNetwork.
+ """Plot a ``SiteNetwork``.
Note that for edges, the average of the edge property for i -> j and j -> i
is often used for visual clarity; if your edge properties are not almost symmetric,
the visualization might not be useful.
- Params:
- - site_mappings (dict): defines how to show different properties. Each
+ Args:
+ site_mappings (dict): defines how to show different properties. Each
entry maps a visual aspect ('marker', 'color', 'size') to the name
of a site attribute including 'site_type'. The markers can also be
arbitrary text (key `"text"`) in which case the value can also be a
2-tuple of an attribute name and a `%` format string.
- - edge_mappings (dict): each key maps a visual property ('intensity',
+ edge_mappings (dict): each key maps a visual property ('intensity',
'color', 'width', 'linestyle') to an edge attribute in the SiteNetwork.
- - markers (list of str): What `matplotlib` markers to use for sites.
- - plot_points_params (dict): User options for plotting site points.
- - minmax_linewidth (2-tuple): Minimum and maximum linewidth to use.
- - minmax_edge_alpha (2-tuple): Similar, for edge line alphas.
- - minmax_markersize (2-tuple): Similar, for markersize.
- - min_color_threshold (float): Minimum (normalized) color intensity for
+ markers (list of str): What `matplotlib` markers to use for sites.
+ plot_points_params (dict): User options for plotting site points.
+ minmax_linewidth (2-tuple): Minimum and maximum linewidth to use.
+ minmax_edge_alpha (2-tuple): Similar, for edge line alphas.
+ minmax_markersize (2-tuple): Similar, for markersize.
+ min_color_threshold (float): Minimum (normalized) color intensity for
the corresponding line to be shown. Defaults to zero, i.e., all
nonzero edges will be drawn.
- - min_width_threshold (float): Minimum normalized edge width for the
+ min_width_threshold (float): Minimum normalized edge width for the
corresponding edge to be shown. Defaults to zero, i.e., all
nonzero edges will be drawn.
- - title (str)
+ title (str): Title for the figure.
"""
DEFAULT_SITE_MAPPINGS = {
diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py
index 7649c8b..484850f 100644
--- a/sitator/visualization/SiteTrajectoryPlotter.py
+++ b/sitator/visualization/SiteTrajectoryPlotter.py
@@ -7,8 +7,16 @@
class SiteTrajectoryPlotter(object):
+ """Produce various plots of a ``SiteTrajectory``."""
+
@plotter(is3D = True)
def plot_frame(self, st, frame, **kwargs):
+ """Plot sites and instantaneous positions from a given frame.
+
+ Args:
+ st (SiteTrajectory)
+ frame (int)
+ """
sites_of_frame = np.unique(st._traj[frame])
frame_sn = st._sn[sites_of_frame]
@@ -26,6 +34,12 @@ def plot_frame(self, st, frame, **kwargs):
@plotter(is3D = True)
def plot_site(self, st, site, **kwargs):
+ """Plot all real space positions associated with a site.
+
+ Args:
+ st (SiteTrajectory)
+ site (int)
+ """
pbcc = PBCCalculator(st._sn.structure.cell)
pts = st.real_positions_for_site(site).copy()
offset = pbcc.cell_centroid - pts[3]
@@ -54,6 +68,12 @@ def plot_site(self, st, site, **kwargs):
@plotter(is3D = False)
def plot_particle_trajectory(self, st, particle, ax = None, fig = None, **kwargs):
+ """Plot the sites occupied by a mobile particle over time.
+
+ Args:
+ st (SiteTrajectory)
+ particle (int)
+ """
types = not st._sn.site_types is None
if types:
type_height_percent = 0.1
diff --git a/sitator/voronoi/VoronoiSiteGenerator.py b/sitator/voronoi.py
similarity index 73%
rename from sitator/voronoi/VoronoiSiteGenerator.py
rename to sitator/voronoi.py
index 1f5f3c3..5b94dbf 100644
--- a/sitator/voronoi/VoronoiSiteGenerator.py
+++ b/sitator/voronoi.py
@@ -9,9 +9,9 @@
class VoronoiSiteGenerator(object):
"""Given an empty SiteNetwork, use the Voronoi decomposition to predict/generate sites.
- :param str zeopp_path: Path to the Zeo++ `network` executable
+ :param str zeopp_path: Path to the Zeo++ ``network`` executable
:param bool radial: Whether to use the radial Voronoi transform. Defaults to,
- and should typically be, False.
+ and should typically be, ``False``.
"""
def __init__(self,
@@ -21,7 +21,13 @@ def __init__(self,
self._zeopy = Zeopy(zeopp_path)
def run(self, sn):
- """SiteNetwork -> SiteNetwork"""
+ """
+ Args:
+ sn (SiteNetwork): Any sites will be ignored; needed for structure
+ and static mask.
+ Returns:
+ A ``SiteNetwork``.
+ """
assert isinstance(sn, SiteNetwork)
with self._zeopy:
diff --git a/sitator/voronoi/__init__.py b/sitator/voronoi/__init__.py
deleted file mode 100644
index 6a5d16a..0000000
--- a/sitator/voronoi/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-from .VoronoiSiteGenerator import VoronoiSiteGenerator
From ee72fa3d69506bf2915320ba98ba89f97faae560 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 18:08:37 -0400
Subject: [PATCH 078/129] Versioning and requirements
---
requirements.txt | 7 +++++++
setup.py | 2 +-
2 files changed, 8 insertions(+), 1 deletion(-)
create mode 100644 requirements.txt
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e719012
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+numpy
+Cython
+scipy
+matplotlib
+ase
+tqdm
+sklearn
diff --git a/setup.py b/setup.py
index 11024c9..a61275f 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
import numpy as np
setup(name = 'sitator',
- version = '1.0.1',
+ version = '2.0.0',
description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.',
download_url = "https://github.com/Linux-cpp-lisp/sitator",
author = 'Alby Musaelian',
From d6f42fb07a80dc49d4e8a912bdb45caf4fdfeb17 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 18:12:22 -0400
Subject: [PATCH 079/129] Allow import when dependency isn't met
---
sitator/site_descriptors/SiteTypeAnalysis.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index c77c3bc..26efa8f 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -13,10 +13,12 @@
import logging
logger = logging.getLogger(__name__)
+has_pydpc = False
try:
import pydpc
+ has_pydpc = True
except ImportError:
- raise ImportError("SiteTypeAnalysis requires the `pydpc` package")
+ pass
class SiteTypeAnalysis(object):
"""Cluster sites into types using a continuous descriptor and Density Peak Clustering.
@@ -39,6 +41,9 @@ class SiteTypeAnalysis(object):
def __init__(self, descriptor,
min_pca_variance = 0.9, min_pca_dimensions = 2,
n_site_types_max = 20):
+ if not has_pydpc:
+ raise ImportError("SiteTypeAnalysis requires the `pydpc` package")
+
self.descriptor = descriptor
self.min_pca_variance = min_pca_variance
self.min_pca_dimensions = min_pca_dimensions
From ee430ba93a0bc967e7ad3595a4fea62f68656158 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 18:14:06 -0400
Subject: [PATCH 080/129] Docs fix
---
docs/source/conf.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index af7b379..3d1da40 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -40,6 +40,7 @@
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
+master_doc = 'index'
# -- Options for HTML output -------------------------------------------------
From 357d9e168c77a191b08ab59807f8047c3bda21ed Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 18:17:11 -0400
Subject: [PATCH 081/129] Docs link
---
README.md | 2 +-
docs/source/conf.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 4a9f991..980810a 100644
--- a/README.md
+++ b/README.md
@@ -49,7 +49,7 @@ Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4
`sitator` generally assumes units of femtoseconds for time, Angstroms for space,
and Cartesian (not crystal) coordinates.
-All individual classes and parameters are documented with docstrings in the source code.
+Documentation can be found at [ReadTheDocs](https://sitator.readthedocs.io/en/py3/).
## Global Options
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 3d1da40..8e7ec1b 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -47,7 +47,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = 'alabaster'
+html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
From 4d5aff43e21366764d9731c762beb24126e2ad76 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 18:20:07 -0400
Subject: [PATCH 082/129] Removed "peak evening"
---
sitator/landmark/LandmarkAnalysis.py | 24 ------------------------
1 file changed, 24 deletions(-)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index b2116ab..3dc646f 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -35,11 +35,6 @@ class LandmarkAnalysis(object):
:param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied)
for a site to qualify as such.
:param dict clustering_params: Parameters for the chosen ``clustering_algorithm``.
- :param str peak_evening: Whether and what kind of peak "evening" to apply;
- that is, processing that makes all large peaks in the landmark vector
- more similar in magnitude. This can help in site clustering.
-
- Valid options: 'none', 'clip'
:param bool weighted_site_positions: When computing site positions, whether
to weight the average by assignment confidence.
:param bool check_for_zero_landmarks: Whether to check for and raise exceptions
@@ -76,7 +71,6 @@ def __init__(self,
cutoff_midpoint = 1.5,
cutoff_steepness = 30,
minimum_site_occupancy = 0.01,
- peak_evening = 'none',
weighted_site_positions = True,
check_for_zero_landmarks = True,
static_movement_threshold = 1.0,
@@ -92,10 +86,6 @@ def __init__(self,
self._cluster_algo = clustering_algorithm
self._clustering_params = clustering_params
- if not peak_evening in ['none', 'clip']:
- raise ValueError("Invalid value `%s` for peak_evening" % peak_evening)
- self._peak_evening = peak_evening
-
self.verbose = verbose
self.check_for_zero_landmarks = check_for_zero_landmarks
self.weighted_site_positions = weighted_site_positions
@@ -211,8 +201,6 @@ def run(self, sn, frames):
# -- Step 3: Cluster landmark vectors
logger.info(" - clustering landmark vectors -")
- # - Preprocess -
- self._do_peak_evening()
# - Cluster -
# FIXME: remove reload after development done
@@ -279,15 +267,3 @@ def run(self, sn, frames):
self._has_run = True
return out_st
-
- # -------- "private" methods --------
-
- def _do_peak_evening(self):
- if self._peak_evening == 'none':
- return
- elif self._peak_evening == 'clip':
- lvec_peaks = np.max(self._landmark_vectors, axis = 1)
- # Clip all peaks to the lowest "normal" (stdev.) peak
- lvec_clip = np.mean(lvec_peaks) - np.std(lvec_peaks)
- # Do the clipping
- self._landmark_vectors[self._landmark_vectors > lvec_clip] = lvec_clip
From 16105edac1994a81ba6654beeef097fb8873034d Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 16 Jul 2019 18:42:37 -0400
Subject: [PATCH 083/129] Fix unknown in remove unoccupied
---
sitator/dynamics/RemoveShortJumps.py | 2 +-
sitator/dynamics/RemoveUnoccupiedSites.py | 9 ++++++---
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
index d53bfda..b2f9114 100644
--- a/sitator/dynamics/RemoveShortJumps.py
+++ b/sitator/dynamics/RemoveShortJumps.py
@@ -120,7 +120,7 @@ def run(self,
logger.info(
"Short jump statistics:\n" +
"\n".join(
- " removed {1[1]:3}x {0[0]:2} -> {0[1]:2} -> {0[2]:2}; avg. residence at {0[1]:2} of {1[0]} frames".format(
+ " removed {1[1]:3}x {0[0]:2} -> {0[1]:2} -> {0[2]:2}; avg. residence at {0[1]:2} of {1[0]} frames".format(
k, v
) for k, v in short_jump_info.items()
)
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
index e7dafb3..3a7d0df 100644
--- a/sitator/dynamics/RemoveUnoccupiedSites.py
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -37,13 +37,16 @@ def run(self, st, return_kept_sites = False):
logger.info("Removing unoccupied sites %s" % np.where(~seen_mask)[0])
n_new_sites = np.sum(seen_mask)
- translation = np.empty(shape = old_sn.n_sites, dtype = np.int)
- translation[seen_mask] = np.arange(n_new_sites)
- translation[~seen_mask] = SiteTrajectory.SITE_UNKNOWN
+ translation = np.empty(shape = old_sn.n_sites + 1, dtype = np.int)
+ translation[:-1][seen_mask] = np.arange(n_new_sites)
+ translation[:-1][~seen_mask] = -4321
+ translation[-1] = SiteTrajectory.SITE_UNKNOWN # Map unknown to unknown
newtraj = translation[st.traj.reshape(-1)]
newtraj.shape = st.traj.shape
+ assert -4321 not in newtraj
+
# We don't clear computed attributes since nothing is changing for other sites.
newsn = old_sn[seen_mask]
From baae9b713694241e3af329e4cf056dc696b80ebd Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 14:40:15 -0400
Subject: [PATCH 084/129] Optimized SiteTrajectory smoothing
---
setup.py | 5 +-
sitator/dynamics/SmoothSiteTrajectory.pyx | 86 +++++++++++++++++++++++
2 files changed, 89 insertions(+), 2 deletions(-)
create mode 100644 sitator/dynamics/SmoothSiteTrajectory.pyx
diff --git a/setup.py b/setup.py
index a61275f..89d74a7 100644
--- a/setup.py
+++ b/setup.py
@@ -12,8 +12,9 @@
packages = find_packages(),
ext_modules = cythonize([
"sitator/landmark/helpers.pyx",
- "sitator/util/*.pyx"
- ]),
+ "sitator/util/*.pyx",
+ "sitator/dynamics/*.pyx"
+ ], language_level = 3),
include_dirs=[np.get_include()],
install_requires = [
"numpy",
diff --git a/sitator/dynamics/SmoothSiteTrajectory.pyx b/sitator/dynamics/SmoothSiteTrajectory.pyx
new file mode 100644
index 0000000..34497ef
--- /dev/null
+++ b/sitator/dynamics/SmoothSiteTrajectory.pyx
@@ -0,0 +1,86 @@
+# cython: language_level=3
+
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.dynamics import RemoveUnoccupiedSites
+
+import logging
+logger = logging.getLogger(__name__)
+
+ctypedef Py_ssize_t site_int
+
+class SmoothSiteTrajectory(object):
+ """"Smooth" a SiteTrajectory by applying a rolling mode.
+
+ For each mobile particle, the assignmet at each frame is replaced by the
+ mode of its site assignments over some number of frames centered around it.
+ If the multiplicity of the mode is less than the threshold, the particle is
+ marked unassigned at that frame.
+
+ Can be thought of as a discrete lowpass filter.
+
+ Args:
+ remove_unoccupied_sites (bool): If True, sites that are unoccupied after
+ removing short jumps will be removed.
+ """
+ def __init__(self,
+ window_threshold_factor = 2.1,
+ remove_unoccupied_sites = True):
+ self.window_threshold_factor = window_threshold_factor
+ self.remove_unoccupied_sites = remove_unoccupied_sites
+
+ def run(self,
+ st,
+ threshold):
+ n_mobile = st.site_network.n_mobile
+ n_frames = st.n_frames
+ n_sites = st.site_network.n_sites
+
+ traj = st.traj
+ out = st.traj.copy()
+
+ window = self.window_threshold_factor * threshold
+ wleft, wright = int(np.floor(window / 2)), int(np.ceil(window / 2))
+
+ running_windowed_mode(traj, out, wleft, wright, threshold, n_sites)
+
+ st = st.copy(with_computed = False)
+ st._traj = out
+ if self.remove_unoccupied_sites:
+ # Removing short jumps could have made some sites completely unoccupied
+ st = RemoveUnoccupiedSites().run(st)
+ st.site_network.clear_attributes()
+
+ return st
+
+
+cpdef running_windowed_mode(site_int [:, :] traj,
+ site_int [:, :] out,
+ Py_ssize_t wleft,
+ Py_ssize_t wright,
+ Py_ssize_t threshold,
+ Py_ssize_t n_sites):
+ countbuf_np = np.zeros(shape = n_sites + 1, dtype = np.int)
+ cdef Py_ssize_t [:] countbuf = countbuf_np
+ cdef Py_ssize_t n_mobile = traj.shape[1]
+ cdef Py_ssize_t n_frames = traj.shape[0]
+ cdef site_int s_unknown = SiteTrajectory.SITE_UNKNOWN
+ cdef site_int winner
+ cdef Py_ssize_t best_count
+
+ for mob in range(n_mobile):
+ for frame in range(n_frames):
+ for wi in range(max(frame - wleft, 0), min(frame + wright, n_frames)):
+ countbuf[traj[wi, mob] + 1] += 1
+ winner = 0 # THis is actually -1, so unknown by default
+ best_count = 0
+ for site in range(n_sites + 1):
+ if countbuf[site] > best_count:
+ winner = site
+ best_count = countbuf[site]
+ if best_count >= threshold:
+ out[frame, mob] = winner - 1
+ else:
+ out[frame, mob] = s_unknown
+ countbuf[:] = 0
From 369abcb7a81afeae5212ae6e0e03c0818e6ec478 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 14:41:18 -0400
Subject: [PATCH 085/129] Run-length RemoveShortJumps
---
sitator/dynamics/RemoveShortJumps.py | 197 ++++++++++++++++-----------
1 file changed, 120 insertions(+), 77 deletions(-)
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
index b2f9114..8e85794 100644
--- a/sitator/dynamics/RemoveShortJumps.py
+++ b/sitator/dynamics/RemoveShortJumps.py
@@ -4,10 +4,28 @@
from sitator import SiteTrajectory
from sitator.dynamics import RemoveUnoccupiedSites
+from sitator.util import PBCCalculator
import logging
logger = logging.getLogger(__name__)
+# From https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
+def rle(inarray):
+ """ run length encoding. Partial credit to R rle function.
+ Multi datatype arrays catered for including non Numpy
+ returns: tuple (runlengths, startpositions, values) """
+ ia = np.asarray(inarray) # force numpy
+ n = len(ia)
+ if n == 0:
+ return (None, None, None)
+ else:
+ y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe)
+ i = np.append(np.where(y), n - 1) # must include last element posi
+ z = np.diff(np.append(-1, i)) # run lengths
+ p = np.cumsum(np.append(0, z))[:-1] # positions
+ return (z, p, ia[i])
+
+
class RemoveShortJumps(object):
"""Remove "short" jumps in a SiteTrajectory.
@@ -15,112 +33,137 @@ class RemoveShortJumps(object):
and, optionally, only where the mobile atom returns to the site it originally
jumped from.
+ It only counts as a short jump if
+
Args:
only_returning_jumps (bool): If True, only short jumps
where the mobile atom returns to its initial site will be removed.
+ remove_unoccupied_sites (bool): If True, sites that are unoccupied after
+ removing short jumps will be removed.
+ replacement_function (callable): Callable that takes
+ ``(st, mobile_atom, from_site, start_frame, to_site, end_frame)`` and
+ returns either:
+ - A single site assignment with which the short jump will be replaced
+ - A timeseries of length ``end_frame - start_frame`` of site
+ assignments with which the short jump will be replaced.
+ If ``None``, defaults to ``RemoveShortJumps.replace_with_from``.
"""
def __init__(self,
only_returning_jumps = True,
- remove_unoccupied_sites = True):
+ remove_unoccupied_sites = True,
+ replacement_function = None):
self.only_returning_jumps = only_returning_jumps
self.remove_unoccupied_sites = remove_unoccupied_sites
-
+ if replacement_function is None:
+ replacement_function = RemoveShortJumps.replace_with_from
+ self.replacement_function = replacement_function
+
+ @staticmethod
+ def replace_with_from(st, mobile_atom, from_site, start_frame, to_site, end_frame):
+ """Replace a short jump with the site being jumped from."""
+ return from_site
+
+ @staticmethod
+ def replace_with_to(st, mobile_atom, from_site, start_frame, to_site, end_frame):
+ """Replace a short jump with the site being jumped to after the short jump."""
+ return to_site
+
+ @staticmethod
+ def replace_with_unknown(st, mobile_atom, from_site, start_frame, to_site, end_frame):
+ """Mark as unassigned during a short jump."""
+ return SiteTrajectory.SITE_UNKNOWN
+
+ @staticmethod
+ def replace_with_closer():
+ """Create function to replace short jump with closest site over time.
+
+ Assigns the positions during a short jump to whichever of the from
+ and to site it is closer to in real space.
+ """
+ pbcc = None
+ ptbuf = np.empty(shape = (2, 3))
+ distbuf = np.empty(shape = 2)
+ def replace_with_closer(st, mobile_atom, from_site, start_frame, to_site, end_frame):
+ if pbcc is None:
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+ n_frames = end_frame - start_frame
+ out = np.empty(shape = n_frames)
+ for i in range(n_frames):
+ ptbuf[0] = st.site_network.centers[from_site]
+ ptbuf[1] = st.site_network.centers[to_site]
+ pbcc.distances(
+ st.real_trajectory[start_frame + i, mobile_atom],
+ ptbuf,
+ in_place = True,
+ out = distbuf
+ )
+ if distbuf[0] < distbuf[1]:
+ out[i] = from_site
+ else:
+ out[i] = to_site
+ return out
def run(self,
st,
threshold,
return_stats = False):
- """Returns a copy of ``st`` with short jumps removed.
-
- Args:
- st (SiteTrajectory): Unassigned considered to be last known.
- threshold (int): The largest number of frames the mobile atom
- can spend at a site while the jump is still considered short.
-
- Returns:
- A ``SiteTrajectory``.
- """
n_mobile = st.site_network.n_mobile
n_frames = st.n_frames
n_sites = st.site_network.n_sites
- previous_site = np.full(shape = n_mobile, fill_value = -2, dtype = np.int)
- last_known = np.empty(shape = n_mobile, dtype = np.int)
- np.copyto(last_known, st.traj[0])
- # Everything is at it's first position for at least one frame by definition
- time_at_current = np.ones(shape = n_mobile, dtype = np.int)
-
- framebuf = np.empty(shape = st.traj.shape[1:], dtype = st.traj.dtype)
+ st_no_un = st.copy(with_computed = False)
+ st_no_un.assign_to_last_known_site(frame_threshold = np.inf)
+ traj = st_no_un.traj
out = st.traj.copy()
- n_problems = 0
- n_short_jumps = 0
-
# Dict of lists [sum_jump_times, n_short_jumps]
short_jump_info = defaultdict(lambda: [0, 0])
- for i, frame in enumerate(st.traj):
- if i == 0:
- continue
- # -- Deal with unassigned
- # Don't screw up the SiteTrajectory
- np.copyto(framebuf, frame)
- frame = framebuf
-
- unassigned = frame == SiteTrajectory.SITE_UNKNOWN
- # Reassign unassigned
- frame[unassigned] = last_known[unassigned]
- fknown = frame >= 0
-
- if np.any(~fknown):
- logger.warning("At frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
- # -- Update stats
-
- jumped = (frame != last_known) & fknown
- #problems = last_known[jumped] == -1
- #jumped[np.where(jumped)[0][problems]] = False
- problems = last_known == -1
- jumped[problems] = False
- n_problems += np.sum(problems)
-
- jump_froms = last_known[jumped]
- jump_tos = frame[jumped]
-
- # For all that didn't jump, increment time at current
- time_at_current[~jumped] += 1
- # For all that did, check if short
- short_mask = time_at_current[jumped] <= threshold
- if self.only_returning_jumps:
- short_mask &= jump_tos == previous_site[jumped]
- # Remove short jumps
- for sj_atom in np.arange(n_mobile)[jumped][short_mask]:
- # Bookkeeping
- sjkey = (previous_site[sj_atom], last_known[sj_atom], frame[sj_atom])
- short_jump_info[sjkey][0] += time_at_current[sj_atom]
- short_jump_info[sjkey][1] += 1
- n_short_jumps += 1
- # Remove short jump
- out[i - time_at_current[sj_atom]:i, sj_atom] = previous_site[sj_atom]
-
- previous_site[jumped] = last_known[jumped]
-
- # Reset for those that jumped
- time_at_current[jumped] = 1
-
- # Update last known assignment for anything that has one
- last_known[~unassigned] = frame[~unassigned]
-
- if n_problems != 0:
- logger.warning("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
- logger.info("Removed %i short jumps" % n_short_jumps)
+ for mob in range(n_mobile):
+ runlen, start, runsites = rle(traj[:, mob])
+ # We pretend that the first and last run extend into infinity
+ # Think Fourier transforms
+ runlen[[0, -1]] = np.iinfo(runlen.dtype).max
+ last_long_enough = 0
+ short_start = None
+ short_from = None
+ short_transitionals = None
+ for runi in range(0, len(runlen)):
+ shortrun = runlen[runi] < threshold
+ if shortrun:
+ if short_start is None:
+ short_start = start[runi]
+ short_from = runsites[runi - 1]
+ short_transitionals = [runsites[runi]]
+ else:
+ short_transitionals.append(runsites[runi])
+ elif short_start is not None:
+ # Process short jump
+ short_to = runsites[runi]
+ short_end = start[runi]
+ # If we're only doing returning jumps, check that
+ do = (not self.only_returning_jumps) or (short_to == short_from)
+ if do:
+ sjkey = (short_from, tuple(short_transitionals), short_to)
+ short_jump_info[sjkey][0] += short_end - short_start
+ short_jump_info[sjkey][1] += 1
+ replace = self.replacement_function(
+ st, mob,
+ short_from, short_start,
+ short_to, short_end
+ )
+ out[short_start:short_end, mob] = replace
+ # Reset
+ short_start = None
+
# Do average
for k in short_jump_info.keys():
short_jump_info[k][0] /= short_jump_info[k][1]
logger.info(
"Short jump statistics:\n" +
"\n".join(
- " removed {1[1]:3}x {0[0]:2} -> {0[1]:2} -> {0[2]:2}; avg. residence at {0[1]:2} of {1[0]} frames".format(
+ " removed {1[1]:3}x {0[0]:2} -> {0[1]} -> {0[2]:2}; spent {1[0]:.1f} frames at {0[1]}".format(
k, v
) for k, v in short_jump_info.items()
)
From 3354b15a587853a8b5ded33dced19207b79db367 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 14:52:25 -0400
Subject: [PATCH 086/129] Removed broken RemoveShortJumps
---
sitator/dynamics/RemoveShortJumps.py | 182 ------------------
.../dynamics/ReplaceUnassignedPositions.py | 117 +++++++++++
2 files changed, 117 insertions(+), 182 deletions(-)
delete mode 100644 sitator/dynamics/RemoveShortJumps.py
create mode 100644 sitator/dynamics/ReplaceUnassignedPositions.py
diff --git a/sitator/dynamics/RemoveShortJumps.py b/sitator/dynamics/RemoveShortJumps.py
deleted file mode 100644
index 8e85794..0000000
--- a/sitator/dynamics/RemoveShortJumps.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import numpy as np
-
-from collections import defaultdict
-
-from sitator import SiteTrajectory
-from sitator.dynamics import RemoveUnoccupiedSites
-from sitator.util import PBCCalculator
-
-import logging
-logger = logging.getLogger(__name__)
-
-# From https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
-def rle(inarray):
- """ run length encoding. Partial credit to R rle function.
- Multi datatype arrays catered for including non Numpy
- returns: tuple (runlengths, startpositions, values) """
- ia = np.asarray(inarray) # force numpy
- n = len(ia)
- if n == 0:
- return (None, None, None)
- else:
- y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe)
- i = np.append(np.where(y), n - 1) # must include last element posi
- z = np.diff(np.append(-1, i)) # run lengths
- p = np.cumsum(np.append(0, z))[:-1] # positions
- return (z, p, ia[i])
-
-
-class RemoveShortJumps(object):
- """Remove "short" jumps in a SiteTrajectory.
-
- Remove jumps where the residence at the target is less than some threshold
- and, optionally, only where the mobile atom returns to the site it originally
- jumped from.
-
- It only counts as a short jump if
-
- Args:
- only_returning_jumps (bool): If True, only short jumps
- where the mobile atom returns to its initial site will be removed.
- remove_unoccupied_sites (bool): If True, sites that are unoccupied after
- removing short jumps will be removed.
- replacement_function (callable): Callable that takes
- ``(st, mobile_atom, from_site, start_frame, to_site, end_frame)`` and
- returns either:
- - A single site assignment with which the short jump will be replaced
- - A timeseries of length ``end_frame - start_frame`` of site
- assignments with which the short jump will be replaced.
- If ``None``, defaults to ``RemoveShortJumps.replace_with_from``.
- """
- def __init__(self,
- only_returning_jumps = True,
- remove_unoccupied_sites = True,
- replacement_function = None):
- self.only_returning_jumps = only_returning_jumps
- self.remove_unoccupied_sites = remove_unoccupied_sites
- if replacement_function is None:
- replacement_function = RemoveShortJumps.replace_with_from
- self.replacement_function = replacement_function
-
- @staticmethod
- def replace_with_from(st, mobile_atom, from_site, start_frame, to_site, end_frame):
- """Replace a short jump with the site being jumped from."""
- return from_site
-
- @staticmethod
- def replace_with_to(st, mobile_atom, from_site, start_frame, to_site, end_frame):
- """Replace a short jump with the site being jumped to after the short jump."""
- return to_site
-
- @staticmethod
- def replace_with_unknown(st, mobile_atom, from_site, start_frame, to_site, end_frame):
- """Mark as unassigned during a short jump."""
- return SiteTrajectory.SITE_UNKNOWN
-
- @staticmethod
- def replace_with_closer():
- """Create function to replace short jump with closest site over time.
-
- Assigns the positions during a short jump to whichever of the from
- and to site it is closer to in real space.
- """
- pbcc = None
- ptbuf = np.empty(shape = (2, 3))
- distbuf = np.empty(shape = 2)
- def replace_with_closer(st, mobile_atom, from_site, start_frame, to_site, end_frame):
- if pbcc is None:
- pbcc = PBCCalculator(st.site_network.structure.cell)
- n_frames = end_frame - start_frame
- out = np.empty(shape = n_frames)
- for i in range(n_frames):
- ptbuf[0] = st.site_network.centers[from_site]
- ptbuf[1] = st.site_network.centers[to_site]
- pbcc.distances(
- st.real_trajectory[start_frame + i, mobile_atom],
- ptbuf,
- in_place = True,
- out = distbuf
- )
- if distbuf[0] < distbuf[1]:
- out[i] = from_site
- else:
- out[i] = to_site
- return out
-
- def run(self,
- st,
- threshold,
- return_stats = False):
- n_mobile = st.site_network.n_mobile
- n_frames = st.n_frames
- n_sites = st.site_network.n_sites
-
- st_no_un = st.copy(with_computed = False)
- st_no_un.assign_to_last_known_site(frame_threshold = np.inf)
-
- traj = st_no_un.traj
- out = st.traj.copy()
-
- # Dict of lists [sum_jump_times, n_short_jumps]
- short_jump_info = defaultdict(lambda: [0, 0])
-
- for mob in range(n_mobile):
- runlen, start, runsites = rle(traj[:, mob])
- # We pretend that the first and last run extend into infinity
- # Think Fourier transforms
- runlen[[0, -1]] = np.iinfo(runlen.dtype).max
- last_long_enough = 0
- short_start = None
- short_from = None
- short_transitionals = None
- for runi in range(0, len(runlen)):
- shortrun = runlen[runi] < threshold
- if shortrun:
- if short_start is None:
- short_start = start[runi]
- short_from = runsites[runi - 1]
- short_transitionals = [runsites[runi]]
- else:
- short_transitionals.append(runsites[runi])
- elif short_start is not None:
- # Process short jump
- short_to = runsites[runi]
- short_end = start[runi]
- # If we're only doing returning jumps, check that
- do = (not self.only_returning_jumps) or (short_to == short_from)
- if do:
- sjkey = (short_from, tuple(short_transitionals), short_to)
- short_jump_info[sjkey][0] += short_end - short_start
- short_jump_info[sjkey][1] += 1
- replace = self.replacement_function(
- st, mob,
- short_from, short_start,
- short_to, short_end
- )
- out[short_start:short_end, mob] = replace
- # Reset
- short_start = None
-
- # Do average
- for k in short_jump_info.keys():
- short_jump_info[k][0] /= short_jump_info[k][1]
- logger.info(
- "Short jump statistics:\n" +
- "\n".join(
- " removed {1[1]:3}x {0[0]:2} -> {0[1]} -> {0[2]:2}; spent {1[0]:.1f} frames at {0[1]}".format(
- k, v
- ) for k, v in short_jump_info.items()
- )
- )
-
- st = st.copy(with_computed = False)
- st._traj = out
- if self.remove_unoccupied_sites:
- # Removing short jumps could have made some sites completely unoccupied
- st = RemoveUnoccupiedSites().run(st)
- st.site_network.clear_attributes()
-
- if return_stats:
- return st, short_jump_info
- else:
- return st
diff --git a/sitator/dynamics/ReplaceUnassignedPositions.py b/sitator/dynamics/ReplaceUnassignedPositions.py
new file mode 100644
index 0000000..1ee2a21
--- /dev/null
+++ b/sitator/dynamics/ReplaceUnassignedPositions.py
@@ -0,0 +1,117 @@
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.util import PBCCalculator
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+# See https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
+def rle(inarray):
+ """ run length encoding. Partial credit to R rle function.
+ Multi datatype arrays catered for including non Numpy
+ returns: tuple (runlengths, startpositions, values) """
+ ia = np.asarray(inarray) # force numpy
+ n = len(ia)
+ if n == 0:
+ return (None, None, None)
+ else:
+ y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe)
+ i = np.append(np.where(y), n - 1) # must include last element posi
+ z = np.diff(np.append(-1, i)) # run lengths
+ p = np.cumsum(np.append(0, z))[:-1] # positions
+ return(z, p, ia[i])
+
+class ReplaceUnassignedPositions(object):
+ """Fill in missing site assignments in a SiteTrajectory.
+
+ Args:
+ replacement_function (callable): Callable that takes
+ ``(st, mobile_atom, before_site, start_frame, after_site, end_frame)`` and
+ returns either:
+ - A single site assignment with which the unassigned will be replaced
+ - A timeseries of length ``end_frame - start_frame`` of site
+ assignments with which the unassigned will be replaced.
+ If ``None``, defaults to
+ ``AssignUnassignedPositions.replace_with_from``.
+ """
+ def __init__(self,
+ replacement_function = None):
+ if replacement_function is None:
+ replacement_function = RemoveShortJumps.replace_with_last_known
+ self.replacement_function = replacement_function
+
+ @staticmethod
+ def replace_with_last_known(st, mobile_atom, before_site, start_frame, after_site, end_frame):
+ """Replace unassigned with the last known site."""
+ return before_site
+
+ @staticmethod
+ def replace_with_next_known(st, mobile_atom, before_site, start_frame, after_site, end_frame):
+ """Replace unassigned with the next known site."""
+ return after_site
+
+ @staticmethod
+ def replace_with_closer():
+ """Create function to replace unknown with closest site over time.
+
+ Assigns each of the positions during an unassigned run to whichever of
+ the before and after sites it is closer to in real space.
+ """
+ pbcc = None
+ ptbuf = np.empty(shape = (2, 3))
+ distbuf = np.empty(shape = 2)
+ def replace_with_closer(st, mobile_atom, before_site, start_frame, after_site, end_frame):
+ if before_site == SiteTrajectory.SITE_UNKNOWN or \
+ after_site == SiteTrajectory.SITE_UNKNOWN:
+ return SiteTrajectory.SITE_UNKNOWN
+
+ if pbcc is None:
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+ n_frames = end_frame - start_frame
+ out = np.empty(shape = n_frames)
+ for i in range(n_frames):
+ ptbuf[0] = st.site_network.centers[before_site]
+ ptbuf[1] = st.site_network.centers[after_site]
+ pbcc.distances(
+ st.real_trajectory[start_frame + i, mobile_atom],
+ ptbuf,
+ in_place = True,
+ out = distbuf
+ )
+ if distbuf[0] < distbuf[1]:
+ out[i] = before_site
+ else:
+ out[i] = after_site
+ return out
+
+
+ def run(self,
+ st):
+ n_mobile = st.site_network.n_mobile
+ n_frames = st.n_frames
+ n_sites = st.site_network.n_sites
+
+ traj = st.traj
+ out = st.traj.copy()
+
+ for mob in range(n_mobile):
+ runlen, start, runsites = rle(traj[:, mob])
+ for runi in range(len(runlen)):
+ if runsites[runi] == SiteTrajectory.SITE_UNKNOWN:
+ unknown_start = start[runi]
+ unknown_end = unknown_start + runlen[runi]
+ unknown_before = runsites[runi - 1] if runi > 0 else SiteTrajectory.SITE_UNKNOWN
+ unknown_after = runsites[runi + 1] if runi < len(runlen) - 1 else SiteTrajectory.SITE_UNKNOWN
+ replace = self.replacement_function(
+ st, mob,
+ unknown_before, unknown_start,
+ unknown_after, unknown_end
+ )
+ out[unknown_start:unknown_end, mob] = replace
+
+ st = st.copy(with_computed = False)
+ st._traj = out
+
+ return st
From 9da8054d3134d5db68eb016033245ac0801a332b Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 16:06:58 -0400
Subject: [PATCH 087/129] SiteTrajectory output improvements
---
sitator/SiteTrajectory.py | 4 ++--
sitator/visualization/SiteTrajectoryPlotter.py | 8 +++++++-
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index de1c02e..03686c9 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -226,7 +226,7 @@ def assign_to_last_known_site(self, frame_threshold = 1):
"""
total_unknown = self.n_unassigned
- logger.info("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %i frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold))
+ logger.info("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %s frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold))
last_known = np.empty(shape = self._sn.n_mobile, dtype = np.int)
last_known.fill(-1)
@@ -293,7 +293,7 @@ def jumps(self, unknown_as_jump = False):
- Frame 0: Atom 1 at site 4
- Frame 1: Atom 1 at site 5
-
+
will yield a jump ``(1, 1, 4, 5)``.
Args:
diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py
index 484850f..56f11ab 100644
--- a/sitator/visualization/SiteTrajectoryPlotter.py
+++ b/sitator/visualization/SiteTrajectoryPlotter.py
@@ -100,7 +100,13 @@ def plot_particle_trajectory(self, st, particle, ax = None, fig = None, **kwargs
val = last_value if current_value == -1 else current_value
segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
linestyles.append(':' if current_value == -1 else '-')
- colors.append('lightgray' if current_value == -1 else 'k')
+ if current_value == -1:
+ c = 'lightgray' # Unknown but reassigned
+ elif val == -1:
+ c = 'red' # Uncorrected unknown
+ else:
+ c = 'k' # Known
+ colors.append(c)
if types:
rxy = (current_segment_start, 0)
From bb827b360f110667088a1e01cf3a68fa3779576f Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 18:48:14 -0400
Subject: [PATCH 088/129] Determine directions of each periodic pathway
---
sitator/network/DiffusionPathwayAnalysis.py | 36 ++++++++++++++++-----
1 file changed, 28 insertions(+), 8 deletions(-)
diff --git a/sitator/network/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py
index 05439c5..390512f 100644
--- a/sitator/network/DiffusionPathwayAnalysis.py
+++ b/sitator/network/DiffusionPathwayAnalysis.py
@@ -39,7 +39,7 @@ def __init__(self,
self.connectivity_threshold = connectivity_threshold
self.minimum_n_sites = minimum_n_sites
- def run(self, sn, return_count = False):
+ def run(self, sn, return_count = False, return_direction = False):
"""
Expects a ``SiteNetwork`` that has had a ``JumpAnalysis`` run on it.
@@ -48,8 +48,11 @@ def run(self, sn, return_count = False):
Args:
sn (SiteNetwork): Must have jump statistics from a ``JumpAnalysis``.
return_count (bool): Return the number of connected pathways.
+ return_direction (bool): If True and `self.true_periodic_pathways`,
+ return for each pathway the direction matrix indicating which
+ directions it connects accross periodic boundaries.
Returns:
- sn, [n_pathways]
+ sn, [n_pathways], [list of set of tuple]
"""
if not sn.has_attribute('n_ij'):
raise ValueError("SiteNetwork has no `n_ij`; run a JumpAnalysis on it first.")
@@ -68,7 +71,7 @@ def run(self, sn, return_count = False):
connectivity_matrix = sn.n_ij >= threshold
if self.true_periodic_pathways:
- connectivity_matrix, mask_000 = self._build_mic_connmat(sn, connectivity_matrix)
+ connectivity_matrix, mask_000, images = self._build_mic_connmat(sn, connectivity_matrix)
n_ccs, ccs = connected_components(connectivity_matrix,
directed = False, # even though the matrix is symmetric
@@ -85,7 +88,8 @@ def run(self, sn, return_count = False):
# Add a non-path (contains no sites, all False) so the broadcasting works
site_masks = [np.zeros(shape = len(sn), dtype = np.bool)]
- #seen_mask = np.zeros(shape = len(sn), dtype = np.bool)
+
+ pathway_dirs = [set()]
for pathway_i in np.arange(n_ccs):
path_mask = ccs == pathway_i
@@ -100,20 +104,32 @@ def run(self, sn, return_count = False):
# Not percolating; doesn't contain any site and its periodic image.
continue
+ pdirs = set()
+ for periodic_site in np.where(site_counts > 1)[0]:
+ at_images = images[path_mask[periodic_site::len(sn)]]
+ # The direction from 0 to 1 should be the same as any other pair.
+ # Cause periodic.
+ direction = (at_images[0] - at_images[1]) != 0
+ pdirs.add(tuple(direction))
+
cur_site_mask = site_counts > 0
intersects_with = np.where(np.any(np.logical_and(site_masks, cur_site_mask), axis = 1))[0]
# Merge them:
if len(intersects_with) > 0:
path_mask = cur_site_mask | np.logical_or.reduce([site_masks[i] for i in intersects_with], axis = 0)
+ pdirs = pdirs.union(*[pathway_dirs[i] for i in intersects_with])
+ print("Merge pdirs: %s" % pdirs)
else:
path_mask = cur_site_mask
# Remove individual merged paths
# Going in reverse order means indexes don't become invalid as deletes happen
for i in sorted(intersects_with, reverse=True):
del site_masks[i]
+ del pathway_dirs[i]
# Add new (super)path
site_masks.append(path_mask)
+ pathway_dirs.append(pdirs)
new_ccs[path_mask] = new_n_ccs
new_n_ccs += 1
@@ -124,6 +140,8 @@ def run(self, sn, return_count = False):
# This will deal with the ones that were merged.
is_pathway = np.in1d(np.arange(n_ccs), ccs)
is_pathway[0] = False # Cause this was the "unassigned" value, we initialized with zeros up above
+ pathway_dirs = [pd for i, pd in enumerate(pathway_dirs) if is_pathway[i]]
+ assert len(pathway_dirs) == np.sum(is_pathway)
else:
is_pathway = counts >= self.minimum_n_sites
@@ -147,10 +165,12 @@ def run(self, sn, return_count = False):
sn.add_site_attribute('site_diffusion_pathway', node_pathways)
sn.add_edge_attribute('edge_diffusion_pathway', outmat)
+ retval = [sn]
if return_count:
- return sn, n_pathway
- else:
- return sn
+ retval.append(n_pathway)
+ if return_direction:
+ retval.append(pathway_dirs)
+ return tuple(retval)
def _build_mic_connmat(self, sn, connectivity_matrix):
@@ -206,4 +226,4 @@ def _build_mic_connmat(self, sn, connectivity_matrix):
assert np.sum(newmat) >= n_images * np.sum(internal_mat) # Lowest it can be is if every one is internal
- return newmat, mask_000
+ return newmat, mask_000, images
From d153c9fa756a04324f8fc87659c71bdd4b949062 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 18:48:27 -0400
Subject: [PATCH 089/129] Show lattice vector labels in visualizations
---
sitator/visualization/atoms.py | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/sitator/visualization/atoms.py b/sitator/visualization/atoms.py
index 153efa0..b4f76cd 100644
--- a/sitator/visualization/atoms.py
+++ b/sitator/visualization/atoms.py
@@ -37,8 +37,19 @@ def plot_atoms(atoms, positions = None, hide_species = (), wrap = False, fig = N
all_cvecs = []
whos_left = set(range(len(atoms.cell)))
+ cvec_labels = ["$\\vec{a}$", "$\\vec{b}$", "$\\vec{c}$"]
for i, cvec1 in enumerate(atoms.cell):
all_cvecs.append(np.array([[0.0, 0.0, 0.0], cvec1]))
+ ax.text(
+ cvec1[0] * 0.25,
+ cvec1[1] * 0.25,
+ cvec1[2] * 0.25,
+ cvec_labels[i],
+ size = 9,
+ color = 'gray',
+ ha = 'left',
+ va = 'center'
+ )
for j, cvec2 in enumerate(atoms.cell[list(whos_left - {i})]):
all_cvecs.append(np.array([cvec1, cvec1 + cvec2]))
for i, cvec1 in enumerate(atoms.cell):
From f95c697164a648f248623c53cb757edc2d90834e Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 17 Jul 2019 23:48:47 -0400
Subject: [PATCH 090/129] Import correction
---
sitator/dynamics/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index ac88173..812c503 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -2,7 +2,7 @@
from .MergeSitesByDynamics import MergeSitesByDynamics
from .MergeSitesByThreshold import MergeSitesByThreshold
from .RemoveUnoccupiedSites import RemoveUnoccupiedSites
-from .RemoveShortJumps import RemoveShortJumps
+from .SmoothSiteTrajectory import SmoothSiteTrajectory
from .AverageVibrationalFrequency import AverageVibrationalFrequency
# For backwards compatability, since this used to be in this module
From 71ccb57d65be8084170ab9515dfa0906e02cfcea Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 13:25:39 -0400
Subject: [PATCH 091/129] Catch degenerate site polyhedra
---
sitator/site_descriptors/SiteVolumes.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/sitator/site_descriptors/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py
index 0353aa5..e065680 100644
--- a/sitator/site_descriptors/SiteVolumes.py
+++ b/sitator/site_descriptors/SiteVolumes.py
@@ -86,6 +86,8 @@ def compute_volumes(self, sn):
Adds the ``site_volumes`` and ``site_surface_areas`` attributes.
+ Volumes can be NaN for degenerate hulls/point sets on which QHull fails.
+
Args:
- sn (SiteNetwork)
"""
@@ -114,9 +116,14 @@ def compute_volumes(self, sn):
pos += offset
pbcc.wrap_points(pos)
- hull = ConvexHull(pos)
- vols[site] = hull.volume
- areas[site] = hull.area
+ try:
+ hull = ConvexHull(pos)
+ vols[site] = hull.volume
+ areas[site] = hull.area
+ except QhullError as qhe:
+ logger.warning("Had QHull failure when computing volume of site %i" % site)
+ vols[site] = np.nan
+ areas[site] = np.nan
sn.add_site_attribute('site_volumes', vols)
sn.add_site_attribute('site_surface_areas', areas)
From 1faf90f73ec5dd423ee3d5bee4a4bb7f51bd355d Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 14:28:23 -0400
Subject: [PATCH 092/129] Fixed assertion error
---
sitator/network/DiffusionPathwayAnalysis.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/sitator/network/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py
index 390512f..7dc714c 100644
--- a/sitator/network/DiffusionPathwayAnalysis.py
+++ b/sitator/network/DiffusionPathwayAnalysis.py
@@ -119,7 +119,6 @@ def run(self, sn, return_count = False, return_direction = False):
if len(intersects_with) > 0:
path_mask = cur_site_mask | np.logical_or.reduce([site_masks[i] for i in intersects_with], axis = 0)
pdirs = pdirs.union(*[pathway_dirs[i] for i in intersects_with])
- print("Merge pdirs: %s" % pdirs)
else:
path_mask = cur_site_mask
# Remove individual merged paths
@@ -140,7 +139,7 @@ def run(self, sn, return_count = False, return_direction = False):
# This will deal with the ones that were merged.
is_pathway = np.in1d(np.arange(n_ccs), ccs)
is_pathway[0] = False # Cause this was the "unassigned" value, we initialized with zeros up above
- pathway_dirs = [pd for i, pd in enumerate(pathway_dirs) if is_pathway[i]]
+ pathway_dirs = pathway_dirs[1:] # Get rid of the dummy pathway's direction
assert len(pathway_dirs) == np.sum(is_pathway)
else:
is_pathway = counts >= self.minimum_n_sites
From 7199562ab04dab0bd188edfee2c1ad650dc12f2d Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 14:48:54 -0400
Subject: [PATCH 093/129] More intelegently log unfixable unknowns
---
sitator/dynamics/JumpAnalysis.py | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index 639b852..945fd7d 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -70,15 +70,11 @@ def run(self, st):
frame[unassigned] = last_known[unassigned]
fknown = (frame >= 0) & (last_known >= 0)
- if np.any(~fknown):
- logger.warning(" at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)))
+ n_problems += np.sum(~fknown)
# -- Update stats
total_time_spent_at_site[frame[fknown]] += 1
jumped = (frame != last_known) & fknown
- problems = last_known[jumped] == -1
- jumped[np.where(jumped)[0][problems]] = False
- n_problems += np.sum(problems)
n_ij[last_known[fknown], frame[fknown]] += 1
From 1e85b2733ba537aa1c3e5c08732bff9ad35802b0 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 16:06:05 -0400
Subject: [PATCH 094/129] Always give valences as list
---
sitator/site_descriptors/SiteCoordinationEnvironment.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index 87be7db..39d1c2a 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -83,6 +83,7 @@ def run(self, sn):
if np.any(mob_val != mob_val[0]):
logger.warning("Mobile atom estimated valences (%s) not uniform; arbitrarily taking first." % mob_val)
valences[site_atom_index] = mob_val[0]
+ finally:
valences = list(valences)
logger.info("Running site coordination environment analysis...")
From 36ac2950d352d6a33556ec27663a0928426026fe Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 16:06:13 -0400
Subject: [PATCH 095/129] Deal with uniform occupancies
---
sitator/visualization/SiteNetworkPlotter.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 54d2985..db1ae09 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -127,11 +127,15 @@ def _site_layers(self, sn, plot_points_params, same_normalization = False):
pts_arrays['c'] = val.copy()
if not same_normalization:
self._color_minmax = (np.min(val), np.max(val))
+ if self._color_minmax[0] == self._color_minmax[1]:
+ self._color_minmax[0] -= 1 # Just to avoid div by zero
color_minmax = self._color_minmax
pts_params['norm'] = matplotlib.colors.Normalize(vmin = color_minmax[0], vmax = color_minmax[1])
elif key == 'size':
if not same_normalization:
self._size_minmax = (np.min(val), np.max(val))
+ if self._size_minmax[0] == self._size_minmax[1]:
+ self._size_minmax[0] -= 1 # Just to avoid div by zero
size_minmax = self._size_minmax
s = val.copy()
s -= size_minmax[0]
From 1d8abe9bad7dc569a9720ee5565341abe8ce920b Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 17:07:37 -0400
Subject: [PATCH 096/129] Fixed showing uncorrectable unassigned
---
sitator/visualization/SiteTrajectoryPlotter.py | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py
index 56f11ab..5ed244e 100644
--- a/sitator/visualization/SiteTrajectoryPlotter.py
+++ b/sitator/visualization/SiteTrajectoryPlotter.py
@@ -98,12 +98,16 @@ def plot_particle_trajectory(self, st, particle, ax = None, fig = None, **kwargs
for i, f in enumerate(traj):
if f != current_value or i == len(traj) - 1:
val = last_value if current_value == -1 else current_value
- segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
+ if last_value == -1:
+ segments.append([[current_segment_start, val], [i, val]])
+ else:
+ segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
linestyles.append(':' if current_value == -1 else '-')
if current_value == -1:
- c = 'lightgray' # Unknown but reassigned
- elif val == -1:
- c = 'red' # Uncorrected unknown
+ if val == -1:
+ c = 'red' # Uncorrected unknown
+ else:
+ c = 'lightgray' # Unknown but reassigned
else:
c = 'k' # Known
colors.append(c)
From cdb52e7e92c6b749c62c85dc358999b5aca9dc1b Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 17:07:49 -0400
Subject: [PATCH 097/129] Improved centering scheme
---
sitator/util/PBCCalculator.pyx | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index 07f3381..57a0901 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -104,13 +104,18 @@ cdef class PBCCalculator(object):
Assumes that the points are relatively close (within a half unit cell)
together, and that the first point is not a particular outsider (the
- cell is centered at that point).
+ cell is centered at that point). If the average is weighted, the
+ maximally weighted point will be taken as the center.
Can be a weighted average with the semantics of :func:numpy.average.
"""
assert points.shape[1] == 3 and points.ndim == 2
- offset = self._cell_centroid - points[0]
+ center_about = 0
+ if weights is not None:
+ center_about = np.argmax(weights)
+
+ offset = self._cell_centroid - points[center_about]
ptbuf = points.copy()
From 6b096617fa469bdc3783f7f2f40813a4b58594ab Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 18 Jul 2019 19:19:58 -0400
Subject: [PATCH 098/129] Allow solving for site centers from representative
landmark vector
---
sitator/landmark/LandmarkAnalysis.py | 79 ++++++++++++++++++++++------
sitator/landmark/cluster/dotprod.py | 8 ++-
sitator/landmark/cluster/mcl.py | 17 +++---
3 files changed, 79 insertions(+), 25 deletions(-)
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 3dc646f..8351c69 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -34,9 +34,25 @@ class LandmarkAnalysis(object):
:param double cutoff_steepness: Steepness of the logistic cutoff function.
:param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied)
for a site to qualify as such.
+ :param str clustering_algorithm: The landmark clustering algorithm. ``sitator``
+ supplies two:
+ - ``"dotprod"``: The method described in our "Unsupervised landmark
+ analysis for jump detection in molecular dynamics simulations" paper.
+ - ``"mcl"``: A newer method we are developing.
:param dict clustering_params: Parameters for the chosen ``clustering_algorithm``.
- :param bool weighted_site_positions: When computing site positions, whether
- to weight the average by assignment confidence.
+ :param str site_centers_method: The method to use for computing the real
+ space positions of the sites. Options:
+ - ``SITE_CENTERS_REAL_UNWEIGHTED``: A spatial average of all real-space
+ mobile atom positions assigned to the site is taken.
+ - ``SITE_CENTERS_REAL_WEIGHTED``: A spatial average of all real-space
+ mobile atom positions assigned to the site is taken, weighted
+ by the confidences with which they assigned to the site.
+ - ``SITE_CENTERS_REPRESENTATIVE_LANDMARK``: A spatial average over
+ all landmarks' centers is taken, weighted by the representative
+ or "typical" landmark vector at the site.
+ The "real" methods will generally be more faithful to the simulation,
+ but the representative landmark method can work better in cases with
+ short trajectories, producing a more "ideal" site location.
:param bool check_for_zero_landmarks: Whether to check for and raise exceptions
when all-zero landmark vectors are computed.
:param float static_movement_threshold: (Angstrom) the maximum allowed
@@ -65,13 +81,24 @@ class LandmarkAnalysis(object):
:param bool verbose: Verbosity for the ``clustering_algorithm``. Other output
controlled through ``logging``.
"""
+
+ SITE_CENTERS_REAL_UNWEIGHTED = 'real-unweighted'
+ SITE_CENTERS_REAL_WEIGHTED = 'real-weighted'
+ SITE_CENTERS_REPRESENTATIVE_LANDMARK = 'representative-landmark'
+
+ CLUSTERING_CLUSTER_SIZE = 'cluster-size'
+ CLUSTERING_LABELS = 'cluster-labels'
+ CLUSTERING_CONFIDENCES = 'cluster-confs'
+ CLUSTERING_LANDMARK_GROUPINGS = 'cluster-landmark-groupings'
+ CLUSTERING_REPRESENTATIVE_LANDMARKS = 'cluster-representative-lvecs'
+
def __init__(self,
clustering_algorithm = 'dotprod',
clustering_params = {},
cutoff_midpoint = 1.5,
cutoff_steepness = 30,
minimum_site_occupancy = 0.01,
- weighted_site_positions = True,
+ site_centers_method = SITE_CENTERS_REAL_WEIGHTED,
check_for_zero_landmarks = True,
static_movement_threshold = 1.0,
dynamic_lattice_mapping = False,
@@ -88,7 +115,7 @@ def __init__(self,
self.verbose = verbose
self.check_for_zero_landmarks = check_for_zero_landmarks
- self.weighted_site_positions = weighted_site_positions
+ self.site_centers_method = site_centers_method
self.dynamic_lattice_mapping = dynamic_lattice_mapping
self.relaxed_lattice_checks = relaxed_lattice_checks
@@ -214,14 +241,19 @@ def run(self, sn, frames):
min_samples = self._minimum_site_occupancy / float(sn.n_mobile),
verbose = self.verbose)
- if len(clustering) == 3:
- cluster_counts, lmk_lbls, lmk_confs = clustering
- landmark_clusters = None
- elif len(clustering) == 4:
- cluster_counts, lmk_lbls, lmk_confs, landmark_clusters = clustering
+ cluster_counts = clustering[LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE]
+ lmk_lbls = clustering[LandmarkAnalysis.CLUSTERING_LABELS]
+ lmk_confs = clustering[LandmarkAnalysis.CLUSTERING_CONFIDENCES]
+ if LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS in clustering:
+ landmark_clusters = clustering[LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS]
assert len(cluster_counts) == len(landmark_clusters)
else:
- raise ValueError("Clustering function returned invalid result %s" % clustering)
+ landmark_clusters = None
+ if LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS in clustering:
+ rep_lvecs = np.asarray(clustering[LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS])
+ assert rep_lvecs.shape == (len(cluster_counts), self._landmark_vectors.shape[1])
+ else:
+ rep_lvecs = None
logging.info(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))))
@@ -240,13 +272,26 @@ def run(self, sn, frames):
out_sn = sn.copy()
# - Compute site centers
site_centers = np.empty(shape = (n_sites, 3), dtype = frames.dtype)
- for site in range(n_sites):
- mask = lmk_lbls == site
- pts = frames[:, sn.mobile_mask][mask]
- if self.weighted_site_positions:
- site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask])
- else:
- site_centers[site] = self._pbcc.average(pts)
+ if self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_WEIGHTED or \
+ self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_UNWEIGHTED:
+ for site in range(n_sites):
+ mask = lmk_lbls == site
+ pts = frames[:, sn.mobile_mask][mask]
+ if self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_WEIGHTED:
+ site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask])
+ else:
+ site_centers[site] = self._pbcc.average(pts)
+ elif self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REPRESENTATIVE_LANDMARK:
+ if rep_lvecs is None:
+ raise ValueError("Chosen clustering method (with current parameters) didn't return representative landmark vectors; can't use SITE_CENTERS_REPRESENTATIVE_LANDMARK.")
+ for site in range(n_sites):
+ weights_nonzero = rep_lvecs[site] > 0
+ site_centers[site] = self._pbcc.average(
+ sn.centers[weights_nonzero],
+ weights = rep_lvecs[site, weights_nonzero]
+ )
+ else:
+ raise ValueError("Invalid site centers method '%s'" % self.site_centers_method)
out_sn.centers = site_centers
# - If clustering gave us that, compute site vertices
if landmark_clusters is not None:
diff --git a/sitator/landmark/cluster/dotprod.py b/sitator/landmark/cluster/dotprod.py
index 20cfc81..bd98296 100644
--- a/sitator/landmark/cluster/dotprod.py
+++ b/sitator/landmark/cluster/dotprod.py
@@ -1,5 +1,6 @@
from sitator.util import DotProdClassifier
+from sitator.landmark import LandmarkAnalysis
DEFAULT_PARAMS = {
'clustering_threshold' : 0.45,
@@ -23,4 +24,9 @@ def do_landmark_clustering(landmark_vectors,
predict_threshold = clustering_params['assignment_threshold'],
verbose = verbose)
- return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
+ return {
+ LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts,
+ LandmarkAnalysis.CLUSTERING_LABELS: lmk_lbls,
+ LandmarkAnalysis.CLUSTERING_CONFIDENCES : lmk_confs,
+ LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : landmark_classifier.cluster_centers
+ }
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index eba0ed2..0596327 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -3,6 +3,7 @@
from sitator.util.progress import tqdm
from sitator.util.mcl import markov_clustering
from sitator.util import DotProdClassifier
+from sitator.landmark import LandmarkAnalysis
from sklearn.covariance import empirical_covariance
@@ -73,10 +74,12 @@ def do_landmark_clustering(landmark_vectors,
msk = info['kept_clusters_mask']
clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
-
- return (
- landmark_classifier.cluster_counts,
- lmk_lbls,
- lmk_confs,
- clusters
- )
+ centers = [c for i, c in enumerate(centers) if msk[i]] # Only need the ones above the threshold
+
+ return {
+ LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts,
+ LandmarkAnalysis.CLUSTERING_LABELS : lmk_lbls,
+ LandmarkAnalysis.CLUSTERING_CONFIDENCES: lmk_confs,
+ LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS : clusters,
+ LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : np.abs(centers)
+ }
From 161f6fdebf70d3802636e98862f6c9c9d6d925ad Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 19 Jul 2019 10:55:18 -0400
Subject: [PATCH 099/129] Improved representative center
---
sitator/landmark/cluster/mcl.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 0596327..39a7f48 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -75,11 +75,17 @@ def do_landmark_clustering(landmark_vectors,
msk = info['kept_clusters_mask']
clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
centers = [c for i, c in enumerate(centers) if msk[i]] # Only need the ones above the threshold
+ # If it's negative, their all negative and its in "quadrant III", so abs is safe
+ centers = np.abs(centers)
+ # The most important landmark contributes unity
+ centers /= np.max(centers, axis = 1)[:, np.newaxis]
+ # Exaggerate the contributions of the more maximal landmarks
+ centers = np.square(centers)
return {
LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts,
LandmarkAnalysis.CLUSTERING_LABELS : lmk_lbls,
LandmarkAnalysis.CLUSTERING_CONFIDENCES: lmk_confs,
LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS : clusters,
- LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : np.abs(centers)
+ LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : centers
}
From 969d2d025e3dbfefe6ea10d7c5f9e86264987aff Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 19 Jul 2019 11:27:37 -0400
Subject: [PATCH 100/129] Coordination Number Site Types
---
.../SiteCoordinationEnvironment.py | 24 +++++++++++++++----
1 file changed, 20 insertions(+), 4 deletions(-)
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index 39d1c2a..4fb3b64 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -42,13 +42,26 @@ class SiteCoordinationEnvironment(object):
- ``coordination_numbers``: The coordination number of the site.
Args:
+ guess_ionic_bonds (bool): If True, uses ``pymatgen``'s bond valence
+ analysis to guess valences and only consider ionic bonds for
+ neighbor analysis. Otherwise, or if it fails, all bonds are fair game.
+ full_chemenv_site_types (bool): If True, ``sitator`` site types on the
+ final ``SiteNetwork`` will be assigned based on unique chemical
+ environments, including shape. If False, they will be assigned
+ solely based on coordination number. Either way, both sets of information
+ are included in the ``SiteNetwork``, this just changes which determines
+ the ``site_types``.
**kwargs: passed to ``compute_structure_environments``.
"""
- def __init__(self, guess_ionic_bonds = True, **kwargs):
+ def __init__(self,
+ guess_ionic_bonds = True,
+ full_chemenv_site_types = False,
+ **kwargs):
if not has_pymatgen:
raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
self._kwargs = kwargs
self._guess_ionic_bonds = guess_ionic_bonds
+ self._full_chemenv_site_types = full_chemenv_site_types
def run(self, sn):
"""
@@ -124,20 +137,23 @@ def run(self, sn):
# -- Postprocess
# TODO: allow user to ask for full fractional breakdown
- unique_envs = list(set(env['ce_symbol'] for env in coord_envs))
- site_types = np.array([unique_envs.index(env['ce_symbol']) for env in coord_envs])
+ str_coord_environments = [env['ce_symbol'] for env in coord_envs]
# The closer to 1 this is, the better
site_type_confidences = np.array([env['ce_fraction'] for env in coord_envs])
coordination_numbers = np.array([int(env['ce_symbol'].split(':')[1]) for env in coord_envs])
assert np.all(coordination_numbers == [len(v) for v in vertices])
+ typearr = str_coord_environments if self._full_chemenv_site_types else coordination_numbers
+ unique_envs = list(set(typearr))
+ site_types = np.array([unique_envs.index(t) for t in typearr])
+
n_types = len(unique_envs)
logger.info(("Type " + "{:<8}" * n_types).format(*unique_envs))
logger.info(("# of sites " + "{:<8}" * n_types).format(*np.bincount(site_types)))
sn.site_types = site_types
sn.vertices = vertices
- sn.add_site_attribute("coordination_environments", [env['ce_symbol'] for env in coord_envs])
+ sn.add_site_attribute("coordination_environments", str_coord_environments)
sn.add_site_attribute("site_type_confidences", site_type_confidences)
sn.add_site_attribute("coordination_numbers", coordination_numbers)
From 1def69b89d69aec8afd8f95f2f710047300d40c2 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 22 Jul 2019 11:24:40 -0400
Subject: [PATCH 101/129] Mean-centered Landmark Clustering
---
sitator/landmark/cluster/mcl.py | 28 ++++++++++++++++++----------
1 file changed, 18 insertions(+), 10 deletions(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 39a7f48..6df2de2 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -35,8 +35,11 @@ def do_landmark_clustering(landmark_vectors,
clustering_params = tmp
n_lmk = landmark_vectors.shape[1]
-
- cov = empirical_covariance(landmark_vectors)
+ # Center landmark vectors
+ seen_ntimes = np.count_nonzero(landmark_vectors, axis = 0)
+ mean = np.mean(landmark_vectors, axis = 0)
+ landmark_vectors -= mean
+ cov = empirical_covariance(landmark_vectors, assume_centered = True)
corr = cov2corr(cov)
graph = np.clip(corr, 0, None)
for i in range(n_lmk):
@@ -47,7 +50,8 @@ def do_landmark_clustering(landmark_vectors,
# -- Cluster Landmarks
clusters = markov_clustering(graph, **clustering_params)
- clusters = [list(c) for c in clusters]
+ # Filter out single element clusters of landmarks that never appear.
+ clusters = [list(c) for c in clusters if seen_ntimes[c[0]] > 0]
n_clusters = len(clusters)
centers = np.zeros(shape = (n_clusters, n_lmk))
for i, cluster in enumerate(clusters):
@@ -72,15 +76,19 @@ def do_landmark_clustering(landmark_vectors,
verbose = verbose,
return_info = True)
+ # Shift landmark vectors back
+ landmark_vectors += mean
+
msk = info['kept_clusters_mask']
clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
- centers = [c for i, c in enumerate(centers) if msk[i]] # Only need the ones above the threshold
- # If it's negative, their all negative and its in "quadrant III", so abs is safe
- centers = np.abs(centers)
- # The most important landmark contributes unity
- centers /= np.max(centers, axis = 1)[:, np.newaxis]
- # Exaggerate the contributions of the more maximal landmarks
- centers = np.square(centers)
+
+ # Find the average landmark vector at each site
+ centers = np.zeros(shape = (len(clusters), n_lmk))
+ mask = np.empty(shape = lmk_lbls.shape, dtype = np.bool)
+ for site in range(len(clusters)):
+ np.equal(lmk_lbls, site, out = mask)
+ centers[site] = np.average(landmark_vectors, weights = mask, axis = 0)
+
return {
LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts,
From d7196a393304284de15457b549ed9102931696a8 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 22 Jul 2019 11:24:52 -0400
Subject: [PATCH 102/129] Add confidences getter
---
sitator/SiteTrajectory.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 03686c9..b4cb4d8 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -57,6 +57,10 @@ def traj(self):
"""The site assignments over time."""
return self._traj
+ @property
+ def confidences(self):
+ return self._confs
+
@property
def n_frames(self):
"""The number of frames in the trajectory."""
From a78e015488eeaba72dcdce5373af973cebda0a92 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 22 Jul 2019 14:31:13 -0400
Subject: [PATCH 103/129] Small fixes
---
sitator/SiteTrajectory.py | 10 ++++++++++
sitator/visualization/SiteNetworkPlotter.py | 4 ++--
2 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index b4cb4d8..32c15f2 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -52,6 +52,16 @@ def __getitem__(self, key):
st.set_real_traj(self._real_traj[key])
return st
+ def __getstate__(self):
+ # Copy the object's state from self.__dict__ which contains
+ # all our instance attributes. Always use the dict.copy()
+ # method to avoid modifying the original state.
+ state = self.__dict__.copy()
+ # Don't want to pickle giant trajectories or uninteresting plotters
+ state['_real_traj'] = None
+ state['_default_plotter'] = None
+ return state
+
@property
def traj(self):
"""The site assignments over time."""
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index db1ae09..8c529e8 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -126,14 +126,14 @@ def _site_layers(self, sn, plot_points_params, same_normalization = False):
elif key == 'color':
pts_arrays['c'] = val.copy()
if not same_normalization:
- self._color_minmax = (np.min(val), np.max(val))
+ self._color_minmax = [np.min(val), np.max(val)]
if self._color_minmax[0] == self._color_minmax[1]:
self._color_minmax[0] -= 1 # Just to avoid div by zero
color_minmax = self._color_minmax
pts_params['norm'] = matplotlib.colors.Normalize(vmin = color_minmax[0], vmax = color_minmax[1])
elif key == 'size':
if not same_normalization:
- self._size_minmax = (np.min(val), np.max(val))
+ self._size_minmax = [np.min(val), np.max(val)]
if self._size_minmax[0] == self._size_minmax[1]:
self._size_minmax[0] -= 1 # Just to avoid div by zero
size_minmax = self._size_minmax
From dbd9add151ef1f1629f29435b570e4d959853ea8 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Mon, 22 Jul 2019 18:44:39 -0400
Subject: [PATCH 104/129] Weighted representative landmarks
---
sitator/landmark/cluster/mcl.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 6df2de2..29931cd 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -83,11 +83,14 @@ def do_landmark_clustering(landmark_vectors,
clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
# Find the average landmark vector at each site
+ weighted_reps = clustering_params.get('weighted_representative_landmarks', True)
centers = np.zeros(shape = (len(clusters), n_lmk))
- mask = np.empty(shape = lmk_lbls.shape, dtype = np.bool)
+ weights = np.empty(shape = lmk_lbls.shape)
for site in range(len(clusters)):
- np.equal(lmk_lbls, site, out = mask)
- centers[site] = np.average(landmark_vectors, weights = mask, axis = 0)
+ np.equal(lmk_lbls, site, out = weights)
+ if weighted_reps:
+ weights *= lmk_confs
+ centers[site] = np.average(landmark_vectors, weights = weights, axis = 0)
return {
From 74668758c3d6ee2ed7b0e9607fdef83336a26198 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Tue, 23 Jul 2019 00:33:05 -0400
Subject: [PATCH 105/129] Remove faulty mean shifting
---
sitator/landmark/cluster/mcl.py | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 29931cd..d4a2ea9 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -37,9 +37,7 @@ def do_landmark_clustering(landmark_vectors,
n_lmk = landmark_vectors.shape[1]
# Center landmark vectors
seen_ntimes = np.count_nonzero(landmark_vectors, axis = 0)
- mean = np.mean(landmark_vectors, axis = 0)
- landmark_vectors -= mean
- cov = empirical_covariance(landmark_vectors, assume_centered = True)
+ cov = empirical_covariance(landmark_vectors, assume_centered = False)
corr = cov2corr(cov)
graph = np.clip(corr, 0, None)
for i in range(n_lmk):
@@ -61,6 +59,7 @@ def do_landmark_clustering(landmark_vectors,
# PCA inspired:
eigenval, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
centers[i, cluster] = eigenvec.T
+ centers[i, cluster] /= np.sqrt(len(cluster))
landmark_classifier = \
@@ -72,13 +71,10 @@ def do_landmark_clustering(landmark_vectors,
lmk_lbls, lmk_confs, info = \
landmark_classifier.fit_predict(landmark_vectors,
predict_threshold = predict_threshold,
- predict_normed = True,
+ predict_normed = False,
verbose = verbose,
return_info = True)
- # Shift landmark vectors back
- landmark_vectors += mean
-
msk = info['kept_clusters_mask']
clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
From a88894d37ddfdae03b85e5f1a5e4ee13c94ed414 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 24 Jul 2019 14:46:05 -0400
Subject: [PATCH 106/129] Basic correct normalization
---
sitator/landmark/cluster/dotprod.py | 1 +
sitator/landmark/cluster/mcl.py | 11 +++++++++--
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/sitator/landmark/cluster/dotprod.py b/sitator/landmark/cluster/dotprod.py
index bd98296..5b98e90 100644
--- a/sitator/landmark/cluster/dotprod.py
+++ b/sitator/landmark/cluster/dotprod.py
@@ -1,3 +1,4 @@
+"""Cluster landmark vectors using the custom online algorithm from the original paper."""
from sitator.util import DotProdClassifier
from sitator.landmark import LandmarkAnalysis
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index d4a2ea9..0c128f6 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -1,3 +1,5 @@
+"""Cluster landmarks into sites using Markov Clustering and then assign each landmark vector."""
+
import numpy as np
from sitator.util.progress import tqdm
@@ -37,7 +39,7 @@ def do_landmark_clustering(landmark_vectors,
n_lmk = landmark_vectors.shape[1]
# Center landmark vectors
seen_ntimes = np.count_nonzero(landmark_vectors, axis = 0)
- cov = empirical_covariance(landmark_vectors, assume_centered = False)
+ cov = np.dot(landmark_vectors.T, landmark_vectors) / landmark_vectors.shape[0]
corr = cov2corr(cov)
graph = np.clip(corr, 0, None)
for i in range(n_lmk):
@@ -52,6 +54,7 @@ def do_landmark_clustering(landmark_vectors,
clusters = [list(c) for c in clusters if seen_ntimes[c[0]] > 0]
n_clusters = len(clusters)
centers = np.zeros(shape = (n_clusters, n_lmk))
+ maxbuf = np.empty(shape = len(landmark_vectors))
for i, cluster in enumerate(clusters):
if len(cluster) == 1:
centers[i, cluster] = 1.0 # Eigenvec is trivial case; scale doesn't matter either.
@@ -59,7 +62,11 @@ def do_landmark_clustering(landmark_vectors,
# PCA inspired:
eigenval, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
centers[i, cluster] = eigenvec.T
- centers[i, cluster] /= np.sqrt(len(cluster))
+ np.dot(landmark_vectors, centers[i], out = maxbuf)
+ np.abs(maxbuf, out = maxbuf)
+ max_projection = np.max(maxbuf)
+ if max_projection > 0:
+ centers[i] /= np.max(maxbuf)
landmark_classifier = \
From 109996fdb46cfcc85fc762c71856f9b284d34d9a Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 24 Jul 2019 17:44:20 -0400
Subject: [PATCH 107/129] Best match representative lvec scaling
---
sitator/landmark/cluster/mcl.py | 46 ++++++++++++++++++++++++++-------
1 file changed, 36 insertions(+), 10 deletions(-)
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
index 0c128f6..741cfcf 100644
--- a/sitator/landmark/cluster/mcl.py
+++ b/sitator/landmark/cluster/mcl.py
@@ -1,4 +1,16 @@
-"""Cluster landmarks into sites using Markov Clustering and then assign each landmark vector."""
+"""Cluster landmarks into sites using Markov Clustering and then assign each landmark vector.
+
+Valid clustering params include:
+ - ``"assignment_threshold"`` (float between 0 and 1): The similarity threshold
+ below which a landmark vector will be marked unassigned.
+ - ``"good_site_normed_threshold"`` (float between 0 and 1): The minimum for
+ the cosine similarity between a good site's representative unit vector and
+ its best match landmark vector.
+ - ``"good_site_projected_threshold"`` (positive float): The minimum inner product
+ between a good site's representative unit vector and its best match
+ landmark vector.
+ - All other params are passed along to `sitator.util.mcl.markov_clustering`.
+"""
import numpy as np
@@ -47,6 +59,8 @@ def do_landmark_clustering(landmark_vectors,
graph[i, i] = 1 # Needs a self loop for Markov clustering not to degenerate. Arbitrary value, shouldn't affect anyone else.
predict_threshold = clustering_params.pop('assignment_threshold')
+ good_site_normed_threshold = clustering_params.pop('good_site_normed_threshold', predict_threshold)
+ good_site_project_thresh = clustering_params.pop('good_site_projected_threshold', predict_threshold)
# -- Cluster Landmarks
clusters = markov_clustering(graph, **clustering_params)
@@ -55,19 +69,31 @@ def do_landmark_clustering(landmark_vectors,
n_clusters = len(clusters)
centers = np.zeros(shape = (n_clusters, n_lmk))
maxbuf = np.empty(shape = len(landmark_vectors))
+ good_clusters = np.zeros(shape = n_clusters, dtype = np.bool)
for i, cluster in enumerate(clusters):
if len(cluster) == 1:
- centers[i, cluster] = 1.0 # Eigenvec is trivial case; scale doesn't matter either.
+ eigenvec = [1.0] # Eigenvec is trivial
else:
# PCA inspired:
- eigenval, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
- centers[i, cluster] = eigenvec.T
- np.dot(landmark_vectors, centers[i], out = maxbuf)
- np.abs(maxbuf, out = maxbuf)
- max_projection = np.max(maxbuf)
- if max_projection > 0:
- centers[i] /= np.max(maxbuf)
-
+ _, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
+ eigenvec = eigenvec.T
+ centers[i, cluster] = eigenvec
+ np.dot(landmark_vectors, centers[i], out = maxbuf)
+ np.abs(maxbuf, out = maxbuf)
+ best_match = np.argmax(maxbuf)
+ best_match_lvec = landmark_vectors[best_match]
+ best_match_dot = np.abs(np.dot(best_match_lvec, centers[i]))
+ best_match_dot_norm = best_match_dot / np.linalg.norm(best_match_lvec)
+ good_clusters[i] = best_match_dot_norm >= good_site_normed_threshold
+ good_clusters[i] &= best_match_dot >= good_site_project_thresh
+ centers[i] /= best_match_dot
+
+ logger.debug("Kept %i/%i landmark clusters as good sites" % (np.sum(good_clusters), len(good_clusters)))
+
+ # Filter out "bad" sites
+ clusters = [c for i, c in enumerate(clusters) if good_clusters[i]]
+ centers = centers[good_clusters]
+ n_clusters = len(clusters)
landmark_classifier = \
DotProdClassifier(threshold = np.nan, # We're not fitting
From e1b48c8ec98fb90a38eb69cb99cae31395f61192 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 24 Jul 2019 18:15:07 -0400
Subject: [PATCH 108/129] Added `site_ids` convenience property for plotting
---
sitator/SiteNetwork.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index ed89aab..edb4d9c 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -207,6 +207,11 @@ def vertices(self):
"""The static atoms defining each site."""
return self._vertices
+ @property
+ def site_ids(self):
+ """Convenience property giving the index of each site."""
+ return np.arange(self.n_sites)
+
@vertices.setter
def vertices(self, value):
if not len(value) == len(self._centers):
From b5c9bad3cdeab25e715e367ca7fcb7c287b890d0 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 25 Jul 2019 15:58:23 -0400
Subject: [PATCH 109/129] Removed redundant distance computations
---
sitator/util/PBCCalculator.pyx | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index 57a0901..cd6ee42 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -48,10 +48,14 @@ cdef class PBCCalculator(object):
buf = pts.copy()
- for i in xrange(len(pts)):
- self.distances(pts[i], buf, in_place = True, out = out[i])
+ for i in xrange(len(pts) - 1):
+ out[i, i] = 0
+ self.distances(pts[i], buf[i + 1:], in_place = True, out = out[i, i + 1:])
+ out[i + 1:, i] = out[i, i + 1:]
buf[:] = pts
+ out[len(pts) - 1, len(pts) - 1] = 0
+
return out
From 458e1e4abcc83a1562b2e7a647644fbd0c133ba8 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 25 Jul 2019 15:58:28 -0400
Subject: [PATCH 110/129] Inheritance fix
---
sitator/SiteNetwork.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index ed89aab..cbeb923 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -95,9 +95,13 @@ def __len__(self):
return self.n_sites
def __getitem__(self, key):
- sn = type(self)(self.structure,
- self.static_mask,
- self.mobile_mask)
+ sn = self.__new__(type(self))
+ SiteNetwork.__init__(
+ sn,
+ self.structure,
+ self.static_mask,
+ self.mobile_mask
+ )
if not self._centers is None:
sn.centers = self._centers[key]
From d56852b2152e35327f5e2b01b88432bf95d724d5 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 25 Jul 2019 17:08:18 -0400
Subject: [PATCH 111/129] Allow preallocated distance matrix buffer
---
sitator/util/PBCCalculator.pyx | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index cd6ee42..c841fd2 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -39,12 +39,13 @@ cdef class PBCCalculator(object):
return self._cell_centroid
- cpdef pairwise_distances(self, pts):
+ cpdef pairwise_distances(self, pts, out = None):
"""Compute the pairwise distance matrix of ``pts`` with itself.
:returns ndarray (len(pts), len(pts)): distances
"""
- out = np.empty(shape = (len(pts), len(pts)), dtype = pts.dtype)
+ if out is None:
+ out = np.empty(shape = (len(pts), len(pts)), dtype = pts.dtype)
buf = pts.copy()
From dc01bdccb25d71dbdaa240ece55a074cfcc3bd51 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 25 Jul 2019 17:19:27 -0400
Subject: [PATCH 112/129] Added feature to keep short transitional jumps while
removing failed attempts
---
sitator/dynamics/SmoothSiteTrajectory.pyx | 39 +++++++++++++++++++----
1 file changed, 32 insertions(+), 7 deletions(-)
diff --git a/sitator/dynamics/SmoothSiteTrajectory.pyx b/sitator/dynamics/SmoothSiteTrajectory.pyx
index 34497ef..379b2d7 100644
--- a/sitator/dynamics/SmoothSiteTrajectory.pyx
+++ b/sitator/dynamics/SmoothSiteTrajectory.pyx
@@ -15,20 +15,33 @@ class SmoothSiteTrajectory(object):
For each mobile particle, the assignmet at each frame is replaced by the
mode of its site assignments over some number of frames centered around it.
- If the multiplicity of the mode is less than the threshold, the particle is
- marked unassigned at that frame.
Can be thought of as a discrete lowpass filter.
+ The ``set_unassigned_under_threshold`` parameter allows the user to control
+ how the smoothing handles "transitions" vs. "attempts"; setting it to True,
+ the default, will mark as unassigned transitional moments where neither
+ the source nor destination site have a sufficient (``threshold``) majority
+ in the window, while setting it to False will maintain the assignment to a
+ transitional site.
+
Args:
+ window_threshold_factor (float): The total width of the rolling window,
+ in terms of the threshold.
remove_unoccupied_sites (bool): If True, sites that are unoccupied after
- removing short jumps will be removed.
+ the smoothing will be removed.
+ set_unassigned_under_threshold (bool): If True, if the multiplicity of
+ the mode is less than the threshold, the particle is marked
+ unassigned at that frame. If False, the particle's assignment will
+ not be modified.
"""
def __init__(self,
window_threshold_factor = 2.1,
- remove_unoccupied_sites = True):
+ remove_unoccupied_sites = True,
+ set_unassigned_under_threshold = True):
self.window_threshold_factor = window_threshold_factor
self.remove_unoccupied_sites = remove_unoccupied_sites
+ self.set_unassigned_under_threshold = set_unassigned_under_threshold
def run(self,
st,
@@ -43,7 +56,15 @@ class SmoothSiteTrajectory(object):
window = self.window_threshold_factor * threshold
wleft, wright = int(np.floor(window / 2)), int(np.ceil(window / 2))
- running_windowed_mode(traj, out, wleft, wright, threshold, n_sites)
+ running_windowed_mode(
+ traj,
+ out,
+ wleft,
+ wright,
+ threshold,
+ n_sites,
+ self.set_unassigned_under_threshold
+ )
st = st.copy(with_computed = False)
st._traj = out
@@ -60,7 +81,8 @@ cpdef running_windowed_mode(site_int [:, :] traj,
Py_ssize_t wleft,
Py_ssize_t wright,
Py_ssize_t threshold,
- Py_ssize_t n_sites):
+ Py_ssize_t n_sites,
+ bint replace_no_winner_unknown):
countbuf_np = np.zeros(shape = n_sites + 1, dtype = np.int)
cdef Py_ssize_t [:] countbuf = countbuf_np
cdef Py_ssize_t n_mobile = traj.shape[1]
@@ -82,5 +104,8 @@ cpdef running_windowed_mode(site_int [:, :] traj,
if best_count >= threshold:
out[frame, mob] = winner - 1
else:
- out[frame, mob] = s_unknown
+ if replace_no_winner_unknown:
+ out[frame, mob] = s_unknown
+ else:
+ out[frame, mob] = traj[frame, mob]
countbuf[:] = 0
From a82ee95f77255d338d949a96b6dd7e384f1923d4 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 26 Jul 2019 12:02:41 -0400
Subject: [PATCH 113/129] Converted GenerateClampedTrajectory to Cython for
performance
---
setup.py | 44 +++++++-----
...ctory.py => GenerateClampedTrajectory.pyx} | 68 +++++++++++++------
2 files changed, 75 insertions(+), 37 deletions(-)
rename sitator/misc/{GenerateClampedTrajectory.py => GenerateClampedTrajectory.pyx} (59%)
diff --git a/setup.py b/setup.py
index 89d74a7..9194cd1 100644
--- a/setup.py
+++ b/setup.py
@@ -1,33 +1,43 @@
from setuptools import setup, find_packages
from Cython.Build import cythonize
+import Cython.Compiler
import numpy as np
-setup(name = 'sitator',
- version = '2.0.0',
- description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.',
- download_url = "https://github.com/Linux-cpp-lisp/sitator",
- author = 'Alby Musaelian',
- license = "MIT",
- python_requires = '>=3.2',
- packages = find_packages(),
- ext_modules = cythonize([
+# Allows cimport'ing PBCCalculator
+Cython.Compiler.Options.cimport_from_pyx = True
+
+setup(
+ name = 'sitator',
+ version = '2.0.0',
+ description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.',
+ download_url = "https://github.com/Linux-cpp-lisp/sitator",
+ author = 'Alby Musaelian',
+ license = "MIT",
+ python_requires = '>=3.2',
+ packages = find_packages(),
+ ext_modules = cythonize(
+ [
"sitator/landmark/helpers.pyx",
"sitator/util/*.pyx",
- "sitator/dynamics/*.pyx"
- ], language_level = 3),
- include_dirs=[np.get_include()],
- install_requires = [
+ "sitator/dynamics/*.pyx",
+ "sitator/misc/*.pyx"
+ ],
+ language_level = 3
+ ),
+ include_dirs=[np.get_include()],
+ install_requires = [
"numpy",
"scipy",
"matplotlib",
"ase",
"tqdm",
"sklearn"
- ],
- extras_require = {
+ ],
+ extras_require = {
"SiteTypeAnalysis" : [
"pydpc",
"dscribe"
]
- },
- zip_safe = True)
+ },
+ zip_safe = True
+)
diff --git a/sitator/misc/GenerateClampedTrajectory.py b/sitator/misc/GenerateClampedTrajectory.pyx
similarity index 59%
rename from sitator/misc/GenerateClampedTrajectory.py
rename to sitator/misc/GenerateClampedTrajectory.pyx
index c06484c..d442664 100644
--- a/sitator/misc/GenerateClampedTrajectory.py
+++ b/sitator/misc/GenerateClampedTrajectory.pyx
@@ -1,9 +1,14 @@
+# cython: language_level=3
+
import numpy as np
from sitator import SiteTrajectory
-from sitator.util import PBCCalculator
+from sitator.util.PBCCalculator cimport PBCCalculator, precision
from sitator.util.progress import tqdm
+from libc.math cimport floor
+
+ctypedef Py_ssize_t site_int
class GenerateClampedTrajectory(object):
"""Create a real-space trajectory with the fixed site/static structure positions.
@@ -46,7 +51,7 @@ def run(self, st, clamp_mask = None):
wrap = self.wrap
pass_through_unassigned = self.pass_through_unassigned
cell = st._sn.structure.cell
- pbcc = PBCCalculator(cell)
+ cdef PBCCalculator pbcc = PBCCalculator(cell)
n_atoms = len(st._sn.structure)
if clamp_mask is None:
@@ -68,33 +73,56 @@ def run(self, st, clamp_mask = None):
if not pass_through_unassigned and np.min(selected_sitetraj) < 0:
raise RuntimeError("The mobile atoms indicated for clamping are unassigned at some point during the trajectory and `pass_through_unassigned` is set to False. Try `assign_to_last_known_site()`?")
+ cdef site_int at_site
+ cdef Py_ssize_t frame_i
+ cdef Py_ssize_t mobile_i
+ cdef Py_ssize_t [:] mobile_clamp_indexes_c = mobile_clamp_indexes
+ cdef precision [:, :] buf
+ cdef precision [:] site_pt
+ cdef site_int site_unknown = SiteTrajectory.SITE_UNKNOWN
+ cdef const site_int [:, :] sitetrj_c = st._traj
+ cdef precision [:, :, :] clamptrj_c = clamptrj
+ cdef const precision [:, :, :] realtrj_c = st._real_traj
+ cdef const precision [:, :] centers_c = st.site_network.centers
+ cdef int site_mic_int
+ cdef int [3] site_mic
+ cdef int [3] pt_in_image
+ cdef precision [:, :] centers_crystal_c
+ cdef Py_ssize_t dim
if wrap:
for frame_i in tqdm(range(len(clamptrj))):
- for mobile_i in mobile_clamp_indexes:
- at_site = st._traj[frame_i, mobile_i]
- if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
- clamptrj[frame_i, mobile_i] = st._real_traj[frame_i, mobile_i]
- continue
- clamptrj[frame_i, mobile_i] = st._sn.centers[at_site]
+ for mobile_i in mobile_clamp_indexes_c:
+ at_site = sitetrj_c[frame_i, mobile_i]
+ if at_site == site_unknown: # we already know that this means pass_through_unassigned = True
+ clamptrj_c[frame_i, mobile_i] = realtrj_c[frame_i, mobile_i]
+ else:
+ clamptrj_c[frame_i, mobile_i] = centers_c[at_site]
else:
buf = np.empty(shape = (1, 3))
site_pt = np.empty(shape = 3)
+ centers_crystal_c = st.site_network.centers.copy()
+ pbcc.to_cell_coords(centers_crystal_c)
for frame_i in tqdm(range(len(clamptrj))):
- for mobile_i in mobile_clamp_indexes:
- buf[:, :] = st._real_traj[frame_i, mobile_i]
- at_site = st._traj[frame_i, mobile_i]
- if at_site == SiteTrajectory.SITE_UNKNOWN: # we already know that this means pass_through_unassigned = True
- clamptrj[frame_i, mobile_i] = st._real_traj[frame_i, mobile_i]
+ for mobile_i in mobile_clamp_indexes_c:
+ buf[:, :] = realtrj_c[frame_i, mobile_i]
+ at_site = sitetrj_c[frame_i, mobile_i]
+ if at_site == site_unknown: # we already know that this means pass_through_unassigned = True
+ clamptrj_c[frame_i, mobile_i] = realtrj_c[frame_i, mobile_i]
continue
- site_pt[:] = st._sn.centers[at_site]
+ site_pt[:] = centers_c[at_site]
pbcc.wrap_point(site_pt)
pbcc.wrap_points(buf)
- site_mic = pbcc.min_image(buf[0], site_pt)
- site_mic = [(site_mic // 10**(2 - i) % 10) - 1 for i in range(3)]
- buf[:, :] = st._real_traj[frame_i, mobile_i]
+ site_mic_int = pbcc.min_image(buf[0], site_pt)
+ for dim in range(3):
+ site_mic[dim] = (site_mic_int // 10**(2 - dim) % 10) - 1
+ buf[:, :] = realtrj_c[frame_i, mobile_i]
pbcc.to_cell_coords(buf)
- pt_in_image = np.floor(buf[0])
- pt_in_image += site_mic
- clamptrj[frame_i, mobile_i] = np.dot(pt_in_image, cell) + st._sn.centers[at_site]
+ for dim in range(3):
+ pt_in_image[dim] = floor(buf[0, dim]) + site_mic[dim]
+ buf[0] = centers_crystal_c[at_site]
+ for dim in range(3):
+ buf[0, dim] += pt_in_image[dim]
+ pbcc.to_real_coords(buf)
+ clamptrj_c[frame_i, mobile_i] = buf[0]
return clamptrj
From bac8a50e7615341efd9f728220ff2f702c3c1096 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 26 Jul 2019 14:48:17 -0400
Subject: [PATCH 114/129] Improved README formatting
---
README.md | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 980810a..c93db2c 100644
--- a/README.md
+++ b/README.md
@@ -44,13 +44,13 @@ pip install ".[SiteTypeAnalysis]"
## Examples and Documentation
-Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/).
+Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12 as in our paper, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/).
+
+Full API documentation can be found at [ReadTheDocs](https://sitator.readthedocs.io/en/py3/).
`sitator` generally assumes units of femtoseconds for time, Angstroms for space,
and Cartesian (not crystal) coordinates.
-Documentation can be found at [ReadTheDocs](https://sitator.readthedocs.io/en/py3/).
-
## Global Options
`sitator` uses the `tqdm.autonotebook` tool to automatically produce the correct fancy progress bars for terminals and iPython notebooks. To disable all progress bars, run with the environment variable `SITATOR_PROGRESSBAR` set to `false`.
@@ -59,4 +59,4 @@ The `SITATOR_ZEO_PATH` and `SITATOR_QUIP_PATH` environment variables can set the
## License
-This software is made available under the MIT License. See `LICENSE` for more details.
+This software is made available under the MIT License. See [`LICENSE`](LICENSE) for more details.
From b10cfb5c49729b3111f47b1cf8daedf03ef30b03 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 31 Jul 2019 15:23:48 -0600
Subject: [PATCH 115/129] Improved jump iterators
---
sitator/SiteTrajectory.py | 42 +++++++++++++++++++++++++++++++++------
1 file changed, 36 insertions(+), 6 deletions(-)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 32c15f2..2d74099 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -299,8 +299,8 @@ def assign_to_last_known_site(self, frame_threshold = 1):
return res
- def jumps(self, unknown_as_jump = False):
- """Generator to iterate over all jumps in the trajectory.
+ def jumps(self, **kwargs):
+ """Iterate over all jumps in the trajectory, jump by jump.
A jump is considered to occur "at the frame" when it first acheives its
new site. For example,
@@ -317,6 +317,39 @@ def jumps(self, unknown_as_jump = False):
Yields:
tuple: (frame_number, mobile_atom_number, from_site, to_site)
"""
+ n_mobile = self.site_network.n_mobile
+ for frame_i, jumped, last_known, frame in self._jumped_generator(**kwargs):
+ for atom_i in range(n_mobile):
+ if jumped[atom_i]:
+ yield frame_i, atom_i, last_known[atom_i], frame[atom_i]
+
+ def jumps_by_frame(self, **kwargs):
+ """Iterate over all jumps in the trajectory, frame by frame.
+
+ A jump is considered to occur "at the frame" when it first acheives its
+ new site. For example,
+
+ - Frame 0: Atom 1 at site 4
+ - Frame 1: Atom 1 at site 5
+
+ will yield a jump ``(1, 1, 4, 5)``.
+
+ Args:
+ unknown_as_jump (bool): If ``True``, moving from a site to unknown
+ (or vice versa) is considered a jump; if ``False``, unassigned
+ mobile atoms are considered to be at their last known sites.
+ Yields:
+ tuple: (frame_number, mob_that_jumped, from_sites, to_sites)
+ """
+ n_mobile = self.site_network.n_mobile
+ for frame_i, jumped, last_known, frame in self._jumped_generator(**kwargs):
+ yield frame_i, np.where(jumped)[0], last_known[jumped], frame[jumped]
+
+ def _jumped_generator(self, unknown_as_jump = False):
+ """Internal jump generator that does not create intermediate arrays.
+
+ Wrapped by convinience functions.
+ """
traj = self.traj
n_mobile = self.site_network.n_mobile
assert n_mobile == traj.shape[1]
@@ -330,13 +363,10 @@ def jumps(self, unknown_as_jump = False):
np.not_equal(traj[frame_i], last_known, out = jumped)
jumped &= known # Must be currently known to have jumped
- for atom_i in range(n_mobile):
- if jumped[atom_i]:
- yield frame_i, atom_i, last_known[atom_i], traj[frame_i, atom_i]
+ yield frame_i, jumped, last_known, traj[frame_i]
last_known[known] = traj[frame_i, known]
-
# ---- Plotting code
def plot_frame(self, *args, **kwargs):
if self._default_plotter is None:
From ec8dcba96afe09927ad521bf152728acb428b8f3 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 31 Jul 2019 15:28:20 -0600
Subject: [PATCH 116/129] Refactored exception structure
---
sitator/SiteTrajectory.py | 9 +++++++--
sitator/errors.py | 21 +++++++++++++++++++++
sitator/landmark/LandmarkAnalysis.py | 8 ++++++--
sitator/landmark/__init__.py | 2 +-
sitator/landmark/errors.py | 4 ----
5 files changed, 35 insertions(+), 9 deletions(-)
create mode 100644 sitator/errors.py
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index 2d74099..db602d0 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -211,7 +211,7 @@ def check_multiple_occupancy(self, max_mobile_per_site = 1):
int: the total number of multiple assignment incidents; and
float: the average number of mobile atoms at any site at any one time.
"""
- from sitator.landmark.errors import MultipleOccupancyError
+ from sitator.errors import MultipleOccupancyError
n_more_than_ones = 0
avg_mobile_per_site = 0
divisor = 0
@@ -219,7 +219,12 @@ def check_multiple_occupancy(self, max_mobile_per_site = 1):
_, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
count_msk = counts > max_mobile_per_site
if np.any(count_msk):
- raise MultipleOccupancyError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i))
+ first_multi_site = np.where(count_msk)[0][0]
+ raise MultipleOccupancyError(
+ mobile = np.where(site_frame == first_multi_site)[0],
+ site = first_multi_site,
+ frame = frame_i
+ )
n_more_than_ones += np.sum(counts > 1)
avg_mobile_per_site += np.sum(counts)
divisor += len(counts)
diff --git a/sitator/errors.py b/sitator/errors.py
new file mode 100644
index 0000000..06572a4
--- /dev/null
+++ b/sitator/errors.py
@@ -0,0 +1,21 @@
+
+class SiteAnaysisError(Exception):
+ """An error occuring as part of site analysis."""
+ pass
+
+class MultipleOccupancyError(SiteAnaysisError):
+ """Error raised when multiple mobile atoms are assigned to the same site at the same time."""
+ def __init__(self, mobile, site, frame):
+ super().__init__(
+ "Multiple mobile particles %s were assigned to site %i at frame %i." % (mobile, site, frame)
+ )
+ self.mobile_particles = mobile
+ self.site = site
+ self.frame = frame
+
+class InsufficientSitesError(SiteAnaysisError):
+ """Site detection/merging/etc. resulted in fewer sites than mobile particles."""
+ def __init__(self, verb, n_sites, n_mobile):
+ super().__init__("%s resulted in only %i sites for %i mobile particles." % (verb, n_sites, n_mobile))
+ self.n_sites = n_sites
+ self.n_mobile = n_mobile
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 8351c69..2ade8b8 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -10,7 +10,7 @@
from . import helpers
from sitator import SiteNetwork, SiteTrajectory
-from .errors import MultipleOccupancyError
+from sitator.errors import MultipleOccupancyError
import logging
@@ -264,7 +264,11 @@ def run(self, sn, frames):
n_sites = len(cluster_counts)
if n_sites < (sn.n_mobile / self.max_mobile_per_site):
- raise MultipleOccupancyError("There are %i mobile particles, but only identified %i sites. With %i max_mobile_per_site, this is an error. Check clustering_params." % (sn.n_mobile, n_sites, self.max_mobile_per_site))
+ raise InsufficientSitesError(
+ verb = "Landmark analysis",
+ n_sites = n_sites,
+ n_mobile = sn.n_mobile
+ )
logging.info(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
diff --git a/sitator/landmark/__init__.py b/sitator/landmark/__init__.py
index 2af5cf2..a084133 100644
--- a/sitator/landmark/__init__.py
+++ b/sitator/landmark/__init__.py
@@ -1,4 +1,4 @@
-from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError, MultipleOccupancyError
+from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError
from .LandmarkAnalysis import LandmarkAnalysis
diff --git a/sitator/landmark/errors.py b/sitator/landmark/errors.py
index dbc12e6..69f06eb 100644
--- a/sitator/landmark/errors.py
+++ b/sitator/landmark/errors.py
@@ -38,7 +38,3 @@ def __init__(self, mobile_index, frame):
self.mobile_index = mobile_index
self.frame = frame
-
-class MultipleOccupancyError(LandmarkAnalysisError):
- """Error raised when multiple mobile atoms are assigned to the same site."""
- pass
From 98ea12a5b2e88bd52d3fc19d37f9723683d3a8b7 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 31 Jul 2019 15:31:36 -0600
Subject: [PATCH 117/129] Use InsufficientSitesError consistantly
---
sitator/dynamics/RemoveUnoccupiedSites.py | 9 +++++++++
sitator/network/merging.py | 10 ++++++----
2 files changed, 15 insertions(+), 4 deletions(-)
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
index 3a7d0df..22f3a41 100644
--- a/sitator/dynamics/RemoveUnoccupiedSites.py
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -1,6 +1,7 @@
import numpy as np
from sitator import SiteTrajectory
+from sitator.errors import InsufficientSitesError
import logging
logger = logging.getLogger(__name__)
@@ -37,6 +38,14 @@ def run(self, st, return_kept_sites = False):
logger.info("Removing unoccupied sites %s" % np.where(~seen_mask)[0])
n_new_sites = np.sum(seen_mask)
+
+ if n_new_sites < old_sn.n_mobile:
+ raise InsufficientSitesError(
+ verb = "Removing unoccupied sites",
+ n_sites = n_new_sites,
+ n_mobile = old_sn.n_mobile
+ )
+
translation = np.empty(shape = old_sn.n_sites + 1, dtype = np.int)
translation[:-1][seen_mask] = np.arange(n_new_sites)
translation[:-1][~seen_mask] = -4321
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
index 25f7a51..c714e7c 100644
--- a/sitator/network/merging.py
+++ b/sitator/network/merging.py
@@ -4,6 +4,7 @@
from sitator.util import PBCCalculator
from sitator import SiteNetwork, SiteTrajectory
+from sitator.errors import InsufficientSitesError
import logging
logger = logging.getLogger(__name__)
@@ -14,9 +15,6 @@ class MergeSitesError(Exception):
class MergedSitesTooDistantError(MergeSitesError):
pass
-class TooFewMergedSitesError(MergeSitesError):
- pass
-
class MergeSites(abc.ABC):
"""Abstract base class for merging sites.
@@ -63,7 +61,11 @@ def run(self, st, **kwargs):
logger.info("After merging %i sites there will be %i sites for %i mobile particles" % (len(site_centers), new_n_sites, st.site_network.n_mobile))
if new_n_sites < st.site_network.n_mobile:
- raise TooFewMergedSitesError("There are %i mobile atoms in this system, but only %i sites after merge" % (np.sum(st.site_network.mobile_mask), new_n_sites))
+ raise InsufficientSitesError(
+ verb = "Merging",
+ n_sites = new_n_sites,
+ n_mobile = st.site_network.n_mobile
+ )
if self.check_types:
new_types = np.empty(shape = new_n_sites, dtype = np.int)
From a87063b367d4583a825517864cb1a1431d5ce2ba Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 31 Jul 2019 15:32:50 -0600
Subject: [PATCH 118/129] Gracefully handle a lack of sites
---
sitator/site_descriptors/SiteCoordinationEnvironment.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index 4fb3b64..bd68aed 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -72,6 +72,9 @@ def run(self, sn):
"""
# -- Determine local environments
# Get an ASE structure with a single mobile site that we'll move around
+ if sn.n_sites == 0:
+ logger.warning("Site network had no sites.")
+ return sn
site_struct, site_species = sn[0:1].get_structure_with_sites()
pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
lgf = cgf.LocalGeometryFinder()
From f3ee865bec3d6bd3ae2b3a5e4e23eee391f7cd7f Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Wed, 31 Jul 2019 22:06:52 -0600
Subject: [PATCH 119/129] Use correct site in error message
---
sitator/SiteTrajectory.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index db602d0..a994e28 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -216,10 +216,10 @@ def check_multiple_occupancy(self, max_mobile_per_site = 1):
avg_mobile_per_site = 0
divisor = 0
for frame_i, site_frame in enumerate(self._traj):
- _, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
+ sites, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
count_msk = counts > max_mobile_per_site
if np.any(count_msk):
- first_multi_site = np.where(count_msk)[0][0]
+ first_multi_site = sites[count_msk][0]
raise MultipleOccupancyError(
mobile = np.where(site_frame == first_multi_site)[0],
site = first_multi_site,
From bea93703816fb5ad0335f3453aa8c1c8cb42f381 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 1 Aug 2019 13:46:22 -0600
Subject: [PATCH 120/129] Added missing import
---
sitator/visualization/SiteTrajectoryPlotter.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py
index 5ed244e..9994dfe 100644
--- a/sitator/visualization/SiteTrajectoryPlotter.py
+++ b/sitator/visualization/SiteTrajectoryPlotter.py
@@ -1,3 +1,4 @@
+import numpy as np
import matplotlib
from matplotlib.collections import LineCollection
From 3aaf1a9afc90d7c2defc64829121e827126946c3 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 1 Aug 2019 13:46:34 -0600
Subject: [PATCH 121/129] Added better __str__
---
sitator/SiteNetwork.py | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 007a20a..bd71be4 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -94,6 +94,25 @@ def copy(self, with_computed = True):
def __len__(self):
return self.n_sites
+ def __str__(self):
+ return (
+ "{}: {:d} sites for {:d} mobile particles in static lattice of {:d} particles\n"
+ " Has vertices: {}\n"
+ " Has types: {}\n"
+ " Has site attributes: {}\n"
+ " Has edge attributes: {}\n"
+ ""
+ ).format(
+ type(self).__name__,
+ self.n_sites,
+ self.n_mobile,
+ self.n_static,
+ self._vertices is not None,
+ self._types is not None,
+ ", ".join(self._site_attrs.keys()),
+ ", ".join(self._edge_attrs.keys())
+ )
+
def __getitem__(self, key):
sn = self.__new__(type(self))
SiteNetwork.__init__(
From 2eab52a03208c94ed3000091b56eba9a8be70822 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 1 Aug 2019 13:49:24 -0600
Subject: [PATCH 122/129] Added IonicSiteNetwork
---
sitator/ionic.py | 127 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 127 insertions(+)
create mode 100644 sitator/ionic.py
diff --git a/sitator/ionic.py b/sitator/ionic.py
new file mode 100644
index 0000000..7f8dcb0
--- /dev/null
+++ b/sitator/ionic.py
@@ -0,0 +1,127 @@
+import numpy as np
+
+from sitator import SiteNetwork
+
+import ase.data
+
+try:
+ from pymatgen.io.ase import AseAtomsAdaptor
+ import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf
+ from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \
+ LightStructureEnvironments
+ from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions
+ from pymatgen.analysis.bond_valence import BVAnalyzer
+ has_pymatgen = True
+except ImportError:
+ has_pymatgen = False
+
+
+class IonicSiteNetwork(SiteNetwork):
+ """Site network for a species of mobile charged ions in a static lattice.
+
+ Imposes more restrictions than a plain ``SiteNetwork``:
+
+ - Has a mobile species in place of an arbitrary mobile mask
+ - Has a set of static species in place of an arbitraty static mask
+ - All atoms of the mobile species must have the same charge
+
+ And contains more information...
+
+ Attributes:
+ opposite_ion_mask (ndarray): A mask on ``structure`` indicating all
+ anions and neutrally charged atoms if the mobile species is a cation,
+ or vice versa if the mobile species is an anion.
+ opposite_ion_structure (ase.Atoms): An ``Atoms`` containing the atoms
+ indicated by ``opposite_ion_mask``.
+ same_ion_mask (ndarray): A mask on ``structure`` indicating all
+ atoms whose charge has the same sign as the mobile species.
+ same_ion_structure (ase.Atoms): An ``Atoms`` containing the atoms
+ indicated by ``same_ion_mask``.
+ n_opposite_charge (int): The number of opposite charge static atoms.
+ n_same_charge (int): The number of same charge static atoms.
+
+ Args:
+ structure (ase.Atoms)
+ mobile_species (int): Atomic number of the mobile species.
+ static_species (list of int): Atomic numbers of the static species.
+ mobile_charge (int): Charge of mobile atoms. If ``None``,
+ ``pymatgen``'s ``BVAnalyzer`` will be used to estimate valences.
+ static_charges (ndarray int): Charges of the atoms in the static
+ structure. If ``None``, ``sitator`` will try to use
+ ``pymatgen``'s ``BVAnalyzer`` to estimate valences.
+ """
+ def __init__(self,
+ structure,
+ mobile_species,
+ static_species,
+ mobile_charge = None,
+ static_charges = None):
+ if mobile_species in static_species:
+ raise ValueError("Mobile species %i cannot also be one of static species %s" % (mobile_species, static_species))
+ mobile_mask = structure.numbers == mobile_species
+ static_mask = np.in1d(structure.numbers, static_species)
+ super().__init__(
+ structure = structure,
+ mobile_mask = mobile_mask,
+ static_mask = static_mask
+ )
+
+ self.mobile_species = mobile_species
+ self.static_species = static_species
+ # Estimate bond valences if necessary
+ if mobile_charge is None or static_charges is None:
+ if not has_pymatgen:
+ raise ImportError("Pymatgen could not be imported, and is required for guessing charges.")
+ sim_struct = AseAtomsAdaptor.get_structure(structure)
+ bv = BVAnalyzer()
+ struct_valences = np.asarray(bv.get_valences(sim_struct))
+ if static_charges is None:
+ static_charges = struct_valences[static_mask]
+ if mobile_charge is None:
+ mob_val = struct_valences[mobile_mask]
+ if np.any(mob_val != mob_val[0]):
+ raise ValueError("Mobile atom estimated valences (%s) not uniform; arbitrarily taking first." % mob_val)
+ mobile_charge = mob_val[0]
+ self.mobile_charge = mobile_charge
+ self.static_charges = static_charges
+
+ # Create oposite ion stuff
+ mobile_sign = np.sign(mobile_charge)
+ static_signs = np.sign(static_charges)
+ self.opposite_ion_mask = np.empty_like(static_mask)
+ self.opposite_ion_mask.fill(False)
+ self.opposite_ion_mask[static_mask] = static_signs != mobile_sign
+ self.opposite_ion_structure = structure[self.opposite_ion_mask]
+
+ self.same_ion_mask = np.empty_like(static_mask)
+ self.same_ion_mask.fill(False)
+ self.same_ion_mask[static_mask] = static_signs == mobile_sign
+ self.same_ion_structure = structure[self.same_ion_mask]
+
+ @property
+ def n_opposite_charge(self):
+ return np.sum(self.opposite_ion_mask)
+
+ @property
+ def n_same_charge(self):
+ return np.sum(self.same_ion_mask)
+
+ def __str__(self):
+ out = super().__str__()
+ static_nums = self.static_structure.numbers
+ out += (
+ " Mobile species: {:2} (charge {:+d})\n"
+ " Static species: {}\n"
+ " # opposite charge: {}\n"
+ ).format(
+ ase.data.chemical_symbols[self.mobile_species],
+ self.mobile_charge,
+ ", ".join(
+ "{} (avg. charge {:+.1f})".format(
+ ase.data.chemical_symbols[s],
+ np.mean(self.static_charges[static_nums == s])
+ ) for s in self.static_species
+ ),
+ self.n_opposite_charge
+ )
+ return out
From df2ca086b40744cbd6902a21cb4cde31c726e652 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 1 Aug 2019 13:49:33 -0600
Subject: [PATCH 123/129] Allow choosing seed atoms
---
sitator/voronoi.py | 15 ++++++++++++---
1 file changed, 12 insertions(+), 3 deletions(-)
diff --git a/sitator/voronoi.py b/sitator/voronoi.py
index 5b94dbf..748e819 100644
--- a/sitator/voronoi.py
+++ b/sitator/voronoi.py
@@ -20,7 +20,7 @@ def __init__(self,
self._radial = radial
self._zeopy = Zeopy(zeopp_path)
- def run(self, sn):
+ def run(self, sn, seed_mask = None):
"""
Args:
sn (SiteNetwork): Any sites will be ignored; needed for structure
@@ -30,12 +30,21 @@ def run(self, sn):
"""
assert isinstance(sn, SiteNetwork)
+ if seed_mask is None:
+ seed_mask = sn.static_mask
+ assert not np.any(seed_mask & sn.mobile_mask), "Seed mask must not overlap with mobile mask"
+ assert not np.any(seed_mask & ~sn.static_mask), "All seed atoms must be static."
+ voro_struct = sn.structure[seed_mask]
+ translation = np.zeros(shape = len(sn.static_mask), dtype = np.int)
+ translation[sn.static_mask] = np.arange(sn.n_static)
+ translation = translation[seed_mask]
+
with self._zeopy:
- nodes, verts, edges, _ = self._zeopy.voronoi(sn.static_structure,
+ nodes, verts, edges, _ = self._zeopy.voronoi(voro_struct,
radial = self._radial)
out = sn.copy()
out.centers = nodes
- out.vertices = verts
+ out.vertices = [translation[v] for v in verts]
return out
From fe6eeec2ff74f694d73a698e92e674d9cc0d788b Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 1 Aug 2019 18:43:11 -0600
Subject: [PATCH 124/129] Added constraints for dynamic remapping
---
sitator/landmark/LandmarkAnalysis.py | 27 +++++++--
sitator/landmark/dynamic_mapping.py | 5 ++
sitator/landmark/helpers.pyx | 86 +++++++++++++++-------------
3 files changed, 74 insertions(+), 44 deletions(-)
create mode 100644 sitator/landmark/dynamic_mapping.py
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 2ade8b8..5250afc 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -57,8 +57,8 @@ class LandmarkAnalysis(object):
when all-zero landmark vectors are computed.
:param float static_movement_threshold: (Angstrom) the maximum allowed
distance between an instantanous static atom position and it's ideal position.
- :param bool dynamic_lattice_mapping: Whether to dynamically decide each
- frame which static atom represents each average lattice position;
+ :param bool/callable dynamic_lattice_mapping: Whether to dynamically decide
+ each frame which static atom represents each average lattice position;
this allows the LandmarkAnalysis to deal with, say, a rare exchage of
two static atoms that does not change the structure of the lattice.
@@ -66,6 +66,12 @@ class LandmarkAnalysis(object):
actually change over the course of the trajectory.
In certain cases this is better delt with by ``MergeSitesByDynamics``.
+
+ If ``False``, no mapping will occur. Otherwise, a callable taking a
+ ``SiteNetwork`` should be provided. The callable should return a list
+ of static atom indexes that can be validly assigned to each static lattice
+ position. If ``True``, ``sitator.landmark.dynamic_mapping.within_species``
+ is used.
:param int max_mobile_per_site: The maximum number of mobile atoms that can
be assigned to a single site without throwing an error. Regardless of the
value, assignments of more than one mobile atom to a single site will
@@ -116,7 +122,12 @@ def __init__(self,
self.verbose = verbose
self.check_for_zero_landmarks = check_for_zero_landmarks
self.site_centers_method = site_centers_method
+
+ if dynamic_lattice_mapping is True:
+ from sitator.landmark.dynamic_mapping import within_species
+ dynamic_lattice_mapping = within_species
self.dynamic_lattice_mapping = dynamic_lattice_mapping
+
self.relaxed_lattice_checks = relaxed_lattice_checks
self._landmark_vectors = None
@@ -203,7 +214,13 @@ def run(self, sn, frames):
# -- Step 2: Compute landmark vectors
logger.info(" - computing landmark vectors -")
- # Compute landmark vectors
+
+ if self.dynamic_lattice_mapping:
+ dynmap_compat = self.dynamic_lattice_mapping(sn)
+ else:
+ # If no dynamic mapping, each is only compatable with itself.
+ dynmap_compat = np.arange(sn.n_static)[:, np.newaxis]
+ assert len(dynmap_compat) == sn.n_static
# The dimension of one landmark vector is the number of Voronoi regions
shape = (n_frames * sn.n_mobile, self._landmark_dimension)
@@ -218,7 +235,9 @@ def run(self, sn, frames):
shape = shape)
helpers._fill_landmark_vectors(self, sn, verts_np, site_vert_dists,
- frames, check_for_zeros = self.check_for_zero_landmarks,
+ frames,
+ dynmap_compat = dynmap_compat,
+ check_for_zeros = self.check_for_zero_landmarks,
tqdm = tqdm, logger = logger)
if not self.check_for_zero_landmarks and self.n_all_zero_lvecs > 0:
diff --git a/sitator/landmark/dynamic_mapping.py b/sitator/landmark/dynamic_mapping.py
new file mode 100644
index 0000000..be7ce83
--- /dev/null
+++ b/sitator/landmark/dynamic_mapping.py
@@ -0,0 +1,5 @@
+import numpy as np
+
+def within_species(sn):
+ nums = sn.static_structure.numbers
+ return [(nums == nums[s]).nonzero()[0] for s in range(sn.n_static)]
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index 1dd736e..180ddc5 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -9,7 +9,15 @@ from sitator.landmark import StaticLatticeError, ZeroLandmarkError
ctypedef double precision
-def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i, logger = None):
+def _fill_landmark_vectors(self,
+ sn,
+ verts_np,
+ site_vert_dists,
+ frames,
+ dynmap_compat,
+ check_for_zeros = True,
+ tqdm = lambda i: i,
+ logger = None):
if self._landmark_dimension is None:
raise ValueError("_fill_landmark_vectors called before Voronoi!")
@@ -24,15 +32,15 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
mobile_idexes = np.where(sn.mobile_mask)[0]
- if self.dynamic_lattice_mapping:
- lattice_map = np.empty(shape = sn.n_static, dtype = np.int)
- else:
- # Otherwise just map to themselves
- lattice_map = np.arange(sn.n_static, dtype = np.int)
+ lattice_map = np.empty(shape = sn.n_static, dtype = np.int)
lattice_pt = np.empty(shape = 3, dtype = sn.static_structure.positions.dtype)
- lattice_pt_dists = np.empty(shape = sn.n_static, dtype = np.float)
+ max_n_dynmat_compat = max(len(dm) for dm in dynmap_compat)
+ lattice_pt_dists = np.empty(shape = max_n_dynmat_compat, dtype = np.float)
+ static_pos_buffer = np.empty(shape = (max_n_dynmat_compat, 3), dtype = lattice_pt.dtype)
static_positions_seen = np.empty(shape = sn.n_static, dtype = np.bool)
+ static_positions = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype)
+ static_mask_idexes = sn.static_mask.nonzero()[0]
# - Precompute cutoff function rounding point
# TODO: Think about the 0.0001 value
@@ -46,25 +54,41 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
cdef Py_ssize_t landmark_dim = self._landmark_dimension
cdef Py_ssize_t current_landmark_i = 0
+
+ cdef Py_ssize_t nearest_static_position
+ cdef precision nearest_static_distance
+ cdef Py_ssize_t n_dynmap_allowed
# Iterate through time
for i, frame in enumerate(tqdm(frames, desc = "Landmark Frame")):
- static_positions = frame[sn.static_mask]
+ #static_positions = frame[sn.static_mask]
+ np.take(frame,
+ static_mask_idexes,
+ out = static_positions,
+ axis = 0,
+ mode = 'clip')
# Every frame, update the lattice map
static_positions_seen.fill(False)
for lattice_index in xrange(sn.n_static):
- lattice_pt = sn.static_structure.positions[lattice_index]
-
- if self.dynamic_lattice_mapping:
- # Only compute all distances if dynamic remapping is on
- pbcc.distances(lattice_pt, static_positions, out = lattice_pt_dists)
- nearest_static_position = np.argmin(lattice_pt_dists)
- nearest_static_distance = lattice_pt_dists[nearest_static_position]
- else:
- nearest_static_position = lattice_index
- nearest_static_distance = pbcc.distances(lattice_pt, static_positions[nearest_static_position:nearest_static_position+1])[0]
+ dynmap_allowed = dynmap_compat[lattice_index]
+ n_dynmap_allowed = len(dynmap_allowed)
+ lattice_pt[:] = sn.static_structure.positions[lattice_index]
+ np.take(static_positions,
+ dynmap_allowed,
+ out = static_pos_buffer[:n_dynmap_allowed],
+ axis = 0,
+ mode = 'clip')
+
+ pbcc.distances(
+ lattice_pt,
+ static_pos_buffer[:n_dynmap_allowed],
+ out = lattice_pt_dists[:n_dynmap_allowed]
+ )
+ nearest_static_position = np.argmin(lattice_pt_dists[:n_dynmap_allowed])
+ nearest_static_distance = lattice_pt_dists[nearest_static_position]
+ nearest_static_position = dynmap_allowed[nearest_static_position]
if static_positions_seen[nearest_static_position]:
# We've already seen this one... error
@@ -79,8 +103,7 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
frame = i,
try_recentering = True)
- if self.dynamic_lattice_mapping:
- lattice_map[lattice_index] = nearest_static_position
+ lattice_map[lattice_index] = nearest_static_position
# In normal circumstances, every current static position should be assigned.
# Just a sanity check
@@ -96,8 +119,9 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
mobile_pt = frame[mobile_idexes[j]]
# Shift the Li in question to the center of the unit cell
- np.copyto(frame_shift, frame[sn.static_mask])
- frame_shift += (pbcc.cell_centroid - mobile_pt)
+ frame_shift[:] = static_positions
+ frame_shift -= mobile_pt
+ frame_shift += pbcc.cell_centroid
# Wrap all positions into the unit cell
pbcc.wrap_points(frame_shift)
@@ -148,18 +172,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
precision cutoff_round_to_zero,
precision [:] distbuff) nogil:
- # Pure Python equiv:
- # for k in xrange(landmark_dim):
- # lvec = np.linalg.norm(lattice_positions[verts[k]] - cell_centroid, axis = 1)
- # past_cutoff = lvec > cutoff
-
- # # Short circut it, since the product then goes to zero too.
- # if np.any(past_cutoff):
- # landmark_vectors[(i * n_li) + j, k] = 0
- # else:
- # lvec = (np.cos((np.pi / cutoff) * lvec) + 1.0) / 2.0
- # landmark_vectors[(i * n_li) + j, k] = np.product(lvec)
-
# Fill the landmark vector
cdef int [:] vert
cdef Py_ssize_t v
@@ -177,11 +189,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
distbuff[idex] = temp
- # if temp > cutoff:
- # distbuff[idex] = 0.0
- # else:
- # distbuff[idex] = (cos((M_PI / cutoff) * temp) + 1.0) * 0.5
-
# For each component
for k in xrange(landmark_dim):
ci = 1.0
@@ -205,7 +212,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
temp = 1.0 / (1.0 + exp(cutoff_steepness * (temp - cutoff_midpoint)))
# Multiply into accumulator
- #ci *= distbuff[v]
ci *= temp
# "Normalize" to number of vertices
From c43571e9a7a6e31ce51c61d118096a8e02c90522 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Thu, 1 Aug 2019 18:43:25 -0600
Subject: [PATCH 125/129] Use IonicSiteNetwork information for ionic bonds
---
.../SiteCoordinationEnvironment.py | 29 +++++--------------
1 file changed, 7 insertions(+), 22 deletions(-)
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index bd68aed..d4d1baa 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -42,9 +42,9 @@ class SiteCoordinationEnvironment(object):
- ``coordination_numbers``: The coordination number of the site.
Args:
- guess_ionic_bonds (bool): If True, uses ``pymatgen``'s bond valence
- analysis to guess valences and only consider ionic bonds for
- neighbor analysis. Otherwise, or if it fails, all bonds are fair game.
+ only_ionic_bonds (bool): If True, assumes ``sn`` is an ``IonicSiteNetwork``
+ and uses its charge information to only consider as neighbors
+ atoms with compatable (anion and cation, ion and neutral) charges.
full_chemenv_site_types (bool): If True, ``sitator`` site types on the
final ``SiteNetwork`` will be assigned based on unique chemical
environments, including shape. If False, they will be assigned
@@ -54,13 +54,13 @@ class SiteCoordinationEnvironment(object):
**kwargs: passed to ``compute_structure_environments``.
"""
def __init__(self,
- guess_ionic_bonds = True,
+ only_ionic_bonds = True,
full_chemenv_site_types = False,
**kwargs):
if not has_pymatgen:
raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
self._kwargs = kwargs
- self._guess_ionic_bonds = guess_ionic_bonds
+ self._only_ionic_bonds = only_ionic_bonds
self._full_chemenv_site_types = full_chemenv_site_types
def run(self, sn):
@@ -84,23 +84,8 @@ def run(self, sn):
vertices = []
valences = 'undefined'
- if self._guess_ionic_bonds:
- sim_struct = AseAtomsAdaptor.get_structure(sn.structure)
- valences = np.zeros(shape = len(site_struct), dtype = np.int)
- bv = BVAnalyzer()
- try:
- struct_valences = np.asarray(bv.get_valences(sim_struct))
- except ValueError as ve:
- logger.warning("Failed to compute bond valences: %s" % ve)
- else:
- valences = np.zeros(shape = len(site_struct), dtype = np.int)
- valences[:site_atom_index] = struct_valences[sn.static_mask]
- mob_val = struct_valences[sn.mobile_mask]
- if np.any(mob_val != mob_val[0]):
- logger.warning("Mobile atom estimated valences (%s) not uniform; arbitrarily taking first." % mob_val)
- valences[site_atom_index] = mob_val[0]
- finally:
- valences = list(valences)
+ if self._only_ionic_bonds:
+ valences = np.concatenate((sn.static_charges, [sn.mobile_charge]))
logger.info("Running site coordination environment analysis...")
# Do this once.
From bf1d6c2296fa10e00dca185bf1b60ca093c1ec4a Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 2 Aug 2019 13:36:08 -0600
Subject: [PATCH 126/129] Added polyatomic ion detection
---
sitator/util/chemistry.py | 100 ++++++++++++++++++++++++++++++++++++++
1 file changed, 100 insertions(+)
create mode 100644 sitator/util/chemistry.py
diff --git a/sitator/util/chemistry.py b/sitator/util/chemistry.py
new file mode 100644
index 0000000..c0070ea
--- /dev/null
+++ b/sitator/util/chemistry.py
@@ -0,0 +1,100 @@
+import numpy as np
+
+import ase.data
+try:
+ from ase.formula import Formula
+except ImportError:
+ from ase.util.formula import Formula
+
+from sitator.util import PBCCalculator
+
+# Key is a central atom, value is a list of the formulas for the rest.
+# See, http://www.fccj.us/PolyatomicIons/CompletePolyatomicIonList.htm
+DEFAULT_POLYATOMIC_IONS = {
+ 'Ar' : ['O4', 'O3'],
+ 'C' : ['N', 'O3', 'O2', 'NO', 'SN'],
+ 'B' : ['O3','O2'],
+ 'Br' : ['O4', 'O3', 'O2', 'O'],
+ 'Fe' : ['O4'],
+ 'I' : ['O3', 'O4', 'O2', 'O'],
+ 'Si' : ['O4', 'O3'],
+ 'S' : ['O5', 'O4', 'O3', 'SO3', 'O2'],
+ 'Sb' : ['O4', 'O3'],
+ 'Se' : ['O4', 'O3'],
+ 'Sn' : ['O3', 'O2'],
+ 'N' : ['O3', 'O2', 'CO'],
+ 'Re' : ['O4'],
+ 'Cl' : ['O4', 'O3', 'O2', 'O'],
+ 'Mn' : ['O4'],
+ 'Mo' : ['O4'],
+ 'Cr' : ['O4', 'CrO7', 'O2'],
+ 'P' : ['O4', 'O3', 'O2'],
+ 'Tc' : ['O4'],
+ 'Te' : ['O4', 'O6', 'O3'],
+ 'Pb' : ['O3', 'O2'],
+ 'W' : ['O4']
+}
+
+def identify_polyatomic_ions(structure,
+ cutoff_factor = 1.0,
+ ion_definitions = DEFAULT_POLYATOMIC_IONS):
+ """Find polyatomic ions in a structure.
+
+ Goes to each potential polyatomic ion center and first checks if the nearest
+ neighbor is of a viable species. If it is, all nearest neighbors within the
+ maximum possible pairwise summed covalent radii are found and matched against
+ the database.
+
+ Args:
+ structure (ase.Atoms)
+ cutoff_factor (float): Coefficient for the cutoff. Allows tuning just
+ how closely the polyatomic ions must be bound. Defaults to 1.0.
+ ion_definitions (dict): Database of polyatomic ions where the key is a
+ central atom species symbol and the value is a list of the formulas
+ for the rest. Defaults to a list of common, non-hydrogen-containing
+ polyatomic ions with Oxygen.
+ Returns:
+ list of tuples, one for each identified polyatomic anion containing its
+ formula, the index of the central atom, and the indexes of the remaining
+ atoms.
+ """
+ ion_definitions = {
+ k : [Formula(f) for f in v]
+ for k, v in ion_definitions.items()
+ }
+ out = []
+ # Go to each possible center atom in data, check that nearest neighbor is
+ # of right species, then get all within covalent distances and check if
+ # composition matches database...
+ pbcc = PBCCalculator(structure.cell)
+ dmat = pbcc.pairwise_distances(structure.positions) # Precompute
+ np.fill_diagonal(dmat, np.inf)
+ for center_i, center_symbol in enumerate(structure.symbols):
+ if center_symbol in ion_definitions:
+ nn_sym = structure.symbols[np.argmin(dmat[center_i])]
+ could_be = [f for f in ion_definitions[center_symbol] if nn_sym in f]
+ if len(could_be) == 0:
+ # Nearest neighbor isn't even a plausible polyatomic ion species,
+ # so skip this potential center.
+ continue
+ # Take the largest possible other species covalent radius for all
+ # other species it possibily could be.
+ cutoff = max(
+ ase.data.covalent_radii[ase.data.atomic_numbers[other_sym]]
+ for form in could_be for other_sym in form
+ )
+ cutoff += ase.data.covalent_radii[ase.data.atomic_numbers[center_symbol]]
+ cutoff *= cutoff_factor
+ neighbors = dmat[center_i] <= cutoff
+ neighbors = np.where(neighbors)[0]
+ neighbor_formula = Formula.from_list(structure.symbols[neighbors])
+ it_is = [f for f in could_be if f == neighbor_formula]
+ if len(it_is) > 1:
+ raise ValueError("Somehow identified single center %s (atom %i) as multiple polyatomic ions %s" % (center_symbol, center_i, it_is))
+ elif len(it_is) == 1:
+ out.append((
+ center_symbol + neighbor_formula.format('hill'),
+ center_i,
+ neighbors
+ ))
+ return out
From 7ff646d7c85733292cc005466fe09cc957cdfe06 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 2 Aug 2019 13:36:25 -0600
Subject: [PATCH 127/129] Handle newer ASE `Cell` objects
---
sitator/util/PBCCalculator.pyx | 1 +
1 file changed, 1 insertion(+)
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index c841fd2..caf827f 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -24,6 +24,7 @@ cdef class PBCCalculator(object):
:param DxD ndarray: the unit cell -- an array of cell vectors, like the
cell of an ASE :class:Atoms object.
"""
+ cell = cell[:]
cellmat = np.matrix(cell).T
assert cell.shape[1] == cell.shape[0], "Cell must be square"
From 22f9abf34cfed4572adb7b36fb95cee703de6d90 Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 2 Aug 2019 15:15:34 -0600
Subject: [PATCH 128/129] Added static atom anchoring
---
sitator/landmark/LandmarkAnalysis.py | 32 ++++++++++++++++--
sitator/landmark/anchor.py | 49 ++++++++++++++++++++++++++++
sitator/landmark/helpers.pyx | 42 +++++++++++++++++-------
3 files changed, 109 insertions(+), 14 deletions(-)
create mode 100644 sitator/landmark/anchor.py
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index 5250afc..85dd5d4 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -10,8 +10,8 @@
from . import helpers
from sitator import SiteNetwork, SiteTrajectory
-from sitator.errors import MultipleOccupancyError
-
+from sitator.errors import MultipleOccupancyError, InsufficientSitesError
+from sitator.landmark.anchor import to_origin
import logging
logger = logging.getLogger(__name__)
@@ -108,6 +108,7 @@ def __init__(self,
check_for_zero_landmarks = True,
static_movement_threshold = 1.0,
dynamic_lattice_mapping = False,
+ static_anchoring = to_origin,
relaxed_lattice_checks = False,
max_mobile_per_site = 1,
force_no_memmap = False,
@@ -127,6 +128,7 @@ def __init__(self,
from sitator.landmark.dynamic_mapping import within_species
dynamic_lattice_mapping = within_species
self.dynamic_lattice_mapping = dynamic_lattice_mapping
+ self.static_anchoring = static_anchoring
self.relaxed_lattice_checks = relaxed_lattice_checks
@@ -221,6 +223,9 @@ def run(self, sn, frames):
# If no dynamic mapping, each is only compatable with itself.
dynmap_compat = np.arange(sn.n_static)[:, np.newaxis]
assert len(dynmap_compat) == sn.n_static
+ static_anchors = self.static_anchoring(sn)
+ # This also validates the anchors
+ lattice_pt_order = self._get_lattice_order_from_anchors(static_anchors)
# The dimension of one landmark vector is the number of Voronoi regions
shape = (n_frames * sn.n_mobile, self._landmark_dimension)
@@ -237,6 +242,8 @@ def run(self, sn, frames):
helpers._fill_landmark_vectors(self, sn, verts_np, site_vert_dists,
frames,
dynmap_compat = dynmap_compat,
+ lattice_pt_anchors = static_anchors,
+ lattice_pt_order = lattice_pt_order,
check_for_zeros = self.check_for_zero_landmarks,
tqdm = tqdm, logger = logger)
@@ -335,3 +342,24 @@ def run(self, sn, frames):
self._has_run = True
return out_st
+
+ def _get_lattice_order_from_anchors(self, lattice_pt_anchors):
+ absolute_lattice_mask = lattice_pt_anchors == -1
+ lattice_pt_order = []
+ # -1 (absolte anchor of origin) is always known
+ known = np.zeros(shape = len(lattice_pt_anchors) + 1, dtype = np.bool)
+ known[-1] = True
+
+ while True:
+ can_know = known[lattice_pt_anchors]
+ new_know = can_know & ~known[:-1]
+ known[:-1] |= can_know
+ lattice_pt_order.extend(np.where(new_know)[0])
+ if not np.any(new_know):
+ break
+ if len(lattice_pt_order) < len(lattice_pt_anchors):
+ raise ValueError("Lattice point anchors %s contains a unsatisfiable dependency (likely a circular dependency)." % lattice_pt_anchors)
+ # Remove points with absolute anchors from the order; there's no need to
+ # do any computations for them.
+ lattice_pt_order = lattice_pt_order[np.sum(absolute_lattice_mask):]
+ return lattice_pt_order
diff --git a/sitator/landmark/anchor.py b/sitator/landmark/anchor.py
new file mode 100644
index 0000000..067c859
--- /dev/null
+++ b/sitator/landmark/anchor.py
@@ -0,0 +1,49 @@
+import numpy as np
+
+from collections import Counter
+
+import ase.data
+
+from sitator.util.chemistry import identify_polyatomic_ions
+
+import logging
+logger = logging.getLogger(__name__)
+
+def to_origin(sn):
+ """Anchor all static atoms to the origin; i.e. their positions are absolute."""
+ return np.full(
+ shape = sn.n_static,
+ fill_value = -1,
+ dtype = np.int
+ )
+
+# ------
+
+def within_polyatomic_ions(**kwargs):
+ """Anchor the auxiliary atoms of a polyatomic ion to the central atom.
+
+ In phosphate (PO4), for example, the four coordinating Oxygen atoms
+ will be anchored to the central Phosphorous.
+
+ Args:
+ **kwargs: passed to ``sitator.util.chemistry.identify_polyatomic_ions``.
+ """
+ def func(sn):
+ anchors = np.full(shape = sn.n_static, fill_value = -1, dtype = np.int)
+ polyions = identify_polyatomic_ions(sn.static_structure, **kwargs)
+ logger.info("Identified %i polyatomic anions: %s" % (len(polyions), Counter(i[0] for i in polyions)))
+ for _, center, others in polyions:
+ anchors[others] = center
+ return anchors
+ return func
+
+# ------
+# TODO
+def to_heavy_elements(minimum_mass, maximum_distance = np.inf):
+ """Anchor "light" elements to their nearest "heavy" element.
+
+ Lightness/heaviness is determined by ``minimum_mass``, a cutoff in
+ """
+ def func(sn):
+ pass
+ return func
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index 180ddc5..2ed9eb9 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -15,6 +15,8 @@ def _fill_landmark_vectors(self,
site_vert_dists,
frames,
dynmap_compat,
+ lattice_pt_anchors,
+ lattice_pt_order,
check_for_zeros = True,
tqdm = lambda i: i,
logger = None):
@@ -27,18 +29,29 @@ def _fill_landmark_vectors(self,
cdef pbcc = self._pbcc
- frame_shift = np.empty(shape = (sn.n_static, 3))
+ frame_shift = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype)
temp_distbuff = np.empty(shape = sn.n_static, dtype = frames.dtype)
mobile_idexes = np.where(sn.mobile_mask)[0]
-
- lattice_map = np.empty(shape = sn.n_static, dtype = np.int)
-
- lattice_pt = np.empty(shape = 3, dtype = sn.static_structure.positions.dtype)
+ # Static lattice point buffers
+ lattice_pts_resolved = np.empty(shape = (sn.n_static, 3), dtype = sn.static_structure.positions.dtype)
+ # Determine resolution order
+ absolute_lattice_mask = lattice_pt_anchors == -1
+ assert len(lattice_pt_order) == len(lattice_pt_anchors) - np.sum(absolute_lattice_mask), "Order must contain all non-absolute anchored static lattice points"
+ cdef Py_ssize_t [:] lattice_pt_order_c = np.asarray(lattice_pt_order, dtype = np.int)
+ # Absolute (relative to origin) ones never need to be resolved, put it in
+ # the buffer now
+ lattice_pts_resolved[absolute_lattice_mask] = sn.static_structure.positions[absolute_lattice_mask]
+ assert not np.any(absolute_lattice_mask[lattice_pt_order]), "None of the absolute lattice points should be in the resolution order"
+ # Precompute the offsets for relative static lattice points:
+ relative_lattice_offsets = sn.static_structure.positions[lattice_pt_order] - sn.static_structure.positions[lattice_pt_anchors[lattice_pt_order]]
+ # Buffers for dynamic mapping
max_n_dynmat_compat = max(len(dm) for dm in dynmap_compat)
lattice_pt_dists = np.empty(shape = max_n_dynmat_compat, dtype = np.float)
- static_pos_buffer = np.empty(shape = (max_n_dynmat_compat, 3), dtype = lattice_pt.dtype)
+ static_pos_buffer = np.empty(shape = (max_n_dynmat_compat, 3), dtype = lattice_pts_resolved.dtype)
static_positions_seen = np.empty(shape = sn.n_static, dtype = np.bool)
+ lattice_map = np.empty(shape = sn.n_static, dtype = np.int)
+ # Instant static position buffers
static_positions = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype)
static_mask_idexes = sn.static_mask.nonzero()[0]
@@ -60,8 +73,7 @@ def _fill_landmark_vectors(self,
cdef Py_ssize_t n_dynmap_allowed
# Iterate through time
for i, frame in enumerate(tqdm(frames, desc = "Landmark Frame")):
-
- #static_positions = frame[sn.static_mask]
+ # Copy static positions to buffer
np.take(frame,
static_mask_idexes,
out = static_positions,
@@ -71,10 +83,16 @@ def _fill_landmark_vectors(self,
# Every frame, update the lattice map
static_positions_seen.fill(False)
+ # - Resolve static lattice positions from their origins
+ for order_i, lattice_index in enumerate(lattice_pt_order_c):
+ lattice_pts_resolved[lattice_index] = static_positions[lattice_pt_anchors[lattice_index]]
+ lattice_pts_resolved[lattice_index] += relative_lattice_offsets[order_i]
+
+ # - Map static positions to static lattice sites
for lattice_index in xrange(sn.n_static):
dynmap_allowed = dynmap_compat[lattice_index]
n_dynmap_allowed = len(dynmap_allowed)
- lattice_pt[:] = sn.static_structure.positions[lattice_index]
+
np.take(static_positions,
dynmap_allowed,
out = static_pos_buffer[:n_dynmap_allowed],
@@ -82,7 +100,7 @@ def _fill_landmark_vectors(self,
mode = 'clip')
pbcc.distances(
- lattice_pt,
+ lattice_pts_resolved[lattice_index],
static_pos_buffer[:n_dynmap_allowed],
out = lattice_pt_dists[:n_dynmap_allowed]
)
@@ -98,7 +116,7 @@ def _fill_landmark_vectors(self,
static_positions_seen[nearest_static_position] = True
if nearest_static_distance > self.static_movement_threshold:
- raise StaticLatticeError("No static atom position within %f A threshold of static lattice position %i" % (self.static_movement_threshold, lattice_index),
+ raise StaticLatticeError("Nearest static atom to lattice position %i is %.2fÅ away, above threshold of %.2fÅ" % (lattice_index, nearest_static_distance, self.static_movement_threshold),
lattice_atoms = [lattice_index],
frame = i,
try_recentering = True)
@@ -114,7 +132,7 @@ def _fill_landmark_vectors(self,
frame = i,
try_recentering = True)
-
+ # - Compute landmark vectors for mobile
for j in xrange(sn.n_mobile):
mobile_pt = frame[mobile_idexes[j]]
From bdf717c1901b5623118dc66650d9d6f74645dc3b Mon Sep 17 00:00:00 2001
From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com>
Date: Fri, 2 Aug 2019 15:28:21 -0600
Subject: [PATCH 129/129] Fixed ionic data under copying
---
sitator/ionic.py | 12 ++++++++++++
.../site_descriptors/SiteCoordinationEnvironment.py | 2 +-
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/sitator/ionic.py b/sitator/ionic.py
index 7f8dcb0..538618d 100644
--- a/sitator/ionic.py
+++ b/sitator/ionic.py
@@ -106,6 +106,18 @@ def n_opposite_charge(self):
def n_same_charge(self):
return np.sum(self.same_ion_mask)
+ def __getitem__(self, key):
+ out = super().__getitem__(key)
+ out.mobile_species = self.mobile_species
+ out.static_species = self.static_species
+ out.mobile_charge = self.mobile_charge
+ out.static_charges = self.static_charges
+ out.opposite_ion_mask = self.opposite_ion_mask
+ out.opposite_ion_structure = self.opposite_ion_structure
+ out.same_ion_mask = self.same_ion_mask
+ out.same_ion_structure = self.same_ion_structure
+ return out
+
def __str__(self):
out = super().__str__()
static_nums = self.static_structure.numbers
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
index d4d1baa..8ad2ef8 100644
--- a/sitator/site_descriptors/SiteCoordinationEnvironment.py
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -85,7 +85,7 @@ def run(self, sn):
valences = 'undefined'
if self._only_ionic_bonds:
- valences = np.concatenate((sn.static_charges, [sn.mobile_charge]))
+ valences = list(sn.static_charges) + [sn.mobile_charge]
logger.info("Running site coordination environment analysis...")
# Do this once.