diff --git a/.gitignore b/.gitignore
index 4800c95..f71864f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,8 @@ build/
/dist/
/*.egg-info
+
+*.bak
+
+.tags
+.tags_swap
diff --git a/README.md b/README.md
index 607bcd1..c93db2c 100644
--- a/README.md
+++ b/README.md
@@ -9,41 +9,54 @@ A modular framework for conducting and visualizing site analysis of molecular dy
`sitator` contains an efficient implementation of our method, landmark analysis, as well as visualization tools, generic data structures for site analysis, pre- and post-processing tools, and more.
-For details on the method and its application, please see our paper:
+For details on landmark analysis and its application, please see our paper:
> L. Kahle, A. Musaelian, N. Marzari, and B. Kozinsky
> [Unsupervised landmark analysis for jump detection in molecular dynamics simulations](https://doi.org/10.1103/PhysRevMaterials.3.055404)
> Phys. Rev. Materials 3, 055404 – 21 May 2019
-If you use `sitator` in your research, please consider citing this paper. The BibTex citation can be found in [`CITATION.bib`](CITATION.bib).
+If you use `sitator` in your research, please consider citing this paper. The BibTeX citation can be found in [`CITATION.bib`](CITATION.bib).
## Installation
-`sitator` is currently built for Python 2.7. We recommend the use of a Python 2.7 virtual environment (`virtualenv`, `conda`, etc.). `sitator` has two external dependencies:
+`sitator` is built for Python >=3.2 (the older version supports Python 2.7). We recommend the use of a virtual environment (`virtualenv`, `conda`, etc.). `sitator` has a number of optional dependencies that enable various features:
- - The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does *not* have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
- - If you want to use the site type analysis features, a working installation of [Quippy](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) and Python bindings is required for computing SOAP descriptors.
+* Landmark Analysis
+ * The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does not have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.)
+* Site Type Analysis
+ * For SOAP-based site types: either the `quip` binary from [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) **or** the [`DScribe`](https://singroup.github.io/dscribe/index.html) Python library.
+ * The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing.
+ * For coordination environment analysis (`sitator.site_descriptors.SiteCoordinationEnvironment`), we integrate the `pymatgen.analysis.chemenv` package; a somewhat recent installation of `pymatgen` is required.
After downloading, the package is installed with `pip`:
-```
+```bash
# git clone ... OR unzip ... OR ...
cd sitator
pip install .
```
-To enable site type analysis, add the `[SiteTypeAnalysis]` option:
+To enable site type analysis, add the `[SiteTypeAnalysis]` option (this adds two dependencies -- Python packages `pydpc` and `dscribe`):
-```
+```bash
pip install ".[SiteTypeAnalysis]"
```
## Examples and Documentation
-Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/).
+Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12 as in our paper, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/).
+
+Full API documentation can be found at [ReadTheDocs](https://sitator.readthedocs.io/en/py3/).
+
+`sitator` generally assumes units of femtoseconds for time, Angstroms for space,
+and Cartesian (not crystal) coordinates.
+
+## Global Options
+
+`sitator` uses the `tqdm.autonotebook` tool to automatically produce the correct fancy progress bars for terminals and iPython notebooks. To disable all progress bars, run with the environment variable `SITATOR_PROGRESSBAR` set to `false`.
-All individual classes and parameters are documented with docstrings in the source code.
+The `SITATOR_ZEO_PATH` and `SITATOR_QUIP_PATH` environment variables can set the default paths to the Zeo++ `network` and QUIP `quip` executables, respectively.
## License
-This software is made available under the MIT License. See `LICENSE` for more details.
+This software is made available under the MIT License. See [`LICENSE`](LICENSE) for more details.
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..d0c3cbf
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000..6247f7e
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 0000000..8e7ec1b
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,55 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+
+# -- Project information -----------------------------------------------------
+
+project = 'sitator'
+copyright = '2019, Alby Musaelian'
+author = 'Alby Musaelian'
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.napoleon'
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+master_doc = 'index'
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 0000000..e2f4d80
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,20 @@
+.. sitator documentation master file, created by
+ sphinx-quickstart on Tue Jul 16 15:15:41 2019.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to sitator's documentation!
+===================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
new file mode 100644
index 0000000..244d585
--- /dev/null
+++ b/docs/source/modules.rst
@@ -0,0 +1,7 @@
+sitator
+=======
+
+.. toctree::
+ :maxdepth: 4
+
+ sitator
diff --git a/docs/source/sitator.descriptors.rst b/docs/source/sitator.descriptors.rst
new file mode 100644
index 0000000..aeb7b2e
--- /dev/null
+++ b/docs/source/sitator.descriptors.rst
@@ -0,0 +1,15 @@
+sitator.descriptors package
+===========================
+
+Module contents
+---------------
+
+.. automodule:: sitator.descriptors
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.descriptors.ConfigurationalEntropy
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.dynamics.rst b/docs/source/sitator.dynamics.rst
new file mode 100644
index 0000000..c2510ac
--- /dev/null
+++ b/docs/source/sitator.dynamics.rst
@@ -0,0 +1,40 @@
+sitator.dynamics package
+========================
+
+Module contents
+---------------
+
+.. automodule:: sitator.dynamics
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.AverageVibrationalFrequency
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.JumpAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.MergeSitesByDynamics
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.MergeSitesByThreshold
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.RemoveShortJumps
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.dynamics.RemoveUnoccupiedSites
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.landmark.cluster.rst b/docs/source/sitator.landmark.cluster.rst
new file mode 100644
index 0000000..c332d86
--- /dev/null
+++ b/docs/source/sitator.landmark.cluster.rst
@@ -0,0 +1,30 @@
+sitator.landmark.cluster package
+================================
+
+Submodules
+----------
+
+sitator.landmark.cluster.dotprod module
+---------------------------------------
+
+.. automodule:: sitator.landmark.cluster.dotprod
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.landmark.cluster.mcl module
+-----------------------------------
+
+.. automodule:: sitator.landmark.cluster.mcl
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.landmark.cluster
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.landmark.rst b/docs/source/sitator.landmark.rst
new file mode 100644
index 0000000..8bf446a
--- /dev/null
+++ b/docs/source/sitator.landmark.rst
@@ -0,0 +1,42 @@
+sitator.landmark package
+========================
+
+Subpackages
+-----------
+
+.. toctree::
+
+ sitator.landmark.cluster
+
+Submodules
+----------
+
+sitator.landmark.errors module
+------------------------------
+
+.. automodule:: sitator.landmark.errors
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.landmark.helpers module
+-------------------------------
+
+.. automodule:: sitator.landmark.helpers
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.landmark
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.landmark.LandmarkAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.misc.rst b/docs/source/sitator.misc.rst
new file mode 100644
index 0000000..3b441d0
--- /dev/null
+++ b/docs/source/sitator.misc.rst
@@ -0,0 +1,34 @@
+sitator.misc package
+====================
+
+sitator.misc.oldio module
+-------------------------
+
+.. automodule:: sitator.misc.oldio
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.misc
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.misc.GenerateAroundSites
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.misc.GenerateClampedTrajectory
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.misc.NAvgsPerSite
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.network.rst b/docs/source/sitator.network.rst
new file mode 100644
index 0000000..be7fb28
--- /dev/null
+++ b/docs/source/sitator.network.rst
@@ -0,0 +1,25 @@
+sitator.network package
+=======================
+
+Module contents
+---------------
+
+.. automodule:: sitator.network
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.network.DiffusionPathwayAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.network.MergeSitesByBarrier
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.network.merging
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.rst b/docs/source/sitator.rst
new file mode 100644
index 0000000..8662328
--- /dev/null
+++ b/docs/source/sitator.rst
@@ -0,0 +1,38 @@
+sitator package
+===============
+
+Subpackages
+-----------
+
+.. toctree::
+
+ sitator.descriptors
+ sitator.dynamics
+ sitator.landmark
+ sitator.misc
+ sitator.network
+ sitator.site_descriptors
+ sitator.util
+ sitator.visualization
+ sitator.voronoi
+
+Submodules
+----------
+
+Module contents
+---------------
+
+.. automodule:: sitator
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.SiteNetwork
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.SiteTrajectory
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.site_descriptors.backend.rst b/docs/source/sitator.site_descriptors.backend.rst
new file mode 100644
index 0000000..a7afe2f
--- /dev/null
+++ b/docs/source/sitator.site_descriptors.backend.rst
@@ -0,0 +1,30 @@
+sitator.site\_descriptors.backend package
+=========================================
+
+Submodules
+----------
+
+sitator.site\_descriptors.backend.dscribe module
+------------------------------------------------
+
+.. automodule:: sitator.site_descriptors.backend.dscribe
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.site\_descriptors.backend.quip module
+---------------------------------------------
+
+.. automodule:: sitator.site_descriptors.backend.quip
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.site_descriptors.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.site_descriptors.rst b/docs/source/sitator.site_descriptors.rst
new file mode 100644
index 0000000..4275695
--- /dev/null
+++ b/docs/source/sitator.site_descriptors.rst
@@ -0,0 +1,43 @@
+sitator.site\_descriptors package
+=================================
+
+Subpackages
+-----------
+
+.. toctree::
+
+ sitator.site_descriptors.backend
+
+Submodules
+----------
+
+sitator.site\_descriptors.SOAP module
+-------------------------------------
+
+.. automodule:: sitator.site_descriptors.SOAP
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: sitator.site_descriptors
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.site_descriptors.SiteCoordinationEnvironment
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.site_descriptors.SiteTypeAnalysis
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.site_descriptors.SiteVolumes
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.util.rst b/docs/source/sitator.util.rst
new file mode 100644
index 0000000..1f95c20
--- /dev/null
+++ b/docs/source/sitator.util.rst
@@ -0,0 +1,61 @@
+sitator.util package
+====================
+
+Submodules
+----------
+
+sitator.util.elbow module
+-------------------------
+
+.. automodule:: sitator.util.elbow
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.util.mcl module
+-----------------------
+
+.. automodule:: sitator.util.mcl
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.util.progress module
+----------------------------
+
+.. automodule:: sitator.util.progress
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.util.zeo module
+-----------------------
+
+.. automodule:: sitator.util.zeo
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.util
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.util.DotProdClassifier
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.util.PBCCalculator
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.util.RecenterTrajectory
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.visualization.rst b/docs/source/sitator.visualization.rst
new file mode 100644
index 0000000..9596867
--- /dev/null
+++ b/docs/source/sitator.visualization.rst
@@ -0,0 +1,40 @@
+sitator.visualization package
+=============================
+
+Submodules
+----------
+
+sitator.visualization.atoms module
+----------------------------------
+
+.. automodule:: sitator.visualization.atoms
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+sitator.visualization.common module
+-----------------------------------
+
+.. automodule:: sitator.visualization.common
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: sitator.visualization
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.visualization.SiteNetworkPlotter
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: sitator.visualization.SiteTrajectoryPlotter
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/sitator.voronoi.rst b/docs/source/sitator.voronoi.rst
new file mode 100644
index 0000000..adad362
--- /dev/null
+++ b/docs/source/sitator.voronoi.rst
@@ -0,0 +1,10 @@
+sitator.voronoi package
+=======================
+
+Module contents
+---------------
+
+.. automodule:: sitator.voronoi
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e719012
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+numpy
+Cython
+scipy
+matplotlib
+ase
+tqdm
+sklearn
diff --git a/setup.py b/setup.py
index f6067cd..9194cd1 100644
--- a/setup.py
+++ b/setup.py
@@ -1,33 +1,43 @@
from setuptools import setup, find_packages
from Cython.Build import cythonize
+import Cython.Compiler
import numpy as np
-setup(name = 'sitator',
- version = '1.0.1',
- description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.',
- download_url = "https://github.com/Linux-cpp-lisp/sitator",
- author = 'Alby Musaelian',
- license = "MIT",
- python_requires = '>=2.7, <3',
- packages = find_packages(),
- ext_modules = cythonize([
+# Allows cimport'ing PBCCalculator
+Cython.Compiler.Options.cimport_from_pyx = True
+
+setup(
+ name = 'sitator',
+ version = '2.0.0',
+ description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.',
+ download_url = "https://github.com/Linux-cpp-lisp/sitator",
+ author = 'Alby Musaelian',
+ license = "MIT",
+ python_requires = '>=3.2',
+ packages = find_packages(),
+ ext_modules = cythonize(
+ [
"sitator/landmark/helpers.pyx",
- "sitator/util/*.pyx"
- ]),
- include_dirs=[np.get_include()],
- install_requires = [
+ "sitator/util/*.pyx",
+ "sitator/dynamics/*.pyx",
+ "sitator/misc/*.pyx"
+ ],
+ language_level = 3
+ ),
+ include_dirs=[np.get_include()],
+ install_requires = [
"numpy",
"scipy",
"matplotlib",
"ase",
"tqdm",
- "backports.tempfile",
- "future",
"sklearn"
- ],
- extras_require = {
+ ],
+ extras_require = {
"SiteTypeAnalysis" : [
- "pydpc"
+ "pydpc",
+ "dscribe"
]
- },
- zip_safe = True)
+ },
+ zip_safe = True
+)
diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py
index 28909dc..bd71be4 100644
--- a/sitator/SiteNetwork.py
+++ b/sitator/SiteNetwork.py
@@ -1,14 +1,11 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
import re
import os
import tarfile
-from backports import tempfile
+import tempfile
+import ase
import ase.io
import matplotlib
@@ -17,19 +14,33 @@
class SiteNetwork(object):
"""A network of mobile particle sites in a static lattice.
- Stores the locations of sites (`centers`), their defining static atoms (`vertices`),
- and their "types" (`site_types`).
+ Stores the locations of sites (``centers``) for some indicated mobile atoms
+ (``mobile_mask``) in a structure (``structure``). Optionally includes their
+ defining static atoms (``vertices``) and "types" (``site_types``).
Arbitrary data can also be associated with each site and with each edge
between sites. Site data can be any array of length n_sites; edge data can be
any matrix of shape (n_sites, n_sites) where entry i, j is the value for the
- edge from site i to site j.
+ edge from site i to site j (edge attributes can be asymmetric).
+
+ Attributes can be marked as "computed"; this is a hint that the attribute
+ was computed based on a ``SiteTrajectory``. Most ``sitator`` algorithms that
+ modify/process ``SiteTrajectory``s will clear "computed" attrbutes,
+ assuming that they are invalidated by the changes to the ``SiteTrajectory``.
Attributes:
centers (ndarray): (n_sites, 3) coordinates of each site.
vertices (list, optional): list of lists of indexes of static atoms defining each
site.
site_types (ndarray, optional): (n_sites,) values grouping sites into types.
+
+ Args:
+ structure (ase.Atoms): an ASE ``Atoms`` containging whatever atoms exist
+ in the MD trajectory.
+ static_mask (ndarray): Boolean mask indicating which atoms make up the
+ host lattice.
+ mobile_mask (ndarray): Boolean mask indicating which atoms' movement we
+ are interested in.
"""
ATTR_NAME_REGEX = re.compile("^[a-zA-Z][a-zA-Z0-9_]*$")
@@ -38,16 +49,6 @@ def __init__(self,
structure,
static_mask,
mobile_mask):
- """
- Args:
- structure (Atoms): an ASE/Quippy ``Atoms`` containging whatever atoms exist
- in the MD trajectory.
- static_mask (ndarray): Boolean mask indicating which atoms make up the
- host lattice.
- mobile_mask (ndarray): Boolean mask indicating which atoms' movement we
- are interested in.
- """
-
assert static_mask.ndim == mobile_mask.ndim == 1, "The masks must be one-dimensional"
assert len(structure) == len(static_mask) == len(mobile_mask), "The masks must have the same length as the # of atoms in the strucutre."
@@ -72,20 +73,54 @@ def __init__(self,
self._site_attrs = {}
self._edge_attrs = {}
+ self._attr_computed = {}
+
+ def copy(self, with_computed = True):
+ """Returns a (shallowish) copy of self.
- def copy(self):
- """Returns a (shallowish) copy of self."""
+ Args:
+ with_computed (bool): If ``False``, attributes marked "computed" will
+ not be included in the copy.
+ Returns:
+ A ``SiteNetwork``.
+ """
# Use a mask to force a copy
msk = np.ones(shape = self.n_sites, dtype = np.bool)
- return self[msk]
+ sn = self[msk]
+ if not with_computed:
+ sn.clear_computed_attributes()
+ return sn
def __len__(self):
return self.n_sites
+ def __str__(self):
+ return (
+ "{}: {:d} sites for {:d} mobile particles in static lattice of {:d} particles\n"
+ " Has vertices: {}\n"
+ " Has types: {}\n"
+ " Has site attributes: {}\n"
+ " Has edge attributes: {}\n"
+ ""
+ ).format(
+ type(self).__name__,
+ self.n_sites,
+ self.n_mobile,
+ self.n_static,
+ self._vertices is not None,
+ self._types is not None,
+ ", ".join(self._site_attrs.keys()),
+ ", ".join(self._edge_attrs.keys())
+ )
+
def __getitem__(self, key):
- sn = type(self)(self.structure,
- self.static_mask,
- self.mobile_mask)
+ sn = self.__new__(type(self))
+ SiteNetwork.__init__(
+ sn,
+ self.structure,
+ self.static_mask,
+ self.mobile_mask
+ )
if not self._centers is None:
sn.centers = self._centers[key]
@@ -108,80 +143,14 @@ def __getitem__(self, key):
return sn
- _STRUCT_FNAME = "structure.xyz"
- _SMASK_FNAME = "static_mask.npy"
- _MMASK_FNAME = "mobile_mask.npy"
- _MAIN_FNAMES = ['centers', 'vertices', 'site_types']
-
- def save(self, file):
- """Save this SiteNetwork to a tar archive."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # -- Write the structure
- ase.io.write(os.path.join(tmpdir, self._STRUCT_FNAME), self.structure)
- # -- Write masks
- np.save(os.path.join(tmpdir, self._SMASK_FNAME), self.static_mask)
- np.save(os.path.join(tmpdir, self._MMASK_FNAME), self.mobile_mask)
- # -- Write what we have
- for arrname in self._MAIN_FNAMES:
- if not getattr(self, arrname) is None:
- np.save(os.path.join(tmpdir, "%s.npy" % arrname), getattr(self, arrname))
- # -- Write all site/edge attributes
- for atype, attrs in zip(("site_attr", "edge_attr"), (self._site_attrs, self._edge_attrs)):
- for attr in attrs:
- np.save(os.path.join(tmpdir, "%s-%s.npy" % (atype, attr)), attrs[attr])
- # -- Write final archive
- with tarfile.open(file, mode = 'w:gz', format = tarfile.PAX_FORMAT) as outf:
- outf.add(tmpdir, arcname = "")
-
- @classmethod
- def from_file(cls, file):
- """Load a SiteNetwork from a tar file/file descriptor."""
- all_others = {}
- site_attrs = {}
- edge_attrs = {}
- structure = None
- with tarfile.open(file, mode = 'r:gz', format = tarfile.PAX_FORMAT) as input:
- # -- Load everything
- for member in input.getmembers():
- if member.name == '':
- continue
- f = input.extractfile(member)
- if member.name == cls._STRUCT_FNAME:
- with tempfile.TemporaryDirectory() as tmpdir:
- input.extract(member, path = tmpdir)
- structure = ase.io.read(os.path.join(tmpdir, member.name), format = 'xyz')
- else:
- basename = os.path.splitext(os.path.basename(member.name))[0]
- data = np.load(f)
- if basename.startswith("site_attr"):
- site_attrs[basename.split('-')[1]] = data
- elif basename.startswith("edge_attr"):
- edge_attrs[basename.split('-')[1]] = data
- else:
- all_others[basename] = data
-
- # Create SiteNetwork
- assert not structure is None
- assert all(k in all_others for k in ("static_mask", "mobile_mask")), "Malformed SiteNetwork file."
- sn = SiteNetwork(structure,
- all_others['static_mask'],
- all_others['mobile_mask'])
- if 'centers' in all_others:
- sn.centers = all_others['centers']
- for key in all_others:
- if key in ('centers', 'static_mask', 'mobile_mask'):
- continue
- setattr(sn, key, all_others[key])
-
- assert all(len(sa) == sn.n_sites for sa in site_attrs.values())
- assert all(ea.shape == (sn.n_sites, sn.n_sites) for ea in edge_attrs.values())
- sn._site_attrs = site_attrs
- sn._edge_attrs = edge_attrs
-
- return sn
-
def of_type(self, stype):
- """Returns a "view" to this SiteNetwork with only sites of a certain type."""
+ """Returns a subset of this ``SiteNetwork`` with only sites of a certain type.
+
+ Args:
+ stype (int)
+ Returns:
+ A ``SiteNetwork``.
+ """
if self._types is None:
raise ValueError("This SiteNetwork has no type information.")
@@ -190,18 +159,45 @@ def of_type(self, stype):
return self[self._types == stype]
+ def get_structure_with_sites(self, site_atomic_number = None):
+ """Get an ``ase.Atoms`` with the sites included.
+
+ Sites are appended to the static structure; the first ``np.sum(static_mask)``
+ atoms in the returned object are the static structure.
+
+ Args:
+ site_atomic_number: If ``None``, the species of the first mobile
+ atom will be used.
+ Returns:
+ ``ase.Atoms`` and final ``site_atomic_number``
+ """
+ out = self.static_structure.copy()
+ if site_atomic_number is None:
+ site_atomic_number = self.structure.get_atomic_numbers()[self.mobile_mask][0]
+ numbers = np.full(len(self), site_atomic_number)
+ sites_atoms = ase.Atoms(
+ positions = self.centers,
+ numbers = numbers
+ )
+ site_idexes = len(out) + np.arange(self.n_sites)
+ out.extend(sites_atoms)
+ return out, site_atomic_number
+
@property
def n_sites(self):
+ """The number of sites."""
if self._centers is None:
return 0
return len(self._centers)
@property
def n_total(self):
+ """The total number of atoms in the system."""
return len(self.static_mask)
@property
def centers(self):
+ """The positions of the sites."""
view = self._centers.view()
view.flags.writeable = False
return view
@@ -215,27 +211,47 @@ def centers(self, value):
self._types = None
self._site_attrs = {}
self._edge_attrs = {}
+ self._attr_computed = {}
# Set centers
self._centers = value
def update_centers(self, newcenters):
- """Update the SiteNetwork's centers *without* reseting all other information."""
+ """Update the ``SiteNetwork``'s centers *without* reseting all other information.
+
+ Args:
+ newcenters (ndarray): Must have same length as current number of sites.
+ """
if newcenters.shape != self._centers.shape:
raise ValueError("New `centers` must have same shape as old; try using the setter `.centers = ...`")
self._centers = newcenters
@property
def vertices(self):
+ """The static atoms defining each site."""
return self._vertices
+ @property
+ def site_ids(self):
+ """Convenience property giving the index of each site."""
+ return np.arange(self.n_sites)
+
@vertices.setter
def vertices(self, value):
if not len(value) == len(self._centers):
raise ValueError("Wrong # of vertices %i; expected %i" % (len(value), len(self._centers)))
self._vertices = value
+ @property
+ def number_of_vertices(self):
+ """The number of vertices of each site."""
+ if self._vertices is None:
+ return None
+ else:
+ return [len(v) for v in self._vertices]
+
@property
def site_types(self):
+ """The type IDs of each site."""
if self._types is None:
return None
view = self._types.view()
@@ -250,24 +266,40 @@ def site_types(self, value):
@property
def n_types(self):
+ """The number of site types in the ``SiteNetwork``."""
return len(np.unique(self.site_types))
@property
def types(self):
+ """The unique site type IDs in the ``SiteNetwork``."""
return np.unique(self.site_types)
@property
def site_attributes(self):
- return self._site_attrs.keys()
+ """The names of the ``SiteNetwork``'s site attributes."""
+ return list(self._site_attrs.keys())
@property
def edge_attributes(self):
- return self._edge_attrs.keys()
+ """The names of the ``SiteNetwork``'s edge attributes."""
+ return list(self._edge_attrs.keys())
def has_attribute(self, attr):
+ """Whether the ``SiteNetwork`` has a given site or edge attrbute.
+
+ Args:
+ attr (str)
+ Returns:
+ bool
+ """
return (attr in self._site_attrs) or (attr in self._edge_attrs)
def remove_attribute(self, attr):
+ """Remove a site or edge attribute.
+
+ Args:
+ attr (str)
+ """
if attr in self._site_attrs:
del self._site_attrs[attr]
elif attr in self._edge_attrs:
@@ -275,10 +307,22 @@ def remove_attribute(self, attr):
else:
raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attr)
+ def clear_attributes(self):
+ """Remove all site and edge attributes."""
+ self._site_attrs = {}
+ self._edge_attrs = {}
+
+ def clear_computed_attributes(self):
+ """Remove all attributes marked "computed"."""
+ for k, computed in self._attr_computed.items():
+ if computed:
+ self.remove_attribute(k)
+
def __getattr__(self, attrkey):
- if attrkey in self._site_attrs:
+ v = vars(self)
+ if '_site_attrs' in v and attrkey in self._site_attrs:
return self._site_attrs[attrkey]
- elif attrkey in self._edge_attrs:
+ elif '_edge_attrs' in v and attrkey in self._edge_attrs:
return self._edge_attrs[attrkey]
else:
raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attrkey)
@@ -320,21 +364,37 @@ def get_edge(self, edge):
return out
- def add_site_attribute(self, name, attr):
+ def add_site_attribute(self, name, attr, computed = True):
+ """Add a site attribute.
+
+ Args:
+ name (str)
+ attr (ndarray): Must be of length ``n_sites``.
+ computed (bool): Whether to mark this attribute as "computed".
+ """
self._check_name(name)
attr = np.asarray(attr)
if not attr.shape[0] == self.n_sites:
raise ValueError("Attribute array has only %i entries; need one for all %i sites." % (len(attr), self.n_sites))
self._site_attrs[name] = attr
+ self._attr_computed[name] = computed
- def add_edge_attribute(self, name, attr):
+ def add_edge_attribute(self, name, attr, computed = True):
+ """Add an edge attribute.
+
+ Args:
+ name (str)
+ attr (ndarray): Must be of shape ``(n_sites, n_sites)``.
+ computed (bool): Whether to mark this attribute as "computed".
+ """
self._check_name(name)
attr = np.asarray(attr)
if not (attr.shape[0] == attr.shape[1] == self.n_sites):
raise ValueError("Attribute matrix has shape %s; need first two dimensions to be %i" % (attr.shape, self.n_sites))
self._edge_attrs[name] = attr
+ self._attr_computed[name] = computed
def _check_name(self, name):
if not SiteNetwork.ATTR_NAME_REGEX.match(name):
@@ -345,6 +405,6 @@ def _check_name(self, name):
raise ValueError("Attribute name `%s` reserved." % name)
def plot(self, *args, **kwargs):
- """Convenience method -- constructs a defualt SiteNetworkPlotter and calls it."""
+ """Convenience method -- constructs a defualt ``SiteNetworkPlotter`` and calls it."""
p = SiteNetworkPlotter(title = "Sites")
p(self, *args, **kwargs)
diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py
index f934edd..a994e28 100644
--- a/sitator/SiteTrajectory.py
+++ b/sitator/SiteTrajectory.py
@@ -1,14 +1,11 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
from sitator.util import PBCCalculator
-from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS
+from sitator.visualization import SiteTrajectoryPlotter
+from sitator.util.progress import tqdm
-import matplotlib
-from matplotlib.collections import LineCollection
+import logging
+logger = logging.getLogger(__name__)
class SiteTrajectory(object):
"""A trajectory capturing the dynamics of particles through a SiteNetwork."""
@@ -42,6 +39,8 @@ def __init__(self,
self._real_traj = None
+ self._default_plotter = None
+
def __len__(self):
return self.n_frames
@@ -51,28 +50,45 @@ def __getitem__(self, key):
confidences = None if self._confs is None else self._confs[key])
if not self._real_traj is None:
st.set_real_traj(self._real_traj[key])
-
return st
+ def __getstate__(self):
+ # Copy the object's state from self.__dict__ which contains
+ # all our instance attributes. Always use the dict.copy()
+ # method to avoid modifying the original state.
+ state = self.__dict__.copy()
+ # Don't want to pickle giant trajectories or uninteresting plotters
+ state['_real_traj'] = None
+ state['_default_plotter'] = None
+ return state
+
@property
def traj(self):
- """The underlying trajectory."""
+ """The site assignments over time."""
return self._traj
+ @property
+ def confidences(self):
+ return self._confs
+
@property
def n_frames(self):
+ """The number of frames in the trajectory."""
return len(self._traj)
@property
def n_unassigned(self):
+ """The total number of times a mobile particle is unassigned."""
return np.sum(self._traj < 0)
@property
def n_assigned(self):
+ """The total number of times a mobile particle was assigned to a site."""
return self._sn.n_mobile * self.n_frames - self.n_unassigned
@property
def percent_unassigned(self):
+ """Proportion of particle positions that are unassigned over all time."""
return float(self.n_unassigned) / (self._sn.n_mobile * self.n_frames)
@property
@@ -88,25 +104,49 @@ def site_network(self, value):
@property
def real_trajectory(self):
+ """The real-space trajectory this ``SiteTrajectory`` is based on."""
return self._real_traj
+ def copy(self, with_computed = True):
+ """Return a copy.
+
+ Args:
+ with_computed (bool): See ``SiteNetwork.copy()``.
+ """
+ st = self[:]
+ st.site_network = st.site_network.copy(with_computed = with_computed)
+ return st
+
def set_real_traj(self, real_traj):
"""Assocaite this SiteTrajectory with a trajectory of points in real space.
- The trajectory is not copied, and should have shape (n_frames, n_total)
+ The trajectory is not copied.
+
+ Args:
+ real_traj (ndarray of shape (n_frames, n_total))
"""
expected_shape = (self.n_frames, self._sn.n_total, 3)
if not real_traj.shape == expected_shape:
raise ValueError("real_traj of shape %s does not have expected shape %s" % (real_traj.shape, expected_shape))
self._real_traj = real_traj
+
def remove_real_traj(self):
"""Forget associated real trajectory."""
del self._real_traj
self._real_traj = None
+
def trajectory_for_particle(self, i, return_confidences = False):
- """Returns the array of sites particle i is assigned to over time."""
+ """Returns the array of sites particle i is assigned to over time.
+
+ Args:
+ i (int)
+ return_confidences (bool): If ``True``, also return the confidences
+ with which those assignments were made.
+ Returns:
+ ndarray (int) of length ``n_frames``[, ndarray (float) length ``n_frames``]
+ """
if return_confidences and self._confs is None:
raise ValueError("This SiteTrajectory has no confidences")
if return_confidences:
@@ -114,7 +154,19 @@ def trajectory_for_particle(self, i, return_confidences = False):
else:
return self._traj[:, i]
+
def real_positions_for_site(self, site, return_confidences = False):
+ """Get all real-space positions assocated with a site.
+
+ Args:
+ site (int)
+ return_confidences (bool): If ``True``, the confidences with which
+ each real-space position was assigned to ``site`` are also
+ returned.
+
+ Returns:
+ ndarray (N, 3)[, ndarray (N)]
+ """
if self._real_traj is None:
raise ValueError("This SiteTrajectory has no real trajectory")
if return_confidences and self._confs is None:
@@ -131,22 +183,69 @@ def real_positions_for_site(self, site, return_confidences = False):
else:
return pts
+
def compute_site_occupancies(self):
- """Computes site occupancies and adds site attribute `occupancies` to site_network."""
- occ = np.true_divide(np.bincount(self._traj[self._traj >= 0]), self.n_frames)
+ """Computes site occupancies.
+
+ Adds site attribute ``occupancies`` to ``site_network``.
+
+ In cases of multiple occupancy, this will be higher than the number of
+ frames in which the site is occupied and could be over 1.0.
+
+ Returns:
+ ndarray of occupancies (length ``n_sites``)
+ """
+ occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames)
+ if self.site_network.has_attribute('occupancies'):
+ self.site_network.remove_attribute('occupancies')
self.site_network.add_site_attribute('occupancies', occ)
return occ
- def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
- """Assign unassigned mobile particles to their last known site within
- `frame_threshold` frames.
- :returns: information dictionary of debugging/diagnostic information.
+ def check_multiple_occupancy(self, max_mobile_per_site = 1):
+ """Count cases of "multiple occupancy" where more than one mobile share the same site at the same time.
+
+ These cases usually indicate bad site analysis.
+
+ Returns:
+ int: the total number of multiple assignment incidents; and
+ float: the average number of mobile atoms at any site at any one time.
+ """
+ from sitator.errors import MultipleOccupancyError
+ n_more_than_ones = 0
+ avg_mobile_per_site = 0
+ divisor = 0
+ for frame_i, site_frame in enumerate(self._traj):
+ sites, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
+ count_msk = counts > max_mobile_per_site
+ if np.any(count_msk):
+ first_multi_site = sites[count_msk][0]
+ raise MultipleOccupancyError(
+ mobile = np.where(site_frame == first_multi_site)[0],
+ site = first_multi_site,
+ frame = frame_i
+ )
+ n_more_than_ones += np.sum(counts > 1)
+ avg_mobile_per_site += np.sum(counts)
+ divisor += len(counts)
+ avg_mobile_per_site /= divisor
+ return n_more_than_ones, avg_mobile_per_site
+
+
+ def assign_to_last_known_site(self, frame_threshold = 1):
+ """Assign unassigned mobile particles to their last known site.
+
+ Args:
+ frame_threshold (int): The maximum number of frames between the last
+ known site and the present frame up to which the last known site
+ can be used.
+
+ Returns:
+ information dictionary of debugging/diagnostic information.
"""
total_unknown = self.n_unassigned
- if verbose:
- print("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %i frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold))
+ logger.info("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %s frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold))
last_known = np.empty(shape = self._sn.n_mobile, dtype = np.int)
last_known.fill(-1)
@@ -156,7 +255,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
max_time_unknown = 0
total_reassigned = 0
- for i in xrange(self.n_frames):
+ for i in range(self.n_frames):
# All those unknown this frame
unknown = self._traj[i] == -1
# Update last_known for assigned sites
@@ -184,10 +283,9 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
if avg_time_unknown_div > 0: # We corrected some unknowns
avg_time_unknown = float(avg_time_unknown) / avg_time_unknown_div
- if verbose:
- print(" Maximum # of frames any mobile particle spent unassigned: %i" % max_time_unknown)
- print(" Avg. # of frames spent unassigned: %f" % avg_time_unknown)
- print(" Assigned %i/%i unassigned positions, leaving %i (%i%%) unknown" % (total_reassigned, total_unknown, self.n_unassigned, self.percent_unassigned))
+ logger.info(" Maximum # of frames any mobile particle spent unassigned: %i" % max_time_unknown)
+ logger.info(" Avg. # of frames spent unassigned: %f" % avg_time_unknown)
+ logger.info(" Assigned %i/%i unassigned positions, leaving %i (%i%%) unknown" % (total_reassigned, total_unknown, self.n_unassigned, self.percent_unassigned))
res = {
'max_time_unknown' : max_time_unknown,
@@ -195,8 +293,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
'total_reassigned' : total_reassigned
}
else:
- if self.verbose:
- print(" None to correct.")
+ logger.info(" None to correct.")
res = {
'max_time_unknown' : 0,
@@ -206,119 +303,87 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True):
return res
- @plotter(is3D = True)
- def plot_frame(self, frame, **kwargs):
- sites_of_frame = np.unique(self._traj[frame])
- frame_sn = self._sn[sites_of_frame]
- frame_sn.plot(**kwargs)
+ def jumps(self, **kwargs):
+ """Iterate over all jumps in the trajectory, jump by jump.
- if not self._real_traj is None:
- mobile_atoms = self._sn.structure.copy()
- del mobile_atoms[~self._sn.mobile_mask]
-
- mobile_atoms.positions[:] = self._real_traj[frame, self._sn.mobile_mask]
- plot_atoms(atoms = mobile_atoms, **kwargs)
-
- kwargs['ax'].set_title("Frame %i/%i" % (frame, self.n_frames))
-
- @plotter(is3D = True)
- def plot_site(self, site, **kwargs):
- pbcc = PBCCalculator(self._sn.structure.cell)
- pts = self.real_positions_for_site(site).copy()
- offset = pbcc.cell_centroid - pts[3]
- pts += offset
- pbcc.wrap_points(pts)
- lattice_pos = self._sn.static_structure.positions.copy()
- lattice_pos += offset
- pbcc.wrap_points(lattice_pos)
- site_pos = self._sn.centers[site:site+1].copy()
- site_pos += offset
- pbcc.wrap_points(site_pos)
- # Plot point cloud
- plot_points(points = pts, alpha = 0.3, marker = '.', color = 'k', **kwargs)
- # Plot site
- plot_points(points = site_pos, color = 'cyan', **kwargs)
- # Plot everything else
- plot_atoms(self._sn.static_structure, positions = lattice_pos, **kwargs)
-
- title = "Site %i/%i" % (site, len(self._sn))
-
- if not self._sn.site_types is None:
- title += " (type %i)" % self._sn.site_types[site]
-
- kwargs['ax'].set_title(title)
-
- @plotter(is3D = False)
- def plot_particle_trajectory(self, particle, ax = None, fig = None, **kwargs):
- types = not self._sn.site_types is None
- if types:
- type_height_percent = 0.1
- axpos = ax.get_position()
- typeax_height = type_height_percent * axpos.height
- typeax = fig.add_axes([axpos.x0, axpos.y0, axpos.width, typeax_height], sharex = ax)
- ax.set_position([axpos.x0, axpos.y0 + typeax_height, axpos.width, axpos.height - typeax_height])
- type_height = 1
- # Draw trajectory
- segments = []
- linestyles = []
- colors = []
-
- traj = self._traj[:, particle]
- current_value = traj[0]
- last_value = traj[0]
- if types:
- last_type = None
- current_segment_start = 0
- puttext = False
-
- for i, f in enumerate(traj):
- if f != current_value or i == len(traj) - 1:
- val = last_value if current_value == -1 else current_value
- segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
- linestyles.append(':' if current_value == -1 else '-')
- colors.append('lightgray' if current_value == -1 else 'k')
-
- if types:
- rxy = (current_segment_start, 0)
- this_type = self._sn.site_types[val]
- typerect = matplotlib.patches.Rectangle(rxy, i - current_segment_start, type_height,
- color = DEFAULT_COLORS[this_type], linewidth = 0)
- typeax.add_patch(typerect)
- if this_type != last_type:
- typeax.annotate("T%i" % this_type,
- xy = (rxy[0], rxy[1] + 0.5 * type_height),
- xytext = (3, -1),
- textcoords = 'offset points',
- fontsize = 'xx-small',
- va = 'center',
- fontweight = 'bold')
- last_type = this_type
-
- last_value = val
- current_segment_start = i
- current_value = f
-
- lc = LineCollection(segments, linestyles = linestyles, colors = colors, linewidth=1.5)
- ax.add_collection(lc)
-
- if types:
- typeax.set_xlabel("Frame")
- ax.tick_params(axis = 'x', which = 'both', bottom = False, top = False, labelbottom = False)
- typeax.tick_params(axis = 'y', which = 'both', left = False, right = False, labelleft = False)
- typeax.annotate("Type", xy = (0, 0.5), xytext = (-25, 0), xycoords = 'axes fraction', textcoords = 'offset points', va = 'center', fontsize = 'x-small')
- else:
- ax.set_xlabel("Frame")
- ax.set_ylabel("Atom %i's site" % particle)
+ A jump is considered to occur "at the frame" when it first acheives its
+ new site. For example,
- ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
- ax.grid()
+ - Frame 0: Atom 1 at site 4
+ - Frame 1: Atom 1 at site 5
- ax.set_xlim((0, self.n_frames - 1))
- margin_percent = 0.04
- ymargin = (margin_percent * self._sn.n_sites)
- ax.set_ylim((-ymargin, self._sn.n_sites - 1.0 + ymargin))
+ will yield a jump ``(1, 1, 4, 5)``.
- if types:
- typeax.set_xlim((0, self.n_frames - 1))
- typeax.set_ylim((0, type_height))
+ Args:
+ unknown_as_jump (bool): If ``True``, moving from a site to unknown
+ (or vice versa) is considered a jump; if ``False``, unassigned
+ mobile atoms are considered to be at their last known sites.
+ Yields:
+ tuple: (frame_number, mobile_atom_number, from_site, to_site)
+ """
+ n_mobile = self.site_network.n_mobile
+ for frame_i, jumped, last_known, frame in self._jumped_generator(**kwargs):
+ for atom_i in range(n_mobile):
+ if jumped[atom_i]:
+ yield frame_i, atom_i, last_known[atom_i], frame[atom_i]
+
+ def jumps_by_frame(self, **kwargs):
+ """Iterate over all jumps in the trajectory, frame by frame.
+
+ A jump is considered to occur "at the frame" when it first acheives its
+ new site. For example,
+
+ - Frame 0: Atom 1 at site 4
+ - Frame 1: Atom 1 at site 5
+
+ will yield a jump ``(1, 1, 4, 5)``.
+
+ Args:
+ unknown_as_jump (bool): If ``True``, moving from a site to unknown
+ (or vice versa) is considered a jump; if ``False``, unassigned
+ mobile atoms are considered to be at their last known sites.
+ Yields:
+ tuple: (frame_number, mob_that_jumped, from_sites, to_sites)
+ """
+ n_mobile = self.site_network.n_mobile
+ for frame_i, jumped, last_known, frame in self._jumped_generator(**kwargs):
+ yield frame_i, np.where(jumped)[0], last_known[jumped], frame[jumped]
+
+ def _jumped_generator(self, unknown_as_jump = False):
+ """Internal jump generator that does not create intermediate arrays.
+
+ Wrapped by convinience functions.
+ """
+ traj = self.traj
+ n_mobile = self.site_network.n_mobile
+ assert n_mobile == traj.shape[1]
+ last_known = traj[0].copy()
+ known = np.ones(shape = len(last_known), dtype = np.bool)
+ jumped = np.zeros(shape = len(last_known), dtype = np.bool)
+ for frame_i in range(1, self.n_frames):
+ if not unknown_as_jump:
+ np.not_equal(traj[frame_i], SiteTrajectory.SITE_UNKNOWN, out = known)
+
+ np.not_equal(traj[frame_i], last_known, out = jumped)
+ jumped &= known # Must be currently known to have jumped
+
+ yield frame_i, jumped, last_known, traj[frame_i]
+
+ last_known[known] = traj[frame_i, known]
+
+ # ---- Plotting code
+ def plot_frame(self, *args, **kwargs):
+ if self._default_plotter is None:
+ self._default_plotter = SiteTrajectoryPlotter()
+ self._default_plotter.plot_frame(self, *args, **kwargs)
+
+ def plot_site(self, *args, **kwargs):
+ if self._default_plotter is None:
+ self._default_plotter = SiteTrajectoryPlotter()
+ self._default_plotter.plot_site(self, *args, **kwargs)
+
+ def plot_particle_trajectory(self, *args, **kwargs):
+ if self._default_plotter is None:
+ self._default_plotter = SiteTrajectoryPlotter()
+ self._default_plotter.plot_particle_trajectory(self, *args, **kwargs)
diff --git a/sitator/__init__.py b/sitator/__init__.py
index 0f393bc..1ca01b0 100644
--- a/sitator/__init__.py
+++ b/sitator/__init__.py
@@ -1,3 +1,3 @@
-from SiteNetwork import SiteNetwork
-from SiteTrajectory import SiteTrajectory
+from .SiteNetwork import SiteNetwork
+from .SiteTrajectory import SiteTrajectory
diff --git a/sitator/descriptors/ConfigurationalEntropy.py b/sitator/descriptors/ConfigurationalEntropy.py
index a1c31ac..4008588 100644
--- a/sitator/descriptors/ConfigurationalEntropy.py
+++ b/sitator/descriptors/ConfigurationalEntropy.py
@@ -3,21 +3,28 @@
from sitator import SiteTrajectory
from sitator.dynamics import JumpAnalysis
+import logging
+logger = logging.getLogger(__name__)
+
class ConfigurationalEntropy(object):
- """Compute the S~ configurational entropy.
+ """Compute the S~ (S tilde) configurational entropy.
If the SiteTrajectory lacks type information, the summation is taken over
the sites rather than the site types.
- Ref:
- Structural, Chemical, and Dynamical Frustration: Origins of Superionic Conductivity in closo-Borate Solid Electrolytes
- Kyoung E. Kweon, Joel B. Varley, Patrick Shea, Nicole Adelstein, Prateek Mehta, Tae Wook Heo, Terrence J. Udovic, Vitalie Stavila, and Brandon C. Wood
+ Reference:
+ Structural, Chemical, and Dynamical Frustration: Origins of Superionic
+ Conductivity in closo-Borate Solid Electrolytes
+
+ Kyoung E. Kweon, Joel B. Varley, Patrick Shea, Nicole Adelstein,
+ Prateek Mehta, Tae Wook Heo, Terrence J. Udovic, Vitalie Stavila,
+ and Brandon C. Wood
+
Chemistry of Materials 2017 29 (21), 9142-9153
DOI: 10.1021/acs.chemmater.7b02902
"""
- def __init__(self, acceptable_overshoot = 0.0001, verbose = True):
+ def __init__(self, acceptable_overshoot = 0.0):
self.acceptable_overshoot = acceptable_overshoot
- self.verbose = verbose
def compute(self, st):
assert isinstance(st, SiteTrajectory)
@@ -52,16 +59,13 @@ def compute(self, st):
size_of_problems = p2 - 1.0
forgivable = problems & (size_of_problems < self.acceptable_overshoot)
- if self.verbose:
- print "n_i " + ("{:5.3} " * len(n_i)).format(*n_i)
- print "N_i " + ("{:>5} " * len(N_i)).format(*N_i)
- print " " + ("------" * len(n_i))
- print "P_2 " + ("{:5.3} " * len(p2)).format(*p2)
+ logger.info("n_i " + ("{:5.3} " * len(n_i)).format(*n_i))
+ logger.info("N_i " + ("{:>5} " * len(N_i)).format(*N_i))
+ logger.info(" " + ("------" * len(n_i)))
+ logger.info("P_2 " + ("{:5.3} " * len(p2)).format(*p2))
if not np.all(problems == forgivable):
raise ValueError("P_2 values for site types %s larger than 1.0 + acceptable_overshoot (%f)" % (np.where(problems)[0], self.acceptable_overshoot))
- elif np.any(problems) and self.verbose:
- print ""
# Correct forgivable problems
p2[forgivable] = 1.0
diff --git a/sitator/descriptors/__init__.py b/sitator/descriptors/__init__.py
index 0f2d416..31d3db7 100644
--- a/sitator/descriptors/__init__.py
+++ b/sitator/descriptors/__init__.py
@@ -1 +1 @@
-from ConfigurationalEntropy import ConfigurationalEntropy
+from .ConfigurationalEntropy import ConfigurationalEntropy
diff --git a/sitator/dynamics/AverageVibrationalFrequency.py b/sitator/dynamics/AverageVibrationalFrequency.py
new file mode 100644
index 0000000..88b6d34
--- /dev/null
+++ b/sitator/dynamics/AverageVibrationalFrequency.py
@@ -0,0 +1,61 @@
+import numpy as np
+
+class AverageVibrationalFrequency(object):
+ """Compute the average vibrational frequency of indicated atoms in a trajectory.
+
+ Uses the method described in section 2.2 of this paper:
+
+ Klerk, Niek J.J. de, Eveline van der Maas, and Marnix Wagemaker.
+
+ Analysis of Diffusion in Solid-State Electrolytes through MD Simulations,
+ Improvement of the Li-Ion Conductivity in β-Li3PS4 as an Example.
+
+ ACS Applied Energy Materials 1, no. 7 (July 23, 2018): 3230–42.
+ https://doi.org/10.1021/acsaem.8b00457.
+
+ Args:
+ min_frequency (float, units: timestep^-1): Compute mean frequency of power
+ spectrum above this frequency.
+ max_frequency (float, units: timestep^-1): Compute mean frequency of power
+ spectrum below this frequency.
+ """
+ def __init__(self,
+ min_frequency = 0,
+ max_frequency = np.inf):
+ # Always want to exclude DC frequency
+ assert min_frequency >= 0
+ self.min_frequency = min_frequency
+ self.max_frequency = max_frequency
+
+ def compute_avg_vibrational_freq(self, traj, mask, return_stdev = False):
+ """Compute the average vibrational frequency.
+
+ Args:
+ traj (ndarray n_frames x n_atoms x 3): An MD trajectory.
+ mask (ndarray n_atoms bool): Which atoms to average over.
+ Returns:
+ A frequency in units of (timestep)^-1
+ """
+ speeds = traj[1:, mask]
+ speeds -= traj[:-1, mask]
+ speeds = np.linalg.norm(speeds, axis = 2)
+
+ freqs = np.fft.rfftfreq(speeds.shape[0])
+ fmask = (freqs > self.min_frequency) & (freqs < self.max_frequency)
+ assert np.any(fmask), "Trajectory too short?"
+ freqs = freqs[fmask]
+
+ # de Klerk et. al. do an average over the atom-by-atom averages
+ n_mob = speeds.shape[1]
+ avg_freqs = np.empty(shape = n_mob)
+
+ for mob in range(n_mob):
+ ps = np.abs(np.fft.rfft(speeds[:, mob])) ** 2
+ avg_freqs[mob] = np.average(freqs, weights = ps[fmask])
+
+ avg_freq = np.mean(avg_freqs)
+
+ if return_stdev:
+ return avg_freq, np.std(avg_freqs)
+ else:
+ return avg_freq
diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py
deleted file mode 100644
index 245f278..0000000
--- a/sitator/dynamics/DiffusionPathwayAnalysis.py
+++ /dev/null
@@ -1,78 +0,0 @@
-
-import numpy as np
-
-import numbers
-
-from scipy.sparse.csgraph import connected_components
-
-class DiffusionPathwayAnalysis(object):
- """Find connected diffusion pathways in a SiteNetwork.
-
- :param float|int connectivity_threshold: The percentage of the total number of
- (non-self) jumps, or absolute number of jumps, that must occur over an edge
- for it to be considered connected.
- :param int minimum_n_sites: The minimum number of sites that must be part of
- a pathway for it to be considered as such.
- """
-
- NO_PATHWAY = -1
-
- def __init__(self,
- connectivity_threshold = 0.001,
- minimum_n_sites = 4,
- verbose = True):
- assert minimum_n_sites >= 0
-
- self.connectivity_threshold = connectivity_threshold
- self.minimum_n_sites = minimum_n_sites
-
- self.verbose = verbose
-
- def run(self, sn):
- """
- Expects a SiteNetwork that has had a JumpAnalysis run on it.
- """
- if not sn.has_attribute('n_ij'):
- raise ValueError("SiteNetwork has no `n_ij`; run a JumpAnalysis on it first.")
-
- nondiag = np.ones(shape = sn.n_ij.shape, dtype = np.bool)
- np.fill_diagonal(nondiag, False)
- n_non_self_jumps = np.sum(sn.n_ij[nondiag])
-
- if isinstance(self.connectivity_threshold, numbers.Integral):
- threshold = self.connectivity_threshold
- elif isinstance(self.connectivity_threshold, numbers.Real):
- threshold = self.connectivity_threshold * n_non_self_jumps
- else:
- raise TypeError("Don't know how to interpret connectivity_threshold `%s`" % self.connectivity_threshold)
-
- connectivity_matrix = sn.n_ij >= threshold
-
- n_ccs, ccs = connected_components(connectivity_matrix,
- directed = False, # even though the matrix is symmetric
- connection = 'weak') # diffusion could be unidirectional
-
- _, counts = np.unique(ccs, return_counts = True)
-
- is_pathway = counts >= self.minimum_n_sites
-
- if self.verbose:
- print "Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps)
- print "Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway))
-
- translation = np.empty(n_ccs, dtype = np.int)
- translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY
- translation[is_pathway] = np.arange(np.sum(is_pathway))
-
- node_pathways = translation[ccs]
-
- outmat = np.empty(shape = (sn.n_sites, sn.n_sites), dtype = np.int)
-
- for i in xrange(sn.n_sites):
- rowmask = node_pathways[i] == node_pathways
- outmat[i, rowmask] = node_pathways[i]
- outmat[i, ~rowmask] = DiffusionPathwayAnalysis.NO_PATHWAY
-
- sn.add_site_attribute('site_diffusion_pathway', node_pathways)
- sn.add_edge_attribute('edge_diffusion_pathway', outmat)
- return sn
diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py
index af5085c..945fd7d 100644
--- a/sitator/dynamics/JumpAnalysis.py
+++ b/sitator/dynamics/JumpAnalysis.py
@@ -5,31 +5,39 @@
from sitator import SiteNetwork, SiteTrajectory
from sitator.visualization import plotter, plot_atoms, layers
+import logging
+logger = logging.getLogger(__name__)
+
class JumpAnalysis(object):
"""Given a SiteTrajectory, compute various statistics about the jumps it contains.
Adds these edge attributes to the SiteTrajectory's SiteNetwork:
- - `n_ij`: total number of jumps from i to j.
- - `p_ij`: being at i, the probability of jumping to j.
- - `jump_lag`: The average number of frames a particle spends at i before jumping
+ - ``n_ij``: total number of jumps from i to j.
+ - ``p_ij``: being at i, the probability of jumping to j.
+ - ``jump_lag``: The average number of frames a particle spends at i before jumping
to j. Can be +inf if no such jumps every occur.
And these site attributes:
- - `residence_times`: Avg. number of frames a particle spends at a site before jumping.
- - `total_corrected_residences`: Total number of frames when a particle was at the site,
+ - ``residence_times``: Avg. number of frames a particle spends at a site before jumping.
+ - ``total_corrected_residences``: Total number of frames when a particle was at the site,
*including* frames when an unassigned particle's last known site was this site.
"""
- def __init__(self, verbose = True):
- self.verbose = verbose
+ def __init__(self):
+ pass
def run(self, st):
"""Run the analysis.
- Adds edge attributes to st's SiteNetwork and returns st.
+ Adds edge attributes to ``st``'s ``SiteNetwork``.
+
+ Args:
+ st (SiteTrajectory)
+
+ Returns:
+ ``st``
"""
assert isinstance(st, SiteTrajectory)
- if self.verbose:
- print "Running JumpAnalysis..."
+ logger.info("Running JumpAnalysis...")
n_mobile = st.site_network.n_mobile
n_frames = st.n_frames
@@ -60,17 +68,13 @@ def run(self, st):
unassigned = frame == SiteTrajectory.SITE_UNKNOWN
# Reassign unassigned
frame[unassigned] = last_known[unassigned]
- fknown = frame >= 0
+ fknown = (frame >= 0) & (last_known >= 0)
- if np.any(~fknown) and self.verbose:
- print " at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown))
+ n_problems += np.sum(~fknown)
# -- Update stats
total_time_spent_at_site[frame[fknown]] += 1
jumped = (frame != last_known) & fknown
- problems = last_known[jumped] == -1
- jumped[np.where(jumped)[0][problems]] = False
- n_problems += np.sum(problems)
n_ij[last_known[fknown], frame[fknown]] += 1
@@ -94,8 +98,8 @@ def run(self, st):
# The time before jumping to self should always be inf
assert not np.any(np.nonzero(avg_time_before_jump.diagonal()))
- if self.verbose and n_problems != 0:
- print "Came across %i times where assignment and last known assignment were unassigned." % n_problems
+ if n_problems != 0:
+ logger.warning("Came across %i times where assignment and last known assignment were unassigned." % n_problems)
msk = avg_time_before_jump_n > 0
# Zeros -- i.e. no jumps -- should actualy be infs
@@ -103,12 +107,20 @@ def run(self, st):
# Do mean
avg_time_before_jump[msk] /= avg_time_before_jump_n[msk]
+ if st.site_network.has_attribute('n_ij'):
+ st.site_network.remove_attribute('n_ij')
+ st.site_network.remove_attribute('p_ij')
+ st.site_network.remove_attribute('jump_lag')
+ st.site_network.remove_attribute('residence_times')
+ st.site_network.remove_attribute('occupancy_freqs')
+ st.site_network.remove_attribute('total_corrected_residences')
+
st.site_network.add_edge_attribute('jump_lag', avg_time_before_jump)
st.site_network.add_edge_attribute('n_ij', n_ij)
st.site_network.add_edge_attribute('p_ij', n_ij / total_time_spent_at_site)
res_times = np.empty(shape = n_sites, dtype = np.float)
- for site in xrange(n_sites):
+ for site in range(n_sites):
times = avg_time_before_jump[site]
noninf = times < np.inf
if np.any(noninf):
@@ -125,18 +137,23 @@ def run(self, st):
def jump_lag_by_type(self,
sn,
return_counts = False):
- """Given a SiteNetwork with jump_lag info, compute avg. residence times by type
+ """Given a SiteNetwork with jump_lag info, compute avg. residence times by type.
Computes the average number of frames a mobile particle spends at each
type of site before jumping to each other type of site.
- Returns an (n_types, n_types) matrix. If no jumps of a given type occured,
+ Args:
+ sn (SiteNetwork)
+ return_counts (bool): Whether to also return a matrix giving the
+ number of each type of jump that occured.
+
+
+ Returns:
+ An (n_types, n_types) matrix. If no jumps of a given type occured,
the corresponding entry is +inf.
- If ``return_counts``, then also returns a matrix giving the number of
- each type of jump that occured.
+ If ``return_counts``, two such matrixes.
"""
-
if sn.site_types is None:
raise ValueError("SiteNetwork has no type information.")
@@ -148,7 +165,7 @@ def jump_lag_by_type(self,
if return_counts:
countmat = np.empty(shape = outmat.shape, dtype = np.int)
- for stype_from, stype_to in itertools.product(xrange(len(all_types)), repeat = 2):
+ for stype_from, stype_to in itertools.product(range(len(all_types)), repeat = 2):
lags = sn.jump_lag[site_types == all_types[stype_from]][:, site_types == all_types[stype_to]]
# Only take things that aren't inf
lags = lags[lags < np.inf]
@@ -172,13 +189,11 @@ def plot_jump_lag(self, sn, mode = 'site', min_n_events = 1, ax = None, fig = No
"""Plot the jump lag of a site network.
:param SiteNetwork sn:
- :param str mode: If 'site', show jump lag between individual sites.
- If 'type', show jump lag between types of sites (see :func:jump_lag_by_type)
- Default: 'site'
+ :param str mode: If ``'site'``, show jump lag between individual sites.
+ If ``'type'``, show jump lag between types of sites (see :func:jump_lag_by_type)
:param int min_n_events: Minimum number of jump events of a given type
(i -> j or type -> type) to show a jump lag. If a given jump has
- occured less than min_n_events times, no jump lag will be shown.
- Default: 1
+ occured less than ``min_n_events`` times, no jump lag will be shown.
"""
if mode == 'site':
mat = np.copy(sn.jump_lag)
@@ -192,7 +207,7 @@ def plot_jump_lag(self, sn, mode = 'site', min_n_events = 1, ax = None, fig = No
mat[counts < min_n_events] = np.nan
# Show diagonal
- ax.plot(*zip([0.0, 0.0], mat.shape), color = 'k', alpha = 0.5, linewidth = 1, linestyle = '--')
+ ax.plot(*list(zip([0.0, 0.0], mat.shape)), color = 'k', alpha = 0.5, linewidth = 1, linestyle = '--')
ax.grid()
im = ax.matshow(mat, zorder = 10, cmap = 'plasma')
diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py
index 5288cca..c8950e5 100644
--- a/sitator/dynamics/MergeSitesByDynamics.py
+++ b/sitator/dynamics/MergeSitesByDynamics.py
@@ -1,16 +1,23 @@
import numpy as np
-from sitator import SiteNetwork, SiteTrajectory
from sitator.dynamics import JumpAnalysis
from sitator.util import PBCCalculator
+from sitator.network.merging import MergeSites
+from sitator.util.mcl import markov_clustering
-class MergeSitesByDynamics(object):
+import logging
+logger = logging.getLogger(__name__)
+
+
+class MergeSitesByDynamics(MergeSites):
"""Merges sites using dynamical data.
Given a SiteTrajectory, merges sites using Markov Clustering.
:param float distance_threshold: Don't merge sites further than this
- in real space.
+ in real space. Zeros out the connectivity_matrix at distances greater than
+ this; a hard, step function style cutoff. For a more gentle cutoff, try
+ changing `connectivity_matrix_generator` to incorporate distance.
:param float post_check_thresh_factor: Throw an error if proposed merge sites
are further than this * distance_threshold away. Only a sanity check; not
a hard guerantee. Can be `None`; defaults to `1.5`. Can be loosely
@@ -25,38 +32,92 @@ class MergeSitesByDynamics(object):
Valid keys are ``'inflation'``, ``'expansion'``, and ``'pruning_threshold'``.
"""
def __init__(self,
+ connectivity_matrix_generator = None,
distance_threshold = 1.0,
post_check_thresh_factor = 1.5,
check_types = True,
- verbose = True,
- iterlimit = 100,
markov_parameters = {}):
- self.verbose = verbose
+ super().__init__(
+ maximum_merge_distance = post_check_thresh_factor * distance_threshold,
+ check_types = check_types
+ )
+
+ if connectivity_matrix_generator is None:
+ connectivity_matrix_generator = MergeSitesByDynamics.connectivity_n_ij
+ assert callable(connectivity_matrix_generator)
+
+ self.connectivity_matrix_generator = connectivity_matrix_generator
self.distance_threshold = distance_threshold
self.post_check_thresh_factor = post_check_thresh_factor
self.check_types = check_types
self.iterlimit = iterlimit
self.markov_parameters = markov_parameters
- def run(self, st):
- """Takes a SiteTrajectory and returns a SiteTrajectory, including a new SiteNetwork."""
+ # Connectivity Matrix Generation Schemes:
+
+ @staticmethod
+ def connectivity_n_ij(sn):
+ """Basic default connectivity scheme: uses n_ij directly as connectivity matrix.
+
+ Works well for systems with sufficient statistics.
+ """
+ return sn.n_ij
+
+ @staticmethod
+ def connectivity_jump_lag_biased(jump_lag_coeff = 1.0,
+ jump_lag_sigma = 20.0,
+ jump_lag_cutoff = np.inf,
+ distance_coeff = 0.5,
+ distance_sigma = 1.0):
+ """Bias the typical connectivity matrix p_ij with jump lag and distance contributions.
+
+ The jump lag and distance are processed through Gaussian functions with
+ the given sigmas (i.e. higher jump lag/larger distance => lower
+ connectivity value). These matrixes are then added to p_ij, with a prefactor
+ of ``jump_lag_coeff`` and ``distance_coeff``.
+
+ Site pairs with jump lags greater than ``jump_lag_cutoff`` have their bias
+ set to zero regardless of ``jump_lag_sigma``. Defaults to ``inf``.
+ """
+ def cfunc(sn):
+ jl = sn.jump_lag.copy()
+ jl -= 1.0 # Center it around 1 since that's the minimum lag, 1 frame
+ jl /= jump_lag_sigma
+ np.square(jl, out = jl)
+ jl *= -0.5
+ np.exp(jl, out = jl) # exp correctly takes the -infs to 0
+
+ jl[sn.jump_lag > jump_lag_cutoff] = 0.
+
+ # Distance term
+ pbccalc = PBCCalculator(sn.structure.cell)
+ dists = pbccalc.pairwise_distances(sn.centers)
+ dmat = dists.copy()
+
+ # We want to strongly boost the similarity of *very* close sites
+ dmat /= distance_sigma
+ np.square(dmat, out = dmat)
+ dmat *= -0.5
+ np.exp(dmat, out = dmat)
- if self.check_types and st.site_network.site_types is None:
- raise ValueError("Cannot run a check_types=True MergeSitesByDynamics on a SiteTrajectory without type information.")
+ return (sn.p_ij + jump_lag_coeff * jl) * (distance_coeff * dmat + (1 - distance_coeff))
+ return cfunc
+
+ # Real methods
+
+ def _get_sites_to_merge(self, st):
# -- Compute jump statistics
- if not st.site_network.has_attribute('p_ij'):
- ja = JumpAnalysis(verbose = self.verbose)
+ if not st.site_network.has_attribute('n_ij'):
+ ja = JumpAnalysis()
ja.run(st)
pbcc = PBCCalculator(st.site_network.structure.cell)
site_centers = st.site_network.centers
- if self.check_types:
- site_types = st.site_network.site_types
# -- Build connectivity_matrix
- connectivity_matrix = st.site_network.n_ij.copy()
+ connectivity_matrix = self.connectivity_matrix_generator(st.site_network).copy()
n_sites_before = st.site_network.n_sites
assert n_sites_before == connectivity_matrix.shape[0]
@@ -71,7 +132,7 @@ def run(self, st):
n_alarming_ignored_edges = 0
# Apply distance threshold
- for i in xrange(n_sites_before):
+ for i in range(n_sites_before):
dists = pbcc.distances(centers_before[i], centers_before[i + 1:])
js_too_far = np.where(dists > self.distance_threshold)[0]
js_too_far += i + 1
@@ -83,120 +144,10 @@ def run(self, st):
connectivity_matrix[i, js_too_far] = 0
connectivity_matrix[js_too_far, i] = 0 # Symmetry
- if self.verbose and n_alarming_ignored_edges > 0:
- print(" At least %i site pairs with high (z-score > 3) fluxes were over the given distance cutoff.\n"
- " This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges)
+ if n_alarming_ignored_edges > 0:
+ logger.warning(" At least %i site pairs with high (z-score > 3) fluxes were over the given distance cutoff.\n"
+ " This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges)
# -- Do Markov Clustering
- clusters = self._markov_clustering(connectivity_matrix, **self.markov_parameters)
-
- new_n_sites = len(clusters)
-
- if self.verbose:
- print "After merge there will be %i sites" % new_n_sites
-
- if self.check_types:
- new_types = np.empty(shape = new_n_sites, dtype = np.int)
-
- # -- Merge Sites
- new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype)
- translation = np.empty(shape = st.site_network.n_sites, dtype = np.int)
- translation.fill(-1)
-
- for newsite in xrange(new_n_sites):
- mask = list(clusters[newsite])
- # Update translation table
- if np.any(translation[mask] != -1):
- # We've assigned a different cluster for this before... weird
- # degeneracy
- raise ValueError("Markov clustering tried to merge site(s) into more than one new site")
- translation[mask] = newsite
-
- to_merge = site_centers[mask]
-
- # Check distances
- if not self.post_check_thresh_factor is None:
- dists = pbcc.distances(to_merge[0], to_merge[1:])
- assert np.all(dists < self.post_check_thresh_factor * self.distance_threshold), \
- "Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold)
-
- # New site center
- new_centers[newsite] = pbcc.average(to_merge)
- if self.check_types:
- assert np.all(site_types[mask] == site_types[mask][0])
- new_types[newsite] = site_types[mask][0]
-
- newsn = st.site_network.copy()
- newsn.centers = new_centers
- if self.check_types:
- newsn.site_types = new_types
-
- newtraj = translation[st._traj]
- newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN
-
- # It doesn't make sense to propagate confidence information through a
- # transform that might completely invalidate it
- newst = SiteTrajectory(newsn, newtraj, confidences = None)
-
- if not st.real_trajectory is None:
- newst.set_real_traj(st.real_trajectory)
-
- return newst
-
- def _markov_clustering(self,
- transition_matrix,
- expansion = 2,
- inflation = 2,
- pruning_threshold = 0.00001):
- """
- See https://micans.org/mcl/.
-
- Because we're dealing with matrixes that are stochastic already,
- there's no need to add artificial loop values.
-
- Implementation inspired by https://github.com/GuyAllard/markov_clustering
- """
-
- assert transition_matrix.shape[0] == transition_matrix.shape[1]
-
- m1 = transition_matrix.copy()
-
- # Normalize (though it should be close already)
- m1 /= np.sum(m1, axis = 0)
-
- allcols = np.arange(m1.shape[1])
-
- converged = False
- for i in xrange(self.iterlimit):
- # -- Expansion
- m2 = np.linalg.matrix_power(m1, expansion)
- # -- Inflation
- np.power(m2, inflation, out = m2)
- m2 /= np.sum(m2, axis = 0)
- # -- Prune
- to_prune = m2 < pruning_threshold
- # Exclude the max of every column
- to_prune[np.argmax(m2, axis = 0), allcols] = False
- m2[to_prune] = 0.0
- # -- Check converged
- if np.allclose(m1, m2):
- converged = True
- if self.verbose:
- print "Markov Clustering converged in %i iterations" % i
- break
-
- m1[:] = m2
-
- if not converged:
- raise ValueError("Markov Clustering couldn't converge in %i iterations" % self.iterlimit)
-
- # -- Get clusters
- attractors = m2.diagonal().nonzero()[0]
-
- clusters = set()
-
- for a in attractors:
- cluster = tuple(m2[a].nonzero()[0])
- clusters.add(cluster)
-
- return list(clusters)
+ clusters = markov_clustering(connectivity_matrix, **self.markov_parameters)
+ return clusters
diff --git a/sitator/dynamics/MergeSitesByThreshold.py b/sitator/dynamics/MergeSitesByThreshold.py
new file mode 100644
index 0000000..449c236
--- /dev/null
+++ b/sitator/dynamics/MergeSitesByThreshold.py
@@ -0,0 +1,87 @@
+import numpy as np
+
+import operator
+
+from scipy.sparse.csgraph import connected_components
+
+from sitator.util import PBCCalculator
+from sitator.network.merging import MergeSites
+
+
+class MergeSitesByThreshold(MergeSites):
+ """Merge sites using a strict threshold on any edge property.
+
+ Takes the edge property matrix given by ``attrname``, applys ``relation`` to it
+ with ``threshold``, and merges all connected components in the graph represented
+ by the resulting boolean adjacency matrix.
+
+ Threshold is given by a keyword argument to ``run()``.
+
+ Args:
+ attrname (str): Name of the edge attribute to merge on.
+ relation (func, default: ``operator.ge``): The relation to use for the
+ thresholding.
+ directed, connection (bool, str): Parameters for ``scipy.sparse.csgraph``'s
+ ``connected_components``.
+ **kwargs: Passed to ``MergeSites``.
+ """
+ def __init__(self,
+ attrname,
+ relation = operator.ge,
+ directed = True,
+ connection = 'strong',
+ distance_threshold = np.inf,
+ forbid_multiple_occupancy = False,
+ **kwargs):
+ self.attrname = attrname
+ self.relation = relation
+ self.directed = directed
+ self.connection = connection
+ self.distance_threshold = distance_threshold
+ self.forbid_multiple_occupancy = forbid_multiple_occupancy
+ super().__init__(**kwargs)
+
+
+ def _get_sites_to_merge(self, st, threshold = 0):
+ sn = st.site_network
+
+ attrmat = getattr(sn, self.attrname)
+ assert attrmat.shape == (sn.n_sites, sn.n_sites), "`attrname` doesn't seem to indicate an edge property."
+ connmat = self.relation(attrmat, threshold)
+
+ # Apply distance threshold
+ if self.distance_threshold < np.inf:
+ pbcc = PBCCalculator(sn.structure.cell)
+ centers = sn.centers
+ for i in range(sn.n_sites):
+ dists = pbcc.distances(centers[i], centers[i + 1:])
+ js_too_far = np.where(dists > self.distance_threshold)[0]
+ js_too_far += i + 1
+
+ connmat[i, js_too_far] = False
+ connmat[js_too_far, i] = False # Symmetry
+
+ if self.forbid_multiple_occupancy:
+ n_mobile = sn.n_mobile
+ for frame in st.traj:
+ frame = [s for s in frame if s >= 0]
+ for site in frame: # only known
+ # can't merge occupied site with other simulatanious occupied sites
+ connmat[site, frame] = False
+
+ # Everything is always mergable with itself.
+ np.fill_diagonal(connmat, True)
+
+ # Get mergable groups
+ n_merged_sites, labels = connected_components(
+ connmat,
+ directed = self.directed,
+ connection = self.connection
+ )
+ # MergeSites will check pairwise distances; we just need to make it the
+ # right format.
+ merge_groups = []
+ for lbl in range(n_merged_sites):
+ merge_groups.append(np.where(labels == lbl)[0])
+
+ return merge_groups
diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py
new file mode 100644
index 0000000..22f3a41
--- /dev/null
+++ b/sitator/dynamics/RemoveUnoccupiedSites.py
@@ -0,0 +1,71 @@
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.errors import InsufficientSitesError
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+class RemoveUnoccupiedSites(object):
+ """Remove unoccupied sites."""
+ def __init__(self):
+ pass
+
+ def run(self, st, return_kept_sites = False):
+ """
+ Args:
+ return_kept_sites (bool): If ``True``, a list of the sites from ``st``
+ that were kept will be returned.
+
+ Returns:
+ A ``SiteTrajectory``, or ``st`` itself if it has no unoccupied sites.
+ """
+ assert isinstance(st, SiteTrajectory)
+
+ old_sn = st.site_network
+
+ # Allow for the -1 to affect nothing
+ seen_mask = np.zeros(shape = old_sn.n_sites + 1, dtype = np.bool)
+
+ for frame in st.traj:
+ seen_mask[frame] = True
+ if np.all(seen_mask[:-1]):
+ return st
+
+ seen_mask = seen_mask[:-1]
+
+ logger.info("Removing unoccupied sites %s" % np.where(~seen_mask)[0])
+
+ n_new_sites = np.sum(seen_mask)
+
+ if n_new_sites < old_sn.n_mobile:
+ raise InsufficientSitesError(
+ verb = "Removing unoccupied sites",
+ n_sites = n_new_sites,
+ n_mobile = old_sn.n_mobile
+ )
+
+ translation = np.empty(shape = old_sn.n_sites + 1, dtype = np.int)
+ translation[:-1][seen_mask] = np.arange(n_new_sites)
+ translation[:-1][~seen_mask] = -4321
+ translation[-1] = SiteTrajectory.SITE_UNKNOWN # Map unknown to unknown
+
+ newtraj = translation[st.traj.reshape(-1)]
+ newtraj.shape = st.traj.shape
+
+ assert -4321 not in newtraj
+
+ # We don't clear computed attributes since nothing is changing for other sites.
+ newsn = old_sn[seen_mask]
+
+ new_st = SiteTrajectory(
+ site_network = newsn,
+ particle_assignments = newtraj
+ )
+ if st.real_trajectory is not None:
+ new_st.set_real_traj(st.real_trajectory)
+ if return_kept_sites:
+ return new_st, np.where(seen_mask)
+ else:
+ return new_st
diff --git a/sitator/dynamics/ReplaceUnassignedPositions.py b/sitator/dynamics/ReplaceUnassignedPositions.py
new file mode 100644
index 0000000..1ee2a21
--- /dev/null
+++ b/sitator/dynamics/ReplaceUnassignedPositions.py
@@ -0,0 +1,117 @@
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.util import PBCCalculator
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+# See https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
+def rle(inarray):
+ """ run length encoding. Partial credit to R rle function.
+ Multi datatype arrays catered for including non Numpy
+ returns: tuple (runlengths, startpositions, values) """
+ ia = np.asarray(inarray) # force numpy
+ n = len(ia)
+ if n == 0:
+ return (None, None, None)
+ else:
+ y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe)
+ i = np.append(np.where(y), n - 1) # must include last element posi
+ z = np.diff(np.append(-1, i)) # run lengths
+ p = np.cumsum(np.append(0, z))[:-1] # positions
+ return(z, p, ia[i])
+
+class ReplaceUnassignedPositions(object):
+ """Fill in missing site assignments in a SiteTrajectory.
+
+ Args:
+ replacement_function (callable): Callable that takes
+ ``(st, mobile_atom, before_site, start_frame, after_site, end_frame)`` and
+ returns either:
+ - A single site assignment with which the unassigned will be replaced
+ - A timeseries of length ``end_frame - start_frame`` of site
+ assignments with which the unassigned will be replaced.
+ If ``None``, defaults to
+ ``AssignUnassignedPositions.replace_with_from``.
+ """
+ def __init__(self,
+ replacement_function = None):
+ if replacement_function is None:
+ replacement_function = RemoveShortJumps.replace_with_last_known
+ self.replacement_function = replacement_function
+
+ @staticmethod
+ def replace_with_last_known(st, mobile_atom, before_site, start_frame, after_site, end_frame):
+ """Replace unassigned with the last known site."""
+ return before_site
+
+ @staticmethod
+ def replace_with_next_known(st, mobile_atom, before_site, start_frame, after_site, end_frame):
+ """Replace unassigned with the next known site."""
+ return after_site
+
+ @staticmethod
+ def replace_with_closer():
+ """Create function to replace unknown with closest site over time.
+
+ Assigns each of the positions during an unassigned run to whichever of
+ the before and after sites it is closer to in real space.
+ """
+ pbcc = None
+ ptbuf = np.empty(shape = (2, 3))
+ distbuf = np.empty(shape = 2)
+ def replace_with_closer(st, mobile_atom, before_site, start_frame, after_site, end_frame):
+ if before_site == SiteTrajectory.SITE_UNKNOWN or \
+ after_site == SiteTrajectory.SITE_UNKNOWN:
+ return SiteTrajectory.SITE_UNKNOWN
+
+ if pbcc is None:
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+ n_frames = end_frame - start_frame
+ out = np.empty(shape = n_frames)
+ for i in range(n_frames):
+ ptbuf[0] = st.site_network.centers[before_site]
+ ptbuf[1] = st.site_network.centers[after_site]
+ pbcc.distances(
+ st.real_trajectory[start_frame + i, mobile_atom],
+ ptbuf,
+ in_place = True,
+ out = distbuf
+ )
+ if distbuf[0] < distbuf[1]:
+ out[i] = before_site
+ else:
+ out[i] = after_site
+ return out
+
+
+ def run(self,
+ st):
+ n_mobile = st.site_network.n_mobile
+ n_frames = st.n_frames
+ n_sites = st.site_network.n_sites
+
+ traj = st.traj
+ out = st.traj.copy()
+
+ for mob in range(n_mobile):
+ runlen, start, runsites = rle(traj[:, mob])
+ for runi in range(len(runlen)):
+ if runsites[runi] == SiteTrajectory.SITE_UNKNOWN:
+ unknown_start = start[runi]
+ unknown_end = unknown_start + runlen[runi]
+ unknown_before = runsites[runi - 1] if runi > 0 else SiteTrajectory.SITE_UNKNOWN
+ unknown_after = runsites[runi + 1] if runi < len(runlen) - 1 else SiteTrajectory.SITE_UNKNOWN
+ replace = self.replacement_function(
+ st, mob,
+ unknown_before, unknown_start,
+ unknown_after, unknown_end
+ )
+ out[unknown_start:unknown_end, mob] = replace
+
+ st = st.copy(with_computed = False)
+ st._traj = out
+
+ return st
diff --git a/sitator/dynamics/SmoothSiteTrajectory.pyx b/sitator/dynamics/SmoothSiteTrajectory.pyx
new file mode 100644
index 0000000..379b2d7
--- /dev/null
+++ b/sitator/dynamics/SmoothSiteTrajectory.pyx
@@ -0,0 +1,111 @@
+# cython: language_level=3
+
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.dynamics import RemoveUnoccupiedSites
+
+import logging
+logger = logging.getLogger(__name__)
+
+ctypedef Py_ssize_t site_int
+
+class SmoothSiteTrajectory(object):
+ """"Smooth" a SiteTrajectory by applying a rolling mode.
+
+ For each mobile particle, the assignmet at each frame is replaced by the
+ mode of its site assignments over some number of frames centered around it.
+
+ Can be thought of as a discrete lowpass filter.
+
+ The ``set_unassigned_under_threshold`` parameter allows the user to control
+ how the smoothing handles "transitions" vs. "attempts"; setting it to True,
+ the default, will mark as unassigned transitional moments where neither
+ the source nor destination site have a sufficient (``threshold``) majority
+ in the window, while setting it to False will maintain the assignment to a
+ transitional site.
+
+ Args:
+ window_threshold_factor (float): The total width of the rolling window,
+ in terms of the threshold.
+ remove_unoccupied_sites (bool): If True, sites that are unoccupied after
+ the smoothing will be removed.
+ set_unassigned_under_threshold (bool): If True, if the multiplicity of
+ the mode is less than the threshold, the particle is marked
+ unassigned at that frame. If False, the particle's assignment will
+ not be modified.
+ """
+ def __init__(self,
+ window_threshold_factor = 2.1,
+ remove_unoccupied_sites = True,
+ set_unassigned_under_threshold = True):
+ self.window_threshold_factor = window_threshold_factor
+ self.remove_unoccupied_sites = remove_unoccupied_sites
+ self.set_unassigned_under_threshold = set_unassigned_under_threshold
+
+ def run(self,
+ st,
+ threshold):
+ n_mobile = st.site_network.n_mobile
+ n_frames = st.n_frames
+ n_sites = st.site_network.n_sites
+
+ traj = st.traj
+ out = st.traj.copy()
+
+ window = self.window_threshold_factor * threshold
+ wleft, wright = int(np.floor(window / 2)), int(np.ceil(window / 2))
+
+ running_windowed_mode(
+ traj,
+ out,
+ wleft,
+ wright,
+ threshold,
+ n_sites,
+ self.set_unassigned_under_threshold
+ )
+
+ st = st.copy(with_computed = False)
+ st._traj = out
+ if self.remove_unoccupied_sites:
+ # Removing short jumps could have made some sites completely unoccupied
+ st = RemoveUnoccupiedSites().run(st)
+ st.site_network.clear_attributes()
+
+ return st
+
+
+cpdef running_windowed_mode(site_int [:, :] traj,
+ site_int [:, :] out,
+ Py_ssize_t wleft,
+ Py_ssize_t wright,
+ Py_ssize_t threshold,
+ Py_ssize_t n_sites,
+ bint replace_no_winner_unknown):
+ countbuf_np = np.zeros(shape = n_sites + 1, dtype = np.int)
+ cdef Py_ssize_t [:] countbuf = countbuf_np
+ cdef Py_ssize_t n_mobile = traj.shape[1]
+ cdef Py_ssize_t n_frames = traj.shape[0]
+ cdef site_int s_unknown = SiteTrajectory.SITE_UNKNOWN
+ cdef site_int winner
+ cdef Py_ssize_t best_count
+
+ for mob in range(n_mobile):
+ for frame in range(n_frames):
+ for wi in range(max(frame - wleft, 0), min(frame + wright, n_frames)):
+ countbuf[traj[wi, mob] + 1] += 1
+ winner = 0 # THis is actually -1, so unknown by default
+ best_count = 0
+ for site in range(n_sites + 1):
+ if countbuf[site] > best_count:
+ winner = site
+ best_count = countbuf[site]
+ if best_count >= threshold:
+ out[frame, mob] = winner - 1
+ else:
+ if replace_no_winner_unknown:
+ out[frame, mob] = s_unknown
+ else:
+ out[frame, mob] = traj[frame, mob]
+ countbuf[:] = 0
diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py
index ce898fc..812c503 100644
--- a/sitator/dynamics/__init__.py
+++ b/sitator/dynamics/__init__.py
@@ -1,3 +1,9 @@
-from JumpAnalysis import JumpAnalysis
+from .JumpAnalysis import JumpAnalysis
+from .MergeSitesByDynamics import MergeSitesByDynamics
+from .MergeSitesByThreshold import MergeSitesByThreshold
+from .RemoveUnoccupiedSites import RemoveUnoccupiedSites
+from .SmoothSiteTrajectory import SmoothSiteTrajectory
+from .AverageVibrationalFrequency import AverageVibrationalFrequency
-from MergeSitesByDynamics import MergeSitesByDynamics
+# For backwards compatability, since this used to be in this module
+from sitator.network import DiffusionPathwayAnalysis
diff --git a/sitator/errors.py b/sitator/errors.py
new file mode 100644
index 0000000..06572a4
--- /dev/null
+++ b/sitator/errors.py
@@ -0,0 +1,21 @@
+
+class SiteAnaysisError(Exception):
+ """An error occuring as part of site analysis."""
+ pass
+
+class MultipleOccupancyError(SiteAnaysisError):
+ """Error raised when multiple mobile atoms are assigned to the same site at the same time."""
+ def __init__(self, mobile, site, frame):
+ super().__init__(
+ "Multiple mobile particles %s were assigned to site %i at frame %i." % (mobile, site, frame)
+ )
+ self.mobile_particles = mobile
+ self.site = site
+ self.frame = frame
+
+class InsufficientSitesError(SiteAnaysisError):
+ """Site detection/merging/etc. resulted in fewer sites than mobile particles."""
+ def __init__(self, verb, n_sites, n_mobile):
+ super().__init__("%s resulted in only %i sites for %i mobile particles." % (verb, n_sites, n_mobile))
+ self.n_sites = n_sites
+ self.n_mobile = n_mobile
diff --git a/sitator/ionic.py b/sitator/ionic.py
new file mode 100644
index 0000000..538618d
--- /dev/null
+++ b/sitator/ionic.py
@@ -0,0 +1,139 @@
+import numpy as np
+
+from sitator import SiteNetwork
+
+import ase.data
+
+try:
+ from pymatgen.io.ase import AseAtomsAdaptor
+ import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf
+ from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \
+ LightStructureEnvironments
+ from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions
+ from pymatgen.analysis.bond_valence import BVAnalyzer
+ has_pymatgen = True
+except ImportError:
+ has_pymatgen = False
+
+
+class IonicSiteNetwork(SiteNetwork):
+ """Site network for a species of mobile charged ions in a static lattice.
+
+ Imposes more restrictions than a plain ``SiteNetwork``:
+
+ - Has a mobile species in place of an arbitrary mobile mask
+ - Has a set of static species in place of an arbitraty static mask
+ - All atoms of the mobile species must have the same charge
+
+ And contains more information...
+
+ Attributes:
+ opposite_ion_mask (ndarray): A mask on ``structure`` indicating all
+ anions and neutrally charged atoms if the mobile species is a cation,
+ or vice versa if the mobile species is an anion.
+ opposite_ion_structure (ase.Atoms): An ``Atoms`` containing the atoms
+ indicated by ``opposite_ion_mask``.
+ same_ion_mask (ndarray): A mask on ``structure`` indicating all
+ atoms whose charge has the same sign as the mobile species.
+ same_ion_structure (ase.Atoms): An ``Atoms`` containing the atoms
+ indicated by ``same_ion_mask``.
+ n_opposite_charge (int): The number of opposite charge static atoms.
+ n_same_charge (int): The number of same charge static atoms.
+
+ Args:
+ structure (ase.Atoms)
+ mobile_species (int): Atomic number of the mobile species.
+ static_species (list of int): Atomic numbers of the static species.
+ mobile_charge (int): Charge of mobile atoms. If ``None``,
+ ``pymatgen``'s ``BVAnalyzer`` will be used to estimate valences.
+ static_charges (ndarray int): Charges of the atoms in the static
+ structure. If ``None``, ``sitator`` will try to use
+ ``pymatgen``'s ``BVAnalyzer`` to estimate valences.
+ """
+ def __init__(self,
+ structure,
+ mobile_species,
+ static_species,
+ mobile_charge = None,
+ static_charges = None):
+ if mobile_species in static_species:
+ raise ValueError("Mobile species %i cannot also be one of static species %s" % (mobile_species, static_species))
+ mobile_mask = structure.numbers == mobile_species
+ static_mask = np.in1d(structure.numbers, static_species)
+ super().__init__(
+ structure = structure,
+ mobile_mask = mobile_mask,
+ static_mask = static_mask
+ )
+
+ self.mobile_species = mobile_species
+ self.static_species = static_species
+ # Estimate bond valences if necessary
+ if mobile_charge is None or static_charges is None:
+ if not has_pymatgen:
+ raise ImportError("Pymatgen could not be imported, and is required for guessing charges.")
+ sim_struct = AseAtomsAdaptor.get_structure(structure)
+ bv = BVAnalyzer()
+ struct_valences = np.asarray(bv.get_valences(sim_struct))
+ if static_charges is None:
+ static_charges = struct_valences[static_mask]
+ if mobile_charge is None:
+ mob_val = struct_valences[mobile_mask]
+ if np.any(mob_val != mob_val[0]):
+ raise ValueError("Mobile atom estimated valences (%s) not uniform; arbitrarily taking first." % mob_val)
+ mobile_charge = mob_val[0]
+ self.mobile_charge = mobile_charge
+ self.static_charges = static_charges
+
+ # Create oposite ion stuff
+ mobile_sign = np.sign(mobile_charge)
+ static_signs = np.sign(static_charges)
+ self.opposite_ion_mask = np.empty_like(static_mask)
+ self.opposite_ion_mask.fill(False)
+ self.opposite_ion_mask[static_mask] = static_signs != mobile_sign
+ self.opposite_ion_structure = structure[self.opposite_ion_mask]
+
+ self.same_ion_mask = np.empty_like(static_mask)
+ self.same_ion_mask.fill(False)
+ self.same_ion_mask[static_mask] = static_signs == mobile_sign
+ self.same_ion_structure = structure[self.same_ion_mask]
+
+ @property
+ def n_opposite_charge(self):
+ return np.sum(self.opposite_ion_mask)
+
+ @property
+ def n_same_charge(self):
+ return np.sum(self.same_ion_mask)
+
+ def __getitem__(self, key):
+ out = super().__getitem__(key)
+ out.mobile_species = self.mobile_species
+ out.static_species = self.static_species
+ out.mobile_charge = self.mobile_charge
+ out.static_charges = self.static_charges
+ out.opposite_ion_mask = self.opposite_ion_mask
+ out.opposite_ion_structure = self.opposite_ion_structure
+ out.same_ion_mask = self.same_ion_mask
+ out.same_ion_structure = self.same_ion_structure
+ return out
+
+ def __str__(self):
+ out = super().__str__()
+ static_nums = self.static_structure.numbers
+ out += (
+ " Mobile species: {:2} (charge {:+d})\n"
+ " Static species: {}\n"
+ " # opposite charge: {}\n"
+ ).format(
+ ase.data.chemical_symbols[self.mobile_species],
+ self.mobile_charge,
+ ", ".join(
+ "{} (avg. charge {:+.1f})".format(
+ ase.data.chemical_symbols[s],
+ np.mean(self.static_charges[static_nums == s])
+ ) for s in self.static_species
+ ),
+ self.n_opposite_charge
+ )
+ return out
diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py
index a972ca3..85dd5d4 100644
--- a/sitator/landmark/LandmarkAnalysis.py
+++ b/sitator/landmark/LandmarkAnalysis.py
@@ -1,28 +1,20 @@
import numpy as np
from sitator.util import PBCCalculator
+from sitator.util.progress import tqdm
-# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
import sys
-try:
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
-except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
import importlib
import tempfile
-import helpers
+from . import helpers
from sitator import SiteNetwork, SiteTrajectory
+from sitator.errors import MultipleOccupancyError, InsufficientSitesError
+from sitator.landmark.anchor import to_origin
+import logging
+logger = logging.getLogger(__name__)
from functools import wraps
def analysis_result(func):
@@ -35,7 +27,76 @@ def wrapper(self, *args, **kwargs):
return wrapper
class LandmarkAnalysis(object):
- """Track a mobile species through a fixed lattice using landmark vectors."""
+ """Site analysis of mobile atoms in a static lattice with landmark analysis.
+
+ :param double cutoff_center: The midpoint for the logistic function used
+ as the landmark cutoff function. (unitless)
+ :param double cutoff_steepness: Steepness of the logistic cutoff function.
+ :param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied)
+ for a site to qualify as such.
+ :param str clustering_algorithm: The landmark clustering algorithm. ``sitator``
+ supplies two:
+ - ``"dotprod"``: The method described in our "Unsupervised landmark
+ analysis for jump detection in molecular dynamics simulations" paper.
+ - ``"mcl"``: A newer method we are developing.
+ :param dict clustering_params: Parameters for the chosen ``clustering_algorithm``.
+ :param str site_centers_method: The method to use for computing the real
+ space positions of the sites. Options:
+ - ``SITE_CENTERS_REAL_UNWEIGHTED``: A spatial average of all real-space
+ mobile atom positions assigned to the site is taken.
+ - ``SITE_CENTERS_REAL_WEIGHTED``: A spatial average of all real-space
+ mobile atom positions assigned to the site is taken, weighted
+ by the confidences with which they assigned to the site.
+ - ``SITE_CENTERS_REPRESENTATIVE_LANDMARK``: A spatial average over
+ all landmarks' centers is taken, weighted by the representative
+ or "typical" landmark vector at the site.
+ The "real" methods will generally be more faithful to the simulation,
+ but the representative landmark method can work better in cases with
+ short trajectories, producing a more "ideal" site location.
+ :param bool check_for_zero_landmarks: Whether to check for and raise exceptions
+ when all-zero landmark vectors are computed.
+ :param float static_movement_threshold: (Angstrom) the maximum allowed
+ distance between an instantanous static atom position and it's ideal position.
+ :param bool/callable dynamic_lattice_mapping: Whether to dynamically decide
+ each frame which static atom represents each average lattice position;
+ this allows the LandmarkAnalysis to deal with, say, a rare exchage of
+ two static atoms that does not change the structure of the lattice.
+
+ It does NOT allow LandmarkAnalysis to deal with lattices whose structures
+ actually change over the course of the trajectory.
+
+ In certain cases this is better delt with by ``MergeSitesByDynamics``.
+
+ If ``False``, no mapping will occur. Otherwise, a callable taking a
+ ``SiteNetwork`` should be provided. The callable should return a list
+ of static atom indexes that can be validly assigned to each static lattice
+ position. If ``True``, ``sitator.landmark.dynamic_mapping.within_species``
+ is used.
+ :param int max_mobile_per_site: The maximum number of mobile atoms that can
+ be assigned to a single site without throwing an error. Regardless of the
+ value, assignments of more than one mobile atom to a single site will
+ be recorded and reported.
+
+ Setting this to 2 can be necessary for very diffusive, liquid-like
+ materials at high temperatures.
+
+ Statistics related to this are reported in ``self.avg_mobile_per_site``
+ and ``self.n_multiple_assignments``.
+ :param bool force_no_memmap: if True, landmark vectors will be stored only in memory.
+ Only useful if access to landmark vectors after the analysis has run is desired.
+ :param bool verbose: Verbosity for the ``clustering_algorithm``. Other output
+ controlled through ``logging``.
+ """
+
+ SITE_CENTERS_REAL_UNWEIGHTED = 'real-unweighted'
+ SITE_CENTERS_REAL_WEIGHTED = 'real-weighted'
+ SITE_CENTERS_REPRESENTATIVE_LANDMARK = 'representative-landmark'
+
+ CLUSTERING_CLUSTER_SIZE = 'cluster-size'
+ CLUSTERING_LABELS = 'cluster-labels'
+ CLUSTERING_CONFIDENCES = 'cluster-confs'
+ CLUSTERING_LANDMARK_GROUPINGS = 'cluster-landmark-groupings'
+ CLUSTERING_REPRESENTATIVE_LANDMARKS = 'cluster-representative-lvecs'
def __init__(self,
clustering_algorithm = 'dotprod',
@@ -43,57 +104,15 @@ def __init__(self,
cutoff_midpoint = 1.5,
cutoff_steepness = 30,
minimum_site_occupancy = 0.01,
- peak_evening = 'none',
- weighted_site_positions = True,
+ site_centers_method = SITE_CENTERS_REAL_WEIGHTED,
check_for_zero_landmarks = True,
static_movement_threshold = 1.0,
dynamic_lattice_mapping = False,
+ static_anchoring = to_origin,
relaxed_lattice_checks = False,
max_mobile_per_site = 1,
force_no_memmap = False,
verbose = True):
- """
- :param double cutoff_center: The midpoint for the logistic function used
- as the landmark cutoff function. (unitless)
- :param double cutoff_steepness: Steepness of the logistic cutoff function.
- :param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied)
- for a site to qualify as such.
- :param dict clustering_params: Parameters for the chosen clustering_algorithm
- :param str peak_evening: Whether and what kind of peak "evening" to apply;
- that is, processing that makes all large peaks in the landmark vector
- more similar in magnitude. This can help in site clustering.
-
- Valid options: 'none', 'clip'
- :param bool weighted_site_positions: When computing site positions, whether
- to weight the average by assignment confidence.
- :param bool check_for_zero_landmarks: Whether to check for and raise exceptions
- when all-zero landmark vectors are computed.
- :param float static_movement_threshold: (Angstrom) the maximum allowed
- distance between an instantanous static atom position and it's ideal position.
- :param bool dynamic_lattice_mapping: Whether to dynamically decide each
- frame which static atom represents each average lattice position;
- this allows the LandmarkAnalysis to deal with, say, a rare exchage of
- two static atoms that does not change the structure of the lattice.
-
- It does NOT allow LandmarkAnalysis to deal with lattices whose structures
- actually change over the course of the trajectory.
-
- In certain cases this is better delt with by MergeSitesByDynamics.
- :param int max_mobile_per_site: The maximum number of mobile atoms that can
- be assigned to a single site without throwing an error. Regardless of the
- value, assignments of more than one mobile atom to a single site will
- be recorded and reported.
-
- Setting this to 2 can be necessary for very diffusive, liquid-like
- materials at high temperatures.
-
- Statistics related to this are reported in self.avg_mobile_per_site
- and self.n_multiple_assignments.
- :param bool force_no_memmap: if True, landmark vectors will be stored only in memory.
- Only useful if access to landmark vectors after the analysis has run is desired.
- :param bool verbose: If `True`, progress bars and messages will be printed to stdout.
- """
-
self._cutoff_midpoint = cutoff_midpoint
self._cutoff_steepness = cutoff_steepness
self._minimum_site_occupancy = minimum_site_occupancy
@@ -101,14 +120,16 @@ def __init__(self,
self._cluster_algo = clustering_algorithm
self._clustering_params = clustering_params
- if not peak_evening in ['none', 'clip']:
- raise ValueError("Invalid value `%s` for peak_evening" % peak_evening)
- self._peak_evening = peak_evening
-
self.verbose = verbose
self.check_for_zero_landmarks = check_for_zero_landmarks
- self.weighted_site_positions = weighted_site_positions
+ self.site_centers_method = site_centers_method
+
+ if dynamic_lattice_mapping is True:
+ from sitator.landmark.dynamic_mapping import within_species
+ dynamic_lattice_mapping = within_species
self.dynamic_lattice_mapping = dynamic_lattice_mapping
+ self.static_anchoring = static_anchoring
+
self.relaxed_lattice_checks = relaxed_lattice_checks
self._landmark_vectors = None
@@ -123,26 +144,34 @@ def __init__(self,
@property
def cutoff(self):
- return self._cutoff
+ return self._cutoff
@analysis_result
def landmark_vectors(self):
+ """Landmark vectors from the last invocation of ``run()``"""
view = self._landmark_vectors[:]
view.flags.writeable = False
return view
@analysis_result
def landmark_dimension(self):
+ """Number of components in a single landmark vector."""
return self._landmark_dimension
-
def run(self, sn, frames):
"""Run the landmark analysis.
- The input SiteNetwork is a network of predicted sites; it's sites will
+ The input ``SiteNetwork`` is a network of predicted sites; it's sites will
be used as the "basis" for the landmark vectors.
- Takes a SiteNetwork and returns a SiteTrajectory.
+ Wraps a copy of ``frames`` into the unit cell.
+
+ Args:
+ sn (SiteNetwork): The landmark basis. Each site is a landmark defined
+ by its vertex static atoms, as indicated by `sn.vertices`.
+ (Typically from ``VoronoiSiteGenerator``.)
+ frames (ndarray n_frames x n_atoms x 3): A trajectory. Can be unwrapped;
+ a copy will be wrapped before the analysis.
"""
assert isinstance(sn, SiteNetwork)
@@ -150,24 +179,33 @@ def run(self, sn, frames):
raise ValueError("Cannot rerun LandmarkAnalysis!")
if frames.shape[1:] != (sn.n_total, 3):
- raise ValueError("Wrong shape %s for frames." % frames.shape)
+ raise ValueError("Wrong shape %s for frames." % (frames.shape,))
if sn.vertices is None:
raise ValueError("Input SiteNetwork must have vertices")
n_frames = len(frames)
- if self.verbose:
- print "--- Running Landmark Analysis ---"
+ logger.info("--- Running Landmark Analysis ---")
# Create PBCCalculator
self._pbcc = PBCCalculator(sn.structure.cell)
+ # -- Step 0: Wrap to Unit Cell
+ orig_frames = frames # Keep a reference around
+ frames = frames.copy()
+ # Flatten to list of points for wrapping
+ orig_frame_shape = frames.shape
+ frames.shape = (orig_frame_shape[0] * orig_frame_shape[1], 3)
+ self._pbcc.wrap_points(frames)
+ # Back to list of frames
+ frames.shape = orig_frame_shape
+
# -- Step 1: Compute site-to-vertex distances
self._landmark_dimension = sn.n_sites
longest_vert_set = np.max([len(v) for v in sn.vertices])
- verts_np = np.array([v + [-1] * (longest_vert_set - len(v)) for v in sn.vertices])
+ verts_np = np.array([np.concatenate((v, [-1] * (longest_vert_set - len(v)))) for v in sn.vertices], dtype = np.int)
site_vert_dists = np.empty(shape = verts_np.shape, dtype = np.float)
site_vert_dists.fill(np.nan)
@@ -177,8 +215,17 @@ def run(self, sn, frames):
site_vert_dists[i, :len(polyhedron)] = dists
# -- Step 2: Compute landmark vectors
- if self.verbose: print " - computing landmark vectors -"
- # Compute landmark vectors
+ logger.info(" - computing landmark vectors -")
+
+ if self.dynamic_lattice_mapping:
+ dynmap_compat = self.dynamic_lattice_mapping(sn)
+ else:
+ # If no dynamic mapping, each is only compatable with itself.
+ dynmap_compat = np.arange(sn.n_static)[:, np.newaxis]
+ assert len(dynmap_compat) == sn.n_static
+ static_anchors = self.static_anchoring(sn)
+ # This also validates the anchors
+ lattice_pt_order = self._get_lattice_order_from_anchors(static_anchors)
# The dimension of one landmark vector is the number of Voronoi regions
shape = (n_frames * sn.n_mobile, self._landmark_dimension)
@@ -193,25 +240,48 @@ def run(self, sn, frames):
shape = shape)
helpers._fill_landmark_vectors(self, sn, verts_np, site_vert_dists,
- frames, check_for_zeros = self.check_for_zero_landmarks,
- tqdm = tqdm)
+ frames,
+ dynmap_compat = dynmap_compat,
+ lattice_pt_anchors = static_anchors,
+ lattice_pt_order = lattice_pt_order,
+ check_for_zeros = self.check_for_zero_landmarks,
+ tqdm = tqdm, logger = logger)
+
+ if not self.check_for_zero_landmarks and self.n_all_zero_lvecs > 0:
+ logger.warning(" Had %i all-zero landmark vectors; no error because `check_for_zero_landmarks = False`." % self.n_all_zero_lvecs)
+ elif self.check_for_zero_landmarks:
+ assert self.n_all_zero_lvecs == 0
# -- Step 3: Cluster landmark vectors
- if self.verbose: print " - clustering landmark vectors -"
- # - Preprocess -
- self._do_peak_evening()
+ logger.info(" - clustering landmark vectors -")
# - Cluster -
- cluster_func = importlib.import_module("..cluster." + self._cluster_algo, package = __name__).do_landmark_clustering
+ # FIXME: remove reload after development done
+ clustermod = importlib.import_module("..cluster." + self._cluster_algo, package = __name__)
+ importlib.reload(clustermod)
+ cluster_func = clustermod.do_landmark_clustering
- cluster_counts, lmk_lbls, lmk_confs = \
+ clustering = \
cluster_func(self._landmark_vectors,
clustering_params = self._clustering_params,
min_samples = self._minimum_site_occupancy / float(sn.n_mobile),
verbose = self.verbose)
- if self.verbose:
- print " Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls)))
+ cluster_counts = clustering[LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE]
+ lmk_lbls = clustering[LandmarkAnalysis.CLUSTERING_LABELS]
+ lmk_confs = clustering[LandmarkAnalysis.CLUSTERING_CONFIDENCES]
+ if LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS in clustering:
+ landmark_clusters = clustering[LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS]
+ assert len(cluster_counts) == len(landmark_clusters)
+ else:
+ landmark_clusters = None
+ if LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS in clustering:
+ rep_lvecs = np.asarray(clustering[LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS])
+ assert rep_lvecs.shape == (len(cluster_counts), self._landmark_vectors.shape[1])
+ else:
+ rep_lvecs = None
+
+ logging.info(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))))
# reshape lables and confidences
lmk_lbls.shape = (n_frames, sn.n_mobile)
@@ -219,61 +289,77 @@ def run(self, sn, frames):
n_sites = len(cluster_counts)
- if n_sites < sn.n_mobile:
- raise ValueError("There are %i mobile particles, but only identified %i sites. Check clustering_params." % (sn.n_mobile, n_sites))
+ if n_sites < (sn.n_mobile / self.max_mobile_per_site):
+ raise InsufficientSitesError(
+ verb = "Landmark analysis",
+ n_sites = n_sites,
+ n_mobile = sn.n_mobile
+ )
- if self.verbose:
- print " Identified %i sites with assignment counts %s" % (n_sites, cluster_counts)
-
- # Check that multiple particles are never assigned to one site at the
- # same time, cause that would be wrong.
- n_more_than_ones = 0
- avg_mobile_per_site = 0
- divisor = 0
- for frame_i, site_frame in enumerate(lmk_lbls):
- _, counts = np.unique(site_frame[site_frame >= 0], return_counts = True)
- count_msk = counts > self.max_mobile_per_site
- if np.any(count_msk):
- raise ValueError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i))
- n_more_than_ones += np.sum(counts > 1)
- avg_mobile_per_site += np.sum(counts)
- divisor += len(counts)
-
- self.n_multiple_assignments = n_more_than_ones
- self.avg_mobile_per_site = avg_mobile_per_site / float(divisor)
+ logging.info(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts))
# -- Do output
+ out_sn = sn.copy()
# - Compute site centers
site_centers = np.empty(shape = (n_sites, 3), dtype = frames.dtype)
-
- for site in xrange(n_sites):
- mask = lmk_lbls == site
- pts = frames[:, sn.mobile_mask][mask]
- if self.weighted_site_positions:
- site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask])
- else:
- site_centers[site] = self._pbcc.average(pts)
-
- # Build output obejcts
- out_sn = sn.copy()
-
+ if self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_WEIGHTED or \
+ self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_UNWEIGHTED:
+ for site in range(n_sites):
+ mask = lmk_lbls == site
+ pts = frames[:, sn.mobile_mask][mask]
+ if self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_WEIGHTED:
+ site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask])
+ else:
+ site_centers[site] = self._pbcc.average(pts)
+ elif self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REPRESENTATIVE_LANDMARK:
+ if rep_lvecs is None:
+ raise ValueError("Chosen clustering method (with current parameters) didn't return representative landmark vectors; can't use SITE_CENTERS_REPRESENTATIVE_LANDMARK.")
+ for site in range(n_sites):
+ weights_nonzero = rep_lvecs[site] > 0
+ site_centers[site] = self._pbcc.average(
+ sn.centers[weights_nonzero],
+ weights = rep_lvecs[site, weights_nonzero]
+ )
+ else:
+ raise ValueError("Invalid site centers method '%s'" % self.site_centers_method)
out_sn.centers = site_centers
- assert out_sn.vertices is None
+ # - If clustering gave us that, compute site vertices
+ if landmark_clusters is not None:
+ vertices = []
+ for lclust in landmark_clusters:
+ vertices.append(set.union(*[set(sn.vertices[l]) for l in lclust]))
+ out_sn.vertices = vertices
out_st = SiteTrajectory(out_sn, lmk_lbls, lmk_confs)
- out_st.set_real_traj(frames)
+
+ # Check that multiple particles are never assigned to one site at the
+ # same time, cause that would be wrong.
+ self.n_multiple_assignments, self.avg_mobile_per_site = out_st.check_multiple_occupancy(
+ max_mobile_per_site = self.max_mobile_per_site
+ )
+
+ out_st.set_real_traj(orig_frames)
self._has_run = True
return out_st
- # -------- "private" methods --------
-
- def _do_peak_evening(self):
- if self._peak_evening == 'none':
- return
- elif self._peak_evening == 'clip':
- lvec_peaks = np.max(self._landmark_vectors, axis = 1)
- # Clip all peaks to the lowest "normal" (stdev.) peak
- lvec_clip = np.mean(lvec_peaks) - np.std(lvec_peaks)
- # Do the clipping
- self._landmark_vectors[self._landmark_vectors > lvec_clip] = lvec_clip
+ def _get_lattice_order_from_anchors(self, lattice_pt_anchors):
+ absolute_lattice_mask = lattice_pt_anchors == -1
+ lattice_pt_order = []
+ # -1 (absolte anchor of origin) is always known
+ known = np.zeros(shape = len(lattice_pt_anchors) + 1, dtype = np.bool)
+ known[-1] = True
+
+ while True:
+ can_know = known[lattice_pt_anchors]
+ new_know = can_know & ~known[:-1]
+ known[:-1] |= can_know
+ lattice_pt_order.extend(np.where(new_know)[0])
+ if not np.any(new_know):
+ break
+ if len(lattice_pt_order) < len(lattice_pt_anchors):
+ raise ValueError("Lattice point anchors %s contains a unsatisfiable dependency (likely a circular dependency)." % lattice_pt_anchors)
+ # Remove points with absolute anchors from the order; there's no need to
+ # do any computations for them.
+ lattice_pt_order = lattice_pt_order[np.sum(absolute_lattice_mask):]
+ return lattice_pt_order
diff --git a/sitator/landmark/__init__.py b/sitator/landmark/__init__.py
index e100d24..a084133 100644
--- a/sitator/landmark/__init__.py
+++ b/sitator/landmark/__init__.py
@@ -1,4 +1,4 @@
-from errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError
+from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError
-from LandmarkAnalysis import LandmarkAnalysis
+from .LandmarkAnalysis import LandmarkAnalysis
diff --git a/sitator/landmark/anchor.py b/sitator/landmark/anchor.py
new file mode 100644
index 0000000..067c859
--- /dev/null
+++ b/sitator/landmark/anchor.py
@@ -0,0 +1,49 @@
+import numpy as np
+
+from collections import Counter
+
+import ase.data
+
+from sitator.util.chemistry import identify_polyatomic_ions
+
+import logging
+logger = logging.getLogger(__name__)
+
+def to_origin(sn):
+ """Anchor all static atoms to the origin; i.e. their positions are absolute."""
+ return np.full(
+ shape = sn.n_static,
+ fill_value = -1,
+ dtype = np.int
+ )
+
+# ------
+
+def within_polyatomic_ions(**kwargs):
+ """Anchor the auxiliary atoms of a polyatomic ion to the central atom.
+
+ In phosphate (PO4), for example, the four coordinating Oxygen atoms
+ will be anchored to the central Phosphorous.
+
+ Args:
+ **kwargs: passed to ``sitator.util.chemistry.identify_polyatomic_ions``.
+ """
+ def func(sn):
+ anchors = np.full(shape = sn.n_static, fill_value = -1, dtype = np.int)
+ polyions = identify_polyatomic_ions(sn.static_structure, **kwargs)
+ logger.info("Identified %i polyatomic anions: %s" % (len(polyions), Counter(i[0] for i in polyions)))
+ for _, center, others in polyions:
+ anchors[others] = center
+ return anchors
+ return func
+
+# ------
+# TODO
+def to_heavy_elements(minimum_mass, maximum_distance = np.inf):
+ """Anchor "light" elements to their nearest "heavy" element.
+
+ Lightness/heaviness is determined by ``minimum_mass``, a cutoff in
+ """
+ def func(sn):
+ pass
+ return func
diff --git a/sitator/landmark/cluster/dbscan.py b/sitator/landmark/cluster/dbscan.py
deleted file mode 100644
index 0ff4856..0000000
--- a/sitator/landmark/cluster/dbscan.py
+++ /dev/null
@@ -1,68 +0,0 @@
-
-import numpy as np
-
-import numbers
-from sklearn.cluster import DBSCAN
-
-DEFAULT_PARAMS = {
- 'eps' : 0.05,
- 'min_samples' : 5,
- 'n_jobs' : -1
-}
-
-def do_landmark_clustering(landmark_vectors,
- clustering_params,
- min_samples,
- verbose):
-
- tmp = DEFAULT_PARAMS.copy()
- tmp.update(clustering_params)
- clustering_params = tmp
-
- landmark_classifier = \
- DBSCAN(eps = clustering_params['eps'],
- min_samples = clustering_params['min_samples'],
- n_jobs = clustering_params['n_jobs'],
- metric = 'cosine')
-
- lmk_lbls = \
- landmark_classifier.fit_predict(landmark_vectors)
-
- # - Filter low occupancy sites
- cluster_counts = np.bincount(lmk_lbls[lmk_lbls >= 0])
- n_assigned = np.sum(cluster_counts)
-
- min_n_samples_cluster = None
- if isinstance(min_samples, numbers.Integral):
- min_n_samples_cluster = min_samples
- elif isinstance(min_samples, numbers.Real):
- min_n_samples_cluster = int(np.floor(min_samples * n_assigned))
- else:
- raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
-
- to_remove_mask = cluster_counts < min_n_samples_cluster
- to_remove = np.where(to_remove_mask)[0]
-
- trans_table = np.empty(shape = len(cluster_counts) + 1, dtype = np.int)
- # Map unknown to unknown
- trans_table[-1] = -1
- # Map removed to unknwon
- trans_table[:-1][to_remove_mask] = -1
- # Map known to rescaled known
- trans_table[:-1][~to_remove_mask] = np.arange(len(cluster_counts) - len(to_remove))
- # Do the remapping
- lmk_lbls = trans_table[lmk_lbls]
-
- if verbose:
- print "DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \
- (len(to_remove), len(cluster_counts), min_samples, min_n_samples_cluster, len(cluster_counts) - len(to_remove))
-
- # Remove counts
- cluster_counts = cluster_counts[~to_remove_mask]
-
- # There are no confidences with DBSCAN, so just give everything confidence 1
- # so as not to screw up later weighting.
- confs = np.ones(shape = lmk_lbls.shape, dtype = np.float)
- confs[lmk_lbls == -1] = 0.0
-
- return cluster_counts, lmk_lbls, confs
diff --git a/sitator/landmark/cluster/dotprod.py b/sitator/landmark/cluster/dotprod.py
index 20cfc81..5b98e90 100644
--- a/sitator/landmark/cluster/dotprod.py
+++ b/sitator/landmark/cluster/dotprod.py
@@ -1,5 +1,7 @@
+"""Cluster landmark vectors using the custom online algorithm from the original paper."""
from sitator.util import DotProdClassifier
+from sitator.landmark import LandmarkAnalysis
DEFAULT_PARAMS = {
'clustering_threshold' : 0.45,
@@ -23,4 +25,9 @@ def do_landmark_clustering(landmark_vectors,
predict_threshold = clustering_params['assignment_threshold'],
verbose = verbose)
- return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs
+ return {
+ LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts,
+ LandmarkAnalysis.CLUSTERING_LABELS: lmk_lbls,
+ LandmarkAnalysis.CLUSTERING_CONFIDENCES : lmk_confs,
+ LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : landmark_classifier.cluster_centers
+ }
diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py
new file mode 100644
index 0000000..741cfcf
--- /dev/null
+++ b/sitator/landmark/cluster/mcl.py
@@ -0,0 +1,131 @@
+"""Cluster landmarks into sites using Markov Clustering and then assign each landmark vector.
+
+Valid clustering params include:
+ - ``"assignment_threshold"`` (float between 0 and 1): The similarity threshold
+ below which a landmark vector will be marked unassigned.
+ - ``"good_site_normed_threshold"`` (float between 0 and 1): The minimum for
+ the cosine similarity between a good site's representative unit vector and
+ its best match landmark vector.
+ - ``"good_site_projected_threshold"`` (positive float): The minimum inner product
+ between a good site's representative unit vector and its best match
+ landmark vector.
+ - All other params are passed along to `sitator.util.mcl.markov_clustering`.
+"""
+
+import numpy as np
+
+from sitator.util.progress import tqdm
+from sitator.util.mcl import markov_clustering
+from sitator.util import DotProdClassifier
+from sitator.landmark import LandmarkAnalysis
+
+from sklearn.covariance import empirical_covariance
+
+from scipy.sparse.linalg import eigsh
+
+import logging
+logger = logging.getLogger(__name__)
+
+DEFAULT_PARAMS = {
+ 'inflation' : 4,
+ 'assignment_threshold' : 0.7,
+}
+
+def cov2corr( A ):
+ """
+ covariance matrix to correlation matrix.
+ """
+ d = np.sqrt(A.diagonal())
+ d[d == 0] = np.inf # Forces correlations to zero where variance is 0
+ A = ((A.T/d).T)/d
+ return A
+
+def do_landmark_clustering(landmark_vectors,
+ clustering_params,
+ min_samples,
+ verbose):
+ tmp = DEFAULT_PARAMS.copy()
+ tmp.update(clustering_params)
+ clustering_params = tmp
+
+ n_lmk = landmark_vectors.shape[1]
+ # Center landmark vectors
+ seen_ntimes = np.count_nonzero(landmark_vectors, axis = 0)
+ cov = np.dot(landmark_vectors.T, landmark_vectors) / landmark_vectors.shape[0]
+ corr = cov2corr(cov)
+ graph = np.clip(corr, 0, None)
+ for i in range(n_lmk):
+ if graph[i, i] == 0: # i.e. no self correlation = 0 variance = landmark never seen
+ graph[i, i] = 1 # Needs a self loop for Markov clustering not to degenerate. Arbitrary value, shouldn't affect anyone else.
+
+ predict_threshold = clustering_params.pop('assignment_threshold')
+ good_site_normed_threshold = clustering_params.pop('good_site_normed_threshold', predict_threshold)
+ good_site_project_thresh = clustering_params.pop('good_site_projected_threshold', predict_threshold)
+
+ # -- Cluster Landmarks
+ clusters = markov_clustering(graph, **clustering_params)
+ # Filter out single element clusters of landmarks that never appear.
+ clusters = [list(c) for c in clusters if seen_ntimes[c[0]] > 0]
+ n_clusters = len(clusters)
+ centers = np.zeros(shape = (n_clusters, n_lmk))
+ maxbuf = np.empty(shape = len(landmark_vectors))
+ good_clusters = np.zeros(shape = n_clusters, dtype = np.bool)
+ for i, cluster in enumerate(clusters):
+ if len(cluster) == 1:
+ eigenvec = [1.0] # Eigenvec is trivial
+ else:
+ # PCA inspired:
+ _, eigenvec = eigsh(cov[cluster][:, cluster], k = 1)
+ eigenvec = eigenvec.T
+ centers[i, cluster] = eigenvec
+ np.dot(landmark_vectors, centers[i], out = maxbuf)
+ np.abs(maxbuf, out = maxbuf)
+ best_match = np.argmax(maxbuf)
+ best_match_lvec = landmark_vectors[best_match]
+ best_match_dot = np.abs(np.dot(best_match_lvec, centers[i]))
+ best_match_dot_norm = best_match_dot / np.linalg.norm(best_match_lvec)
+ good_clusters[i] = best_match_dot_norm >= good_site_normed_threshold
+ good_clusters[i] &= best_match_dot >= good_site_project_thresh
+ centers[i] /= best_match_dot
+
+ logger.debug("Kept %i/%i landmark clusters as good sites" % (np.sum(good_clusters), len(good_clusters)))
+
+ # Filter out "bad" sites
+ clusters = [c for i, c in enumerate(clusters) if good_clusters[i]]
+ centers = centers[good_clusters]
+ n_clusters = len(clusters)
+
+ landmark_classifier = \
+ DotProdClassifier(threshold = np.nan, # We're not fitting
+ min_samples = min_samples)
+
+ landmark_classifier.set_cluster_centers(centers)
+
+ lmk_lbls, lmk_confs, info = \
+ landmark_classifier.fit_predict(landmark_vectors,
+ predict_threshold = predict_threshold,
+ predict_normed = False,
+ verbose = verbose,
+ return_info = True)
+
+ msk = info['kept_clusters_mask']
+ clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold
+
+ # Find the average landmark vector at each site
+ weighted_reps = clustering_params.get('weighted_representative_landmarks', True)
+ centers = np.zeros(shape = (len(clusters), n_lmk))
+ weights = np.empty(shape = lmk_lbls.shape)
+ for site in range(len(clusters)):
+ np.equal(lmk_lbls, site, out = weights)
+ if weighted_reps:
+ weights *= lmk_confs
+ centers[site] = np.average(landmark_vectors, weights = weights, axis = 0)
+
+
+ return {
+ LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts,
+ LandmarkAnalysis.CLUSTERING_LABELS : lmk_lbls,
+ LandmarkAnalysis.CLUSTERING_CONFIDENCES: lmk_confs,
+ LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS : clusters,
+ LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : centers
+ }
diff --git a/sitator/landmark/dynamic_mapping.py b/sitator/landmark/dynamic_mapping.py
new file mode 100644
index 0000000..be7ce83
--- /dev/null
+++ b/sitator/landmark/dynamic_mapping.py
@@ -0,0 +1,5 @@
+import numpy as np
+
+def within_species(sn):
+ nums = sn.static_structure.numbers
+ return [(nums == nums[s]).nonzero()[0] for s in range(sn.n_static)]
diff --git a/sitator/landmark/errors.py b/sitator/landmark/errors.py
index 55dcfb2..69f06eb 100644
--- a/sitator/landmark/errors.py
+++ b/sitator/landmark/errors.py
@@ -2,16 +2,14 @@
class LandmarkAnalysisError(Exception):
pass
-class StaticLatticeError(Exception):
+class StaticLatticeError(LandmarkAnalysisError):
"""Error raised when static lattice atoms break any limits on their movement/position.
Attributes:
lattice_atoms (list, optional): The indexes of the atoms in the static lattice that
caused the error.
frame (int, optional): The frame in the trajectory at which the error occured.
-
"""
-
TRY_RECENTERING_MSG = "Try recentering the input trajectory (sitator.util.RecenterTrajectory)"
def __init__(self, message, lattice_atoms = None, frame = None, try_recentering = False):
@@ -25,7 +23,13 @@ def __init__(self, message, lattice_atoms = None, frame = None, try_recentering
self.lattice_atoms = lattice_atoms
self.frame = frame
-class ZeroLandmarkError(Exception):
+class ZeroLandmarkError(LandmarkAnalysisError):
+ """Error raised when a landmark vector containing only zeros is encountered.
+
+ Attributes:
+ mobile_index (int): Which mobile atom had the all-zero vector.
+ frame (int): At which frame it was encountered.
+ """
def __init__(self, mobile_index, frame):
message = "Encountered a zero landmark vector for mobile ion %i at frame %i. Try increasing `cutoff_midpoint` and/or decreasing `cutoff_steepness`." % (mobile_index, frame)
diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx
index d54bbfd..2ed9eb9 100644
--- a/sitator/landmark/helpers.pyx
+++ b/sitator/landmark/helpers.pyx
@@ -9,7 +9,17 @@ from sitator.landmark import StaticLatticeError, ZeroLandmarkError
ctypedef double precision
-def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i):
+def _fill_landmark_vectors(self,
+ sn,
+ verts_np,
+ site_vert_dists,
+ frames,
+ dynmap_compat,
+ lattice_pt_anchors,
+ lattice_pt_order,
+ check_for_zeros = True,
+ tqdm = lambda i: i,
+ logger = None):
if self._landmark_dimension is None:
raise ValueError("_fill_landmark_vectors called before Voronoi!")
@@ -19,20 +29,31 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
cdef pbcc = self._pbcc
- frame_shift = np.empty(shape = (sn.n_static, 3))
+ frame_shift = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype)
temp_distbuff = np.empty(shape = sn.n_static, dtype = frames.dtype)
mobile_idexes = np.where(sn.mobile_mask)[0]
-
- if self.dynamic_lattice_mapping:
- lattice_map = np.empty(shape = sn.n_static, dtype = np.int)
- else:
- # Otherwise just map to themselves
- lattice_map = np.arange(sn.n_static, dtype = np.int)
-
- lattice_pt = np.empty(shape = 3, dtype = sn.static_structure.positions.dtype)
- lattice_pt_dists = np.empty(shape = sn.n_static, dtype = np.float)
+ # Static lattice point buffers
+ lattice_pts_resolved = np.empty(shape = (sn.n_static, 3), dtype = sn.static_structure.positions.dtype)
+ # Determine resolution order
+ absolute_lattice_mask = lattice_pt_anchors == -1
+ assert len(lattice_pt_order) == len(lattice_pt_anchors) - np.sum(absolute_lattice_mask), "Order must contain all non-absolute anchored static lattice points"
+ cdef Py_ssize_t [:] lattice_pt_order_c = np.asarray(lattice_pt_order, dtype = np.int)
+ # Absolute (relative to origin) ones never need to be resolved, put it in
+ # the buffer now
+ lattice_pts_resolved[absolute_lattice_mask] = sn.static_structure.positions[absolute_lattice_mask]
+ assert not np.any(absolute_lattice_mask[lattice_pt_order]), "None of the absolute lattice points should be in the resolution order"
+ # Precompute the offsets for relative static lattice points:
+ relative_lattice_offsets = sn.static_structure.positions[lattice_pt_order] - sn.static_structure.positions[lattice_pt_anchors[lattice_pt_order]]
+ # Buffers for dynamic mapping
+ max_n_dynmat_compat = max(len(dm) for dm in dynmap_compat)
+ lattice_pt_dists = np.empty(shape = max_n_dynmat_compat, dtype = np.float)
+ static_pos_buffer = np.empty(shape = (max_n_dynmat_compat, 3), dtype = lattice_pts_resolved.dtype)
static_positions_seen = np.empty(shape = sn.n_static, dtype = np.bool)
+ lattice_map = np.empty(shape = sn.n_static, dtype = np.int)
+ # Instant static position buffers
+ static_positions = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype)
+ static_mask_idexes = sn.static_mask.nonzero()[0]
# - Precompute cutoff function rounding point
# TODO: Think about the 0.0001 value
@@ -42,43 +63,65 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
self._cutoff_steepness,
0.0001)
+ cdef Py_ssize_t n_all_zero_lvecs = 0
+
cdef Py_ssize_t landmark_dim = self._landmark_dimension
cdef Py_ssize_t current_landmark_i = 0
- # Iterate through time
- for i, frame in enumerate(tqdm(frames, desc = "Frame")):
- static_positions = frame[sn.static_mask]
+ cdef Py_ssize_t nearest_static_position
+ cdef precision nearest_static_distance
+ cdef Py_ssize_t n_dynmap_allowed
+ # Iterate through time
+ for i, frame in enumerate(tqdm(frames, desc = "Landmark Frame")):
+ # Copy static positions to buffer
+ np.take(frame,
+ static_mask_idexes,
+ out = static_positions,
+ axis = 0,
+ mode = 'clip')
# Every frame, update the lattice map
static_positions_seen.fill(False)
- for lattice_index in xrange(sn.n_static):
- lattice_pt = sn.static_structure.positions[lattice_index]
+ # - Resolve static lattice positions from their origins
+ for order_i, lattice_index in enumerate(lattice_pt_order_c):
+ lattice_pts_resolved[lattice_index] = static_positions[lattice_pt_anchors[lattice_index]]
+ lattice_pts_resolved[lattice_index] += relative_lattice_offsets[order_i]
- if self.dynamic_lattice_mapping:
- # Only compute all distances if dynamic remapping is on
- pbcc.distances(lattice_pt, static_positions, out = lattice_pt_dists)
- nearest_static_position = np.argmin(lattice_pt_dists)
- nearest_static_distance = lattice_pt_dists[nearest_static_position]
- else:
- nearest_static_position = lattice_index
- nearest_static_distance = pbcc.distances(lattice_pt, static_positions[nearest_static_position:nearest_static_position+1])[0]
+ # - Map static positions to static lattice sites
+ for lattice_index in xrange(sn.n_static):
+ dynmap_allowed = dynmap_compat[lattice_index]
+ n_dynmap_allowed = len(dynmap_allowed)
+
+ np.take(static_positions,
+ dynmap_allowed,
+ out = static_pos_buffer[:n_dynmap_allowed],
+ axis = 0,
+ mode = 'clip')
+
+ pbcc.distances(
+ lattice_pts_resolved[lattice_index],
+ static_pos_buffer[:n_dynmap_allowed],
+ out = lattice_pt_dists[:n_dynmap_allowed]
+ )
+ nearest_static_position = np.argmin(lattice_pt_dists[:n_dynmap_allowed])
+ nearest_static_distance = lattice_pt_dists[nearest_static_position]
+ nearest_static_position = dynmap_allowed[nearest_static_position]
if static_positions_seen[nearest_static_position]:
# We've already seen this one... error
- print "Static atom %i is the closest to more than one static lattice position" % nearest_static_position
+ logger.warning("Static atom %i is the closest to more than one static lattice position" % nearest_static_position)
#raise ValueError("Static atom %i is the closest to more than one static lattice position" % nearest_static_position)
static_positions_seen[nearest_static_position] = True
if nearest_static_distance > self.static_movement_threshold:
- raise StaticLatticeError("No static atom position within %f A threshold of static lattice position %i" % (self.static_movement_threshold, lattice_index),
+ raise StaticLatticeError("Nearest static atom to lattice position %i is %.2fÅ away, above threshold of %.2fÅ" % (lattice_index, nearest_static_distance, self.static_movement_threshold),
lattice_atoms = [lattice_index],
frame = i,
try_recentering = True)
- if self.dynamic_lattice_mapping:
- lattice_map[lattice_index] = nearest_static_position
+ lattice_map[lattice_index] = nearest_static_position
# In normal circumstances, every current static position should be assigned.
# Just a sanity check
@@ -89,13 +132,14 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
frame = i,
try_recentering = True)
-
+ # - Compute landmark vectors for mobile
for j in xrange(sn.n_mobile):
mobile_pt = frame[mobile_idexes[j]]
# Shift the Li in question to the center of the unit cell
- np.copyto(frame_shift, frame[sn.static_mask])
- frame_shift += (pbcc.cell_centroid - mobile_pt)
+ frame_shift[:] = static_positions
+ frame_shift -= mobile_pt
+ frame_shift += pbcc.cell_centroid
# Wrap all positions into the unit cell
pbcc.wrap_points(frame_shift)
@@ -111,17 +155,24 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo
cutoff_round_to_zero,
temp_distbuff)
- if check_for_zeros and (np.count_nonzero(self._landmark_vectors[current_landmark_i]) == 0):
- raise ZeroLandmarkError(mobile_index = j, frame = i)
+ if np.count_nonzero(self._landmark_vectors[current_landmark_i]) == 0:
+ if check_for_zeros:
+ raise ZeroLandmarkError(mobile_index = j, frame = i)
+ else:
+ n_all_zero_lvecs += 1
current_landmark_i += 1
+ self.n_all_zero_lvecs = n_all_zero_lvecs
+
+
cdef precision cutoff_round_to_zero_point(precision cutoff_midpoint,
precision cutoff_steepness,
precision threshold):
# Computed by solving for x:
return cutoff_midpoint + log((1/threshold) - 1.) / cutoff_steepness
+
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
@@ -139,18 +190,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
precision cutoff_round_to_zero,
precision [:] distbuff) nogil:
- # Pure Python equiv:
- # for k in xrange(landmark_dim):
- # lvec = np.linalg.norm(lattice_positions[verts[k]] - cell_centroid, axis = 1)
- # past_cutoff = lvec > cutoff
-
- # # Short circut it, since the product then goes to zero too.
- # if np.any(past_cutoff):
- # landmark_vectors[(i * n_li) + j, k] = 0
- # else:
- # lvec = (np.cos((np.pi / cutoff) * lvec) + 1.0) / 2.0
- # landmark_vectors[(i * n_li) + j, k] = np.product(lvec)
-
# Fill the landmark vector
cdef int [:] vert
cdef Py_ssize_t v
@@ -168,11 +207,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
distbuff[idex] = temp
- # if temp > cutoff:
- # distbuff[idex] = 0.0
- # else:
- # distbuff[idex] = (cos((M_PI / cutoff) * temp) + 1.0) * 0.5
-
# For each component
for k in xrange(landmark_dim):
ci = 1.0
@@ -196,7 +230,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors,
temp = 1.0 / (1.0 + exp(cutoff_steepness * (temp - cutoff_midpoint)))
# Multiply into accumulator
- #ci *= distbuff[v]
ci *= temp
# "Normalize" to number of vertices
diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py
deleted file mode 100644
index ece3107..0000000
--- a/sitator/landmark/pointmerge.py
+++ /dev/null
@@ -1,86 +0,0 @@
-
-import numpy as np
-
-# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
-import sys
-try:
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
-except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
-
-def merge_points_soap_paths(tsoap,
- pbcc,
- points,
- connectivity_dict,
- threshold,
- n_steps = 5,
- sanity_check_cutoff = np.inf,
- verbose = True):
- """Merge points using SOAP paths method.
-
- :param SOAP tsoap: to compute SOAPs with.
- :param dict connectivity_dict: Maps a point index to a set of point indexes
- it is connected to, however defined.
- :param threshold: Similarity threshold, 0 < threshold <= 1
- """
-
- merge_sets = set()
-
- points_along = np.empty(shape = (n_steps, 3), dtype = np.float)
- step_vec_mult = np.linspace(0.0, 1.0, num = n_steps)[:, np.newaxis]
-
- for pt_idex in tqdm(connectivity_dict.keys()):
- merge_set = set()
- current_pts = [pt_idex]
- from_soap = None
- keep_going = True
- while keep_going:
- added_this_iter = set()
- for edge_from in current_pts:
- offset = pbcc.cell_centroid - points[edge_from]
- edge_from_pt = pbcc.cell_centroid
-
- for edge_to in connectivity_dict[edge_from] - merge_set:
- edge_to_pt = points[edge_to].copy()
- edge_to_pt += offset
- pbcc.wrap_point(edge_to_pt)
-
- step_vec = edge_to_pt - edge_from_pt
- edge_length = np.linalg.norm(step_vec)
-
- assert edge_length <= sanity_check_cutoff, "edge_length %s" % edge_length
-
- # Points along the line
- for i in xrange(n_steps):
- points_along[i] = step_vec
- points_along *= step_vec_mult
- points_along += edge_from_pt
- # Re-center back to original center
- points_along -= offset
- # Wrap back into original unit cell - the one frame_atoms has
- pbcc.wrap_points(points_along)
-
- merge = tsoap.soaps_similar_for_points(points_along, threshold = threshold)
-
- if merge:
- added_this_iter.add(edge_from)
- added_this_iter.add(edge_to)
-
- if len(added_this_iter) == 0:
- keep_going = False
- else:
- current_pts = added_this_iter - merge_set
- merge_set.update(added_this_iter)
-
- if len(merge_set) > 0:
- merge_sets.add(frozenset(merge_set))
-
- return merge_sets
diff --git a/sitator/misc/GenerateAroundSites.py b/sitator/misc/GenerateAroundSites.py
index dea0306..c668239 100644
--- a/sitator/misc/GenerateAroundSites.py
+++ b/sitator/misc/GenerateAroundSites.py
@@ -4,7 +4,12 @@
from sitator.util import PBCCalculator
class GenerateAroundSites(object):
- """Generate n normally distributed sites around each input site"""
+ """Generate ``n`` normally distributed sites around each site.
+
+ Args:
+ n (int): How many sites to produce for each input site.
+ sigma (float): Standard deviation of the spatial Gaussian.
+ """
def __init__(self, n, sigma):
self.n = n
self.sigma = sigma
@@ -14,9 +19,9 @@ def run(self, sn):
out = sn.copy()
pbcc = PBCCalculator(sn.structure.cell)
- print out.centers.shape
+
newcenters = out.centers.repeat(self.n, axis = 0)
- print newcenters.shape
+ assert len(newcenters) == self.n * len(out.centers)
newcenters += self.sigma * np.random.standard_normal(size = newcenters.shape)
pbcc.wrap_points(newcenters)
diff --git a/sitator/misc/GenerateClampedTrajectory.pyx b/sitator/misc/GenerateClampedTrajectory.pyx
new file mode 100644
index 0000000..d442664
--- /dev/null
+++ b/sitator/misc/GenerateClampedTrajectory.pyx
@@ -0,0 +1,128 @@
+# cython: language_level=3
+
+import numpy as np
+
+from sitator import SiteTrajectory
+from sitator.util.PBCCalculator cimport PBCCalculator, precision
+from sitator.util.progress import tqdm
+
+from libc.math cimport floor
+
+ctypedef Py_ssize_t site_int
+
+class GenerateClampedTrajectory(object):
+ """Create a real-space trajectory with the fixed site/static structure positions.
+
+ Generate a real-space trajectory where the atoms are clamped to the fixed
+ positions of the current site/their fixed static position.
+
+ Args:
+ wrap (bool): If ``True``, all clamped positions will be in
+ the unit cell; if ``False``, the clamped position will be the minimum
+ image of the clamped position with respect to the real-space position.
+ (This can generate a clamped, unwrapped real-space trajectory
+ from an unwrapped real space trajectory.)
+ pass_through_unassigned (bool): If ``True``, when a
+ mobile atom is supposed to be clamped but is unassigned at some
+ frame, its real-space position will be passed through from the
+ real trajectory. If False, an error will be raised.
+ """
+ def __init__(self, wrap = False, pass_through_unassigned = False):
+ self.wrap = wrap
+ self.pass_through_unassigned = pass_through_unassigned
+
+
+ def run(self, st, clamp_mask = None):
+ """Create a real-space trajectory with the fixed site/static structure positions.
+
+ Generate a real-space trajectory where the atoms indicated in ``clamp_mask`` --
+ any mixture of static and mobile -- are clamped to: (1) the fixed position of
+ their current site, if mobile, or (2) the corresponding fixed position in
+ the ``SiteNetwork``'s static structure, if static.
+
+ Atoms not indicated in ``clamp_mask`` will have their positions from
+ ``real_traj`` passed through.
+
+ Args:
+ clamp_mask (ndarray, len(sn.structure))
+ Returns:
+ ndarray (n_frames x n_atoms x 3)
+ """
+ wrap = self.wrap
+ pass_through_unassigned = self.pass_through_unassigned
+ cell = st._sn.structure.cell
+ cdef PBCCalculator pbcc = PBCCalculator(cell)
+
+ n_atoms = len(st._sn.structure)
+ if clamp_mask is None:
+ clamp_mask = np.ones(shape = n_atoms, dtype = np.bool)
+ if st._real_traj is None and not np.all(clamp_mask):
+ raise RuntimeError("This `SiteTrajectory` has no real-space trajectory, but the given clamp mask leaves some atoms unclamped.")
+
+ clamptrj = np.empty(shape = (st.n_frames, n_atoms, 3))
+ # Pass through unclamped positions
+ if not np.all(clamp_mask):
+ clamptrj[:, ~clamp_mask, :] = st._real_traj[:, ~clamp_mask, :]
+ # Clamp static atoms
+ static_clamp = clamp_mask & st._sn.static_mask
+ clamptrj[:, static_clamp, :] = st._sn.structure.get_positions()[static_clamp]
+ # Clamp mobile atoms
+ mobile_clamp = clamp_mask & st._sn.mobile_mask
+ selected_sitetraj = st._traj[:, mobile_clamp]
+ mobile_clamp_indexes = np.where(mobile_clamp)[0]
+ if not pass_through_unassigned and np.min(selected_sitetraj) < 0:
+ raise RuntimeError("The mobile atoms indicated for clamping are unassigned at some point during the trajectory and `pass_through_unassigned` is set to False. Try `assign_to_last_known_site()`?")
+
+ cdef site_int at_site
+ cdef Py_ssize_t frame_i
+ cdef Py_ssize_t mobile_i
+ cdef Py_ssize_t [:] mobile_clamp_indexes_c = mobile_clamp_indexes
+ cdef precision [:, :] buf
+ cdef precision [:] site_pt
+ cdef site_int site_unknown = SiteTrajectory.SITE_UNKNOWN
+ cdef const site_int [:, :] sitetrj_c = st._traj
+ cdef precision [:, :, :] clamptrj_c = clamptrj
+ cdef const precision [:, :, :] realtrj_c = st._real_traj
+ cdef const precision [:, :] centers_c = st.site_network.centers
+ cdef int site_mic_int
+ cdef int [3] site_mic
+ cdef int [3] pt_in_image
+ cdef precision [:, :] centers_crystal_c
+ cdef Py_ssize_t dim
+ if wrap:
+ for frame_i in tqdm(range(len(clamptrj))):
+ for mobile_i in mobile_clamp_indexes_c:
+ at_site = sitetrj_c[frame_i, mobile_i]
+ if at_site == site_unknown: # we already know that this means pass_through_unassigned = True
+ clamptrj_c[frame_i, mobile_i] = realtrj_c[frame_i, mobile_i]
+ else:
+ clamptrj_c[frame_i, mobile_i] = centers_c[at_site]
+ else:
+ buf = np.empty(shape = (1, 3))
+ site_pt = np.empty(shape = 3)
+ centers_crystal_c = st.site_network.centers.copy()
+ pbcc.to_cell_coords(centers_crystal_c)
+ for frame_i in tqdm(range(len(clamptrj))):
+ for mobile_i in mobile_clamp_indexes_c:
+ buf[:, :] = realtrj_c[frame_i, mobile_i]
+ at_site = sitetrj_c[frame_i, mobile_i]
+ if at_site == site_unknown: # we already know that this means pass_through_unassigned = True
+ clamptrj_c[frame_i, mobile_i] = realtrj_c[frame_i, mobile_i]
+ continue
+ site_pt[:] = centers_c[at_site]
+ pbcc.wrap_point(site_pt)
+ pbcc.wrap_points(buf)
+ site_mic_int = pbcc.min_image(buf[0], site_pt)
+ for dim in range(3):
+ site_mic[dim] = (site_mic_int // 10**(2 - dim) % 10) - 1
+ buf[:, :] = realtrj_c[frame_i, mobile_i]
+ pbcc.to_cell_coords(buf)
+ for dim in range(3):
+ pt_in_image[dim] = floor(buf[0, dim]) + site_mic[dim]
+ buf[0] = centers_crystal_c[at_site]
+ for dim in range(3):
+ buf[0, dim] += pt_in_image[dim]
+ pbcc.to_real_coords(buf)
+ clamptrj_c[frame_i, mobile_i] = buf[0]
+
+ return clamptrj
diff --git a/sitator/misc/NAvgsPerSite.py b/sitator/misc/NAvgsPerSite.py
index 41d74bf..b4d1f01 100644
--- a/sitator/misc/NAvgsPerSite.py
+++ b/sitator/misc/NAvgsPerSite.py
@@ -5,15 +5,15 @@
from sitator.util import PBCCalculator
class NAvgsPerSite(object):
- """Given a SiteTrajectory, return a SiteNetwork containing n avg. positions per site.
+ """Given a ``SiteTrajectory``, return a ``SiteNetwork`` containing n avg. positions per site.
- The `types` of sites in the output are the index of the site in the input that generated
- that average.
+ The ``site_types`` of sites in the output are the index of the site in the
+ input that generated that average.
:param int n: How many averages to take
:param bool error_on_insufficient: Whether to throw an error if n points cannot
be provided for a site, or just take all that are available.
- :param bool weighted: Use SiteTrajectory confidences to weight the averages.
+ :param bool weighted: Use ``SiteTrajectory`` confidences to weight the averages.
"""
def __init__(self, n,
@@ -25,6 +25,12 @@ def __init__(self, n,
self.weighted = weighted
def run(self, st):
+ """
+ Args:
+ st (SiteTrajectory)
+ Returns:
+ A ``SiteNetwork``.
+ """
assert isinstance(st, SiteTrajectory)
if st.real_trajectory is None:
raise ValueError("SiteTrajectory must have associated real trajectory.")
@@ -32,10 +38,12 @@ def run(self, st):
pbcc = PBCCalculator(st.site_network.structure.cell)
# Maximum length
centers = np.empty(shape = (self.n * st.site_network.n_sites, 3), dtype = st.real_trajectory.dtype)
+ centers.fill(np.nan)
types = np.empty(shape = centers.shape[0], dtype = np.int)
+ types.fill(np.nan)
current_idex = 0
- for site in xrange(st.site_network.n_sites):
+ for site in range(st.site_network.n_sites):
if self.weighted:
pts, confs = st.real_positions_for_site(site, return_confidences = True)
else:
@@ -46,7 +54,7 @@ def run(self, st):
if len(pts) > self.n:
sanity = 0
- for i in xrange(self.n):
+ for i in range(self.n):
ps = pts[i::self.n]
sanity += len(ps)
c = confs[i::self.n]
@@ -64,7 +72,9 @@ def run(self, st):
types[old_idex:current_idex] = site
sn = st.site_network.copy()
- sn.centers = centers
- sn.site_types = types
+ sn.centers = centers[:current_idex]
+ sn.site_types = types[:current_idex]
+
+ assert not (np.isnan(np.sum(sn.centers)) or np.isnan(np.sum(sn.site_types)))
return sn
diff --git a/sitator/misc/SiteVolumes.py b/sitator/misc/SiteVolumes.py
deleted file mode 100644
index 7a8f48c..0000000
--- a/sitator/misc/SiteVolumes.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import numpy as np
-
-from scipy.spatial import ConvexHull
-from scipy.spatial.qhull import QhullError
-
-from sitator import SiteTrajectory
-from sitator.util import PBCCalculator
-
-class SiteVolumes(object):
- """Computes the volumes of convex hulls around all positions associated with a site.
-
- Adds the `site_volumes` and `site_surface_areas` attributes to the SiteNetwork.
- """
- def __init__(self, n_recenterings = 8):
- self.n_recenterings = n_recenterings
-
- def run(self, st):
- vols = np.empty(shape = st.site_network.n_sites, dtype = np.float)
- areas = np.empty(shape = st.site_network.n_sites, dtype = np.float)
-
- pbcc = PBCCalculator(st.site_network.structure.cell)
-
- for site in xrange(st.site_network.n_sites):
- pos = st.real_positions_for_site(site)
-
- assert pos.flags['OWNDATA']
-
- vol = np.inf
- area = None
- for i in xrange(self.n_recenterings):
- # Recenter
- offset = pbcc.cell_centroid - pos[int(i * (len(pos)/self.n_recenterings))]
- pos += offset
- pbcc.wrap_points(pos)
-
- try:
- hull = ConvexHull(pos)
- except QhullError as qhe:
- print "For site %i, iter %i: %s" % (site, i, qhe)
- vols[site] = np.nan
- areas[site] = np.nan
- continue
-
- if hull.volume < vol:
- vol = hull.volume
- area = hull.area
-
- vols[site] = vol
- areas[site] = area
-
- st.site_network.add_site_attribute('site_volumes', vols)
- st.site_network.add_site_attribute('site_surface_areas', areas)
diff --git a/sitator/misc/__init__.py b/sitator/misc/__init__.py
index 66901f4..8240d62 100644
--- a/sitator/misc/__init__.py
+++ b/sitator/misc/__init__.py
@@ -1,4 +1,4 @@
-from NAvgsPerSite import NAvgsPerSite
-from GenerateAroundSites import GenerateAroundSites
-from SiteVolumes import SiteVolumes
+from .NAvgsPerSite import NAvgsPerSite
+from .GenerateAroundSites import GenerateAroundSites
+from .GenerateClampedTrajectory import GenerateClampedTrajectory
diff --git a/sitator/misc/oldio.py b/sitator/misc/oldio.py
new file mode 100644
index 0000000..ddd7ce2
--- /dev/null
+++ b/sitator/misc/oldio.py
@@ -0,0 +1,82 @@
+import numpy as np
+
+import tempfile
+import tarfile
+import os
+
+import ase
+import ase.io
+
+from sitator import SiteNetwork
+
+_STRUCT_FNAME = "structure.xyz"
+_SMASK_FNAME = "static_mask.npy"
+_MMASK_FNAME = "mobile_mask.npy"
+_MAIN_FNAMES = ['centers', 'vertices', 'site_types']
+
+def save(sn, file):
+ """Save this SiteNetwork to a tar archive."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # -- Write the structure
+ ase.io.write(os.path.join(tmpdir, _STRUCT_FNAME), sn.structure, parallel = False)
+ # -- Write masks
+ np.save(os.path.join(tmpdir, _SMASK_FNAME), sn.static_mask)
+ np.save(os.path.join(tmpdir, _MMASK_FNAME), sn.mobile_mask)
+ # -- Write what we have
+ for arrname in _MAIN_FNAMES:
+ if not getattr(sn, arrname) is None:
+ np.save(os.path.join(tmpdir, "%s.npy" % arrname), getattr(sn, arrname))
+ # -- Write all site/edge attributes
+ for atype, attrs in zip(("site_attr", "edge_attr"), (sn._site_attrs, sn._edge_attrs)):
+ for attr in attrs:
+ np.save(os.path.join(tmpdir, "%s-%s.npy" % (atype, attr)), attrs[attr])
+ # -- Write final archive
+ with tarfile.open(file, mode = 'w:gz', format = tarfile.PAX_FORMAT) as outf:
+ outf.add(tmpdir, arcname = "")
+
+
+def from_file(file):
+ """Load a SiteNetwork from a tar file/file descriptor."""
+ all_others = {}
+ site_attrs = {}
+ edge_attrs = {}
+ structure = None
+ with tarfile.open(file, mode = 'r:gz', format = tarfile.PAX_FORMAT) as input:
+ # -- Load everything
+ for member in input.getmembers():
+ if member.name == '':
+ continue
+ f = input.extractfile(member)
+ if member.name == _STRUCT_FNAME:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ input.extract(member, path = tmpdir)
+ structure = ase.io.read(os.path.join(tmpdir, member.name), format = 'xyz')
+ else:
+ basename = os.path.splitext(os.path.basename(member.name))[0]
+ data = np.load(f)
+ if basename.startswith("site_attr"):
+ site_attrs[basename.split('-')[1]] = data
+ elif basename.startswith("edge_attr"):
+ edge_attrs[basename.split('-')[1]] = data
+ else:
+ all_others[basename] = data
+
+ # Create SiteNetwork
+ assert not structure is None
+ assert all(k in all_others for k in ("static_mask", "mobile_mask")), "Malformed SiteNetwork file."
+ sn = SiteNetwork(structure,
+ all_others['static_mask'],
+ all_others['mobile_mask'])
+ if 'centers' in all_others:
+ sn.centers = all_others['centers']
+ for key in all_others:
+ if key in ('centers', 'static_mask', 'mobile_mask'):
+ continue
+ setattr(sn, key, all_others[key])
+
+ assert all(len(sa) == sn.n_sites for sa in site_attrs.values())
+ assert all(ea.shape == (sn.n_sites, sn.n_sites) for ea in edge_attrs.values())
+ sn._site_attrs = site_attrs
+ sn._edge_attrs = edge_attrs
+
+ return sn
diff --git a/sitator/network/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py
new file mode 100644
index 0000000..7dc714c
--- /dev/null
+++ b/sitator/network/DiffusionPathwayAnalysis.py
@@ -0,0 +1,228 @@
+
+import numpy as np
+
+import numbers
+import itertools
+
+from scipy.sparse import lil_matrix
+from scipy.sparse.csgraph import connected_components
+
+from sitator import SiteNetwork
+from sitator.util import PBCCalculator
+
+import logging
+logger = logging.getLogger(__name__)
+
+class DiffusionPathwayAnalysis(object):
+ """Find connected diffusion pathways in a SiteNetwork.
+
+ :param float|int connectivity_threshold: The percentage of the total number of
+ (non-self) jumps, or absolute number of jumps, that must occur over an edge
+ for it to be considered connected.
+ :param int minimum_n_sites: The minimum number of sites that must be part of
+ a pathway for it to be considered as such.
+ :param bool true_periodic_pathways: Whether only to return true periodic
+ pathways that include sites and their periodic images (i.e. conductive
+ in the bulk) rather than just connected components. If ``True``,
+ ``minimum_n_sites`` is NOT respected.
+ """
+
+ NO_PATHWAY = -1
+
+ def __init__(self,
+ connectivity_threshold = 1,
+ true_periodic_pathways = True,
+ minimum_n_sites = 0):
+ assert minimum_n_sites >= 0
+
+ self.true_periodic_pathways = true_periodic_pathways
+ self.connectivity_threshold = connectivity_threshold
+ self.minimum_n_sites = minimum_n_sites
+
+ def run(self, sn, return_count = False, return_direction = False):
+ """
+ Expects a ``SiteNetwork`` that has had a ``JumpAnalysis`` run on it.
+
+ Adds information to ``sn`` in place.
+
+ Args:
+ sn (SiteNetwork): Must have jump statistics from a ``JumpAnalysis``.
+ return_count (bool): Return the number of connected pathways.
+ return_direction (bool): If True and `self.true_periodic_pathways`,
+ return for each pathway the direction matrix indicating which
+ directions it connects accross periodic boundaries.
+ Returns:
+ sn, [n_pathways], [list of set of tuple]
+ """
+ if not sn.has_attribute('n_ij'):
+ raise ValueError("SiteNetwork has no `n_ij`; run a JumpAnalysis on it first.")
+
+ nondiag = np.ones(shape = sn.n_ij.shape, dtype = np.bool)
+ np.fill_diagonal(nondiag, False)
+ n_non_self_jumps = np.sum(sn.n_ij[nondiag])
+
+ if isinstance(self.connectivity_threshold, numbers.Integral):
+ threshold = self.connectivity_threshold
+ elif isinstance(self.connectivity_threshold, numbers.Real):
+ threshold = self.connectivity_threshold * n_non_self_jumps
+ else:
+ raise TypeError("Don't know how to interpret connectivity_threshold `%s`" % self.connectivity_threshold)
+
+ connectivity_matrix = sn.n_ij >= threshold
+
+ if self.true_periodic_pathways:
+ connectivity_matrix, mask_000, images = self._build_mic_connmat(sn, connectivity_matrix)
+
+ n_ccs, ccs = connected_components(connectivity_matrix,
+ directed = False, # even though the matrix is symmetric
+ connection = 'weak') # diffusion could be unidirectional
+
+ _, counts = np.unique(ccs, return_counts = True)
+
+ if self.true_periodic_pathways:
+ # is_pathway = np.ones(shape = n_ccs, dtype = np.bool)
+ # We have to check that the pathways include a site and its periodic
+ # image, and throw out those that don't
+ new_n_ccs = 1
+ new_ccs = np.zeros(shape = len(sn), dtype = np.int)
+
+ # Add a non-path (contains no sites, all False) so the broadcasting works
+ site_masks = [np.zeros(shape = len(sn), dtype = np.bool)]
+
+ pathway_dirs = [set()]
+
+ for pathway_i in np.arange(n_ccs):
+ path_mask = ccs == pathway_i
+
+ if not np.any(path_mask & mask_000):
+ # If the pathway is entirely outside the unit cell, we don't care
+ continue
+
+ # Sum along each site's periodic images, giving a count site-by-site
+ site_counts = np.sum(path_mask.reshape((-1, sn.n_sites)).astype(np.int), axis = 0)
+ if not np.any(site_counts > 1):
+ # Not percolating; doesn't contain any site and its periodic image.
+ continue
+
+ pdirs = set()
+ for periodic_site in np.where(site_counts > 1)[0]:
+ at_images = images[path_mask[periodic_site::len(sn)]]
+ # The direction from 0 to 1 should be the same as any other pair.
+ # Cause periodic.
+ direction = (at_images[0] - at_images[1]) != 0
+ pdirs.add(tuple(direction))
+
+ cur_site_mask = site_counts > 0
+
+ intersects_with = np.where(np.any(np.logical_and(site_masks, cur_site_mask), axis = 1))[0]
+ # Merge them:
+ if len(intersects_with) > 0:
+ path_mask = cur_site_mask | np.logical_or.reduce([site_masks[i] for i in intersects_with], axis = 0)
+ pdirs = pdirs.union(*[pathway_dirs[i] for i in intersects_with])
+ else:
+ path_mask = cur_site_mask
+ # Remove individual merged paths
+ # Going in reverse order means indexes don't become invalid as deletes happen
+ for i in sorted(intersects_with, reverse=True):
+ del site_masks[i]
+ del pathway_dirs[i]
+ # Add new (super)path
+ site_masks.append(path_mask)
+ pathway_dirs.append(pdirs)
+
+ new_ccs[path_mask] = new_n_ccs
+ new_n_ccs += 1
+
+ n_ccs = new_n_ccs
+ ccs = new_ccs
+ # Only actually take the ones that were assigned to in the end
+ # This will deal with the ones that were merged.
+ is_pathway = np.in1d(np.arange(n_ccs), ccs)
+ is_pathway[0] = False # Cause this was the "unassigned" value, we initialized with zeros up above
+ pathway_dirs = pathway_dirs[1:] # Get rid of the dummy pathway's direction
+ assert len(pathway_dirs) == np.sum(is_pathway)
+ else:
+ is_pathway = counts >= self.minimum_n_sites
+
+ logging.info("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps))
+ logging.info("Found %i connected components, of which %i are large enough to qualify as pathways (%i sites)." % (n_ccs, np.sum(is_pathway), self.minimum_n_sites))
+
+ n_pathway = np.sum(is_pathway)
+ translation = np.empty(n_ccs, dtype = np.int)
+ translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY
+ translation[is_pathway] = np.arange(n_pathway)
+
+ node_pathways = translation[ccs]
+
+ outmat = np.empty(shape = (sn.n_sites, sn.n_sites), dtype = np.int)
+
+ for i in range(sn.n_sites):
+ rowmask = node_pathways[i] == node_pathways
+ outmat[i, rowmask] = node_pathways[i]
+ outmat[i, ~rowmask] = DiffusionPathwayAnalysis.NO_PATHWAY
+
+ sn.add_site_attribute('site_diffusion_pathway', node_pathways)
+ sn.add_edge_attribute('edge_diffusion_pathway', outmat)
+
+ retval = [sn]
+ if return_count:
+ retval.append(n_pathway)
+ if return_direction:
+ retval.append(pathway_dirs)
+ return tuple(retval)
+
+
+ def _build_mic_connmat(self, sn, connectivity_matrix):
+ # We use a 3x3x3 = 27 supercell, so there are 27x as many sites
+ assert len(sn) == connectivity_matrix.shape[0]
+
+ images = np.asarray(list(itertools.product(range(-1, 2), repeat = 3)))
+ image_to_idex = dict((100 * (image[0] + 1) + 10 * (image[1] + 1) + (image[2] + 1), i) for i, image in enumerate(images))
+ n_images = len(images)
+ assert n_images == 27
+
+ n_sites = len(sn)
+ pos = sn.centers #.copy() # TODO: copy not needed after reinstall of sitator!
+ n_total_sites = len(images) * n_sites
+ newmat = lil_matrix((n_total_sites, n_total_sites), dtype = np.bool)
+
+ mask_000 = np.zeros(shape = n_total_sites, dtype = np.bool)
+ index_000 = image_to_idex[111]
+ mask_000[index_000:index_000 + n_sites] = True
+ assert np.sum(mask_000) == len(sn)
+
+ pbcc = PBCCalculator(sn.structure.cell)
+ buf = np.empty(shape = 3)
+
+ internal_mat = np.zeros_like(connectivity_matrix)
+ external_connections = []
+ for from_site, to_site in zip(*np.where(connectivity_matrix)):
+ buf[:] = pos[to_site]
+ if pbcc.min_image(pos[from_site], buf) == 111:
+ # If we're in the main image, keep the connection: it's internal
+ internal_mat[from_site, to_site] = True
+ #internal_mat[to_site, from_site] = True # fake FIXME
+ else:
+ external_connections.append((from_site, to_site))
+ #external_connections.append((to_site, from_site)) # FAKE FIXME
+
+ for image_idex, image in enumerate(images):
+ # Make the block diagonal
+ newmat[image_idex * n_sites:(image_idex + 1) * n_sites,
+ image_idex * n_sites:(image_idex + 1) * n_sites] = internal_mat
+
+ # Check all external connections from this image; add other sparse entries
+ for from_site, to_site in external_connections:
+ buf[:] = pos[to_site]
+ to_mic = pbcc.min_image(pos[from_site], buf)
+ to_in_image = image + [(to_mic // 10**(2 - i) % 10) - 1 for i in range(3)] # FIXME: is the -1 right
+ assert to_in_image is not None, "%s" % to_in_image
+ assert np.max(np.abs(to_in_image)) <= 2
+ if not np.any(np.abs(to_in_image) > 1):
+ to_in_image = 100 * (to_in_image[0] + 1) + 10 * (to_in_image[1] + 1) + 1 * (to_in_image[2] + 1)
+ newmat[image_idex * n_sites + from_site,
+ image_to_idex[to_in_image] * n_sites + to_site] = True
+
+ assert np.sum(newmat) >= n_images * np.sum(internal_mat) # Lowest it can be is if every one is internal
+
+ return newmat, mask_000, images
diff --git a/sitator/network/MergeSitesByBarrier.py b/sitator/network/MergeSitesByBarrier.py
new file mode 100644
index 0000000..3865952
--- /dev/null
+++ b/sitator/network/MergeSitesByBarrier.py
@@ -0,0 +1,139 @@
+import numpy as np
+
+import itertools
+import math
+
+from scipy.sparse.csgraph import connected_components
+
+from ase.calculators.calculator import all_changes
+
+from sitator.util import PBCCalculator
+from sitator.util.progress import tqdm
+from sitator.network.merging import MergeSites
+
+import logging
+logger = logging.getLogger(__name__)
+
+class MergeSitesByBarrier(MergeSites):
+ """Merge sites based on the energy barrier between them.
+
+ Uses a cheap coordinate driving system; this may not be sophisticated enough
+ for complex cases. For each pair of sites within the pairwise distance cutoff,
+ a linear spatial interpolation is applied to produce ``n_driven_images``.
+ Two sites are considered mergable if their energies are within
+ ``final_initial_energy_threshold`` and the barrier between them is below
+ ``barrier_threshold``. The barrier is defined as the maximum image energy minus
+ the average of the initial and final energy.
+
+ The energies of the mobile atom are calculated in a static lattice given
+ by ``coordinating_mask``; if ``None``, this is set to the system's
+ ``static_mask``.
+
+ For resonable performance, ``calculator`` should be something simple like
+ ``ase.calculators.lj.LennardJones``.
+
+ Takes species of first mobile atom as mobile species.
+
+ Args:
+ calculator (ase.Calculator): For computing total potential energies.
+ barrier_threshold (float, eV): The barrier value above which two sites
+ are not mergable.
+ n_driven_images (int, default: None): The number of evenly distributed
+ driven images to use.
+ maximum_pairwise_distance (float, Angstrom): The maximum distance
+ between two sites for them to be considered for merging.
+ minimum_jumps_mergable (int): The minimum number of observed jumps
+ between two sites for their merging to be considered. Setting this
+ higher can avoid unnecessary computations.
+ maximum_merge_distance (float, Angstrom): The maxiumum pairwise distance
+ among a group of sites chosed to be merged.
+ """
+ def __init__(self,
+ calculator,
+ barrier_threshold,
+ n_driven_images = 20,
+ maximum_pairwise_distance = 2,
+ minimum_jumps_mergable = 1,
+ maximum_merge_distance = 2,
+ **kwargs):
+ super().__init__(maximum_merge_distance = maximum_merge_distance, **kwargs)
+ self.barrier_threshold = barrier_threshold
+ self.maximum_pairwise_distance = maximum_pairwise_distance
+ self.minimum_jumps_mergable = minimum_jumps_mergable
+ assert n_driven_images >= 3, "Must have at least initial, transition, and final."
+ self.n_driven_images = n_driven_images
+ self.calculator = calculator
+
+
+ def _get_sites_to_merge(self, st, coordinating_mask = None):
+ sn = st.site_network
+
+ # -- Compute jump statistics
+ if not sn.has_attribute('n_ij'):
+ ja = JumpAnalysis()
+ ja.run(st)
+
+ pos = sn.centers
+ if coordinating_mask is None:
+ coordinating_mask = sn.static_mask
+ else:
+ assert not np.any(coordinating_mask & sn.mobile_mask)
+ # -- Build images
+ mobile_idex = np.where(sn.mobile_mask)[0][0]
+ one_mobile_structure = sn.structure[coordinating_mask]
+ one_mobile_structure.extend(sn.structure[mobile_idex])
+ mobile_idex = -1
+ one_mobile_structure.set_calculator(self.calculator)
+ interpolation_coeffs = np.linspace(0, 1, self.n_driven_images)
+ energies = np.empty(shape = self.n_driven_images)
+
+ # -- Decide on pairs to check
+ pbcc = PBCCalculator(sn.structure.cell)
+ dists = pbcc.pairwise_distances(pos)
+ # At the start, all within distance cutoff are mergable
+ mergable = dists <= self.maximum_pairwise_distance
+ mergable &= sn.n_ij >= self.minimum_jumps_mergable
+
+ # -- Check pairs' barriers
+ # Symmetric, and diagonal is trivially true. Combinations avoids those cases.
+ jbuf = pos[0].copy()
+ first_calculate = True
+ mergable_pairs = (p for p in itertools.combinations(range(sn.n_sites), r = 2) if mergable[p] or mergable[p[1], p[0]])
+ n_mergable = (np.sum(mergable) - sn.n_sites) // 2
+ for i, j in tqdm(mergable_pairs, total = n_mergable):
+ jbuf[:] = pos[j]
+ # Get minimage
+ _ = pbcc.min_image(pos[i], jbuf)
+ # Do coordinate driving
+ vector = jbuf - pos[i]
+ for image_i in range(self.n_driven_images):
+ one_mobile_structure.positions[mobile_idex] = vector
+ one_mobile_structure.positions[mobile_idex] *= interpolation_coeffs[image_i]
+ one_mobile_structure.positions[mobile_idex] += pos[i]
+ energies[image_i] = one_mobile_structure.get_potential_energy()
+ first_calculate = False
+ # Check barrier
+ barrier_idex = np.argmax(energies)
+ forward_barrier = energies[barrier_idex] - energies[0]
+ backward_barrier = energies[barrier_idex] - energies[-1]
+ # If it's an actual maxima barrier between them, then we want to
+ # check its height
+ if barrier_idex != 0 and barrier_idex != self.n_driven_images - 1:
+ mergable[i, j] = forward_barrier <= self.barrier_threshold
+ mergable[j, i] = backward_barrier <= self.barrier_threshold
+ # Otherwise, if there's no maxima between them, they are in the same
+ # basin.
+
+ # Get mergable groups
+ n_merged_sites, labels = connected_components(
+ mergable,
+ directed = True,
+ connection = 'strong'
+ )
+ # MergeSites will check pairwise distances; we just need to make it the
+ # right format.
+ merge_groups = []
+ for lbl in range(n_merged_sites):
+ merge_groups.append(np.where(labels == lbl)[0])
+
+ return merge_groups
diff --git a/sitator/network/__init__.py b/sitator/network/__init__.py
new file mode 100644
index 0000000..0dd8174
--- /dev/null
+++ b/sitator/network/__init__.py
@@ -0,0 +1,4 @@
+from .DiffusionPathwayAnalysis import DiffusionPathwayAnalysis
+
+from . import merging
+from .MergeSitesByBarrier import MergeSitesByBarrier
diff --git a/sitator/network/merging.py b/sitator/network/merging.py
new file mode 100644
index 0000000..c714e7c
--- /dev/null
+++ b/sitator/network/merging.py
@@ -0,0 +1,145 @@
+import numpy as np
+
+import abc
+
+from sitator.util import PBCCalculator
+from sitator import SiteNetwork, SiteTrajectory
+from sitator.errors import InsufficientSitesError
+
+import logging
+logger = logging.getLogger(__name__)
+
+class MergeSitesError(Exception):
+ pass
+
+class MergedSitesTooDistantError(MergeSitesError):
+ pass
+
+
+class MergeSites(abc.ABC):
+ """Abstract base class for merging sites.
+
+ :param bool check_types: If ``True``, only sites of the same type are candidates to
+ be merged; if ``False``, type information is ignored. Merged sites will only
+ be assigned types if this is ``True``.
+ :param float maximum_merge_distance: Maximum distance between two sites
+ that are in a merge group, above which an error will be raised.
+ :param bool set_merged_into: If ``True``, a site attribute ``"merged_into"``
+ will be added to the original ``SiteNetwork`` indicating which new site
+ each old site was merged into.
+ :param bool weighted_spatial_average: If ``True``, the spatial average giving
+ the position of the merged site will be weighted by occupancy.
+ """
+ def __init__(self,
+ check_types = True,
+ maximum_merge_distance = None,
+ set_merged_into = False,
+ weighted_spatial_average = True):
+ self.check_types = check_types
+ self.maximum_merge_distance = maximum_merge_distance
+ self.set_merged_into = set_merged_into
+ self.weighted_spatial_average = weighted_spatial_average
+
+
+ def run(self, st, **kwargs):
+ """Takes a ``SiteTrajectory`` and returns a new ``SiteTrajectory``."""
+
+ if self.check_types and st.site_network.site_types is None:
+ raise ValueError("Cannot run a check_types=True MergeSites on a SiteTrajectory without type information.")
+
+ # -- Compute jump statistics
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+ site_centers = st.site_network.centers
+ if self.check_types:
+ site_types = st.site_network.site_types
+
+ clusters = self._get_sites_to_merge(st, **kwargs)
+
+ old_n_sites = st.site_network.n_sites
+ new_n_sites = len(clusters)
+
+ logger.info("After merging %i sites there will be %i sites for %i mobile particles" % (len(site_centers), new_n_sites, st.site_network.n_mobile))
+
+ if new_n_sites < st.site_network.n_mobile:
+ raise InsufficientSitesError(
+ verb = "Merging",
+ n_sites = new_n_sites,
+ n_mobile = st.site_network.n_mobile
+ )
+
+ if self.check_types:
+ new_types = np.empty(shape = new_n_sites, dtype = np.int)
+ merge_verts = st.site_network.vertices is not None
+ if merge_verts:
+ new_verts = []
+
+ # -- Merge Sites
+ new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype)
+ translation = np.empty(shape = st.site_network.n_sites, dtype = np.int)
+ translation.fill(-1)
+
+ for newsite in range(new_n_sites):
+ mask = list(clusters[newsite])
+ # Update translation table
+ if np.any(translation[mask] != -1):
+ # We've assigned a different cluster for this before... weird
+ # degeneracy
+ raise ValueError("Site merging tried to merge site(s) into more than one new site. This shouldn't happen.")
+ translation[mask] = newsite
+
+ to_merge = site_centers[mask]
+
+ # Check distances
+ if not self.maximum_merge_distance is None:
+ dists = pbcc.distances(to_merge[0], to_merge[1:])
+ if not np.all(dists <= self.maximum_merge_distance):
+ raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %.2f apart. Lower your distance_threshold?" % self.maximum_merge_distance)
+
+ # New site center
+ if self.weighted_spatial_average:
+ new_centers[newsite] = pbcc.average(to_merge)
+ else:
+ occs = st.site_network.occupancies[mask]
+ new_centers[newsite] = pbcc.average(to_merge, weights = occs)
+
+ if self.check_types:
+ assert np.all(site_types[mask] == site_types[mask][0])
+ new_types[newsite] = site_types[mask][0]
+ if merge_verts:
+ new_verts.append(set.union(*[set(st.site_network.vertices[i]) for i in mask]))
+
+ newsn = st.site_network.copy()
+ newsn.centers = new_centers
+ if self.check_types:
+ newsn.site_types = new_types
+ if merge_verts:
+ newsn.vertices = new_verts
+
+ newtraj = translation[st._traj]
+ newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN
+
+ # It doesn't make sense to propagate confidence information through a
+ # transform that might completely invalidate it
+ newst = SiteTrajectory(newsn, newtraj, confidences = None)
+
+ if not st.real_trajectory is None:
+ newst.set_real_traj(st.real_trajectory)
+
+ if self.set_merged_into:
+ if st.site_network.has_attribute("merged_into"):
+ st.site_network.remove_attribute("merged_into")
+ st.site_network.add_site_attribute("merged_into", translation)
+
+ return newst
+
+ @abc.abstractmethod
+ def _get_sites_to_merge(self, st, **kwargs):
+ """Get the groups of sites to merge.
+
+ Returns a list of list/tuples each containing the numbers of sites to be merged.
+ There should be no overlap, and every site should be mentioned in at most
+ one site merging group.
+
+ If not mentioned in any, the site will disappear.
+ """
+ pass
diff --git a/sitator/plotting.py b/sitator/plotting.py
deleted file mode 100644
index cff27bd..0000000
--- a/sitator/plotting.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import numpy as np
-
-import matplotlib.pyplot as plt
-import matplotlib
-from mpl_toolkits.mplot3d import Axes3D
-
-import ase
-
-from samos.analysis.jumps.voronoi import collapse_into_unit_cell
-
-DEFAULT_COLORS = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] * 2
-
-# From https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to
-def set_axes_equal(ax):
- '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
- cubes as cubes, etc.. This is one possible solution to Matplotlib's
- ax.set_aspect('equal') and ax.axis('equal') not working for 3D.
-
- Input
- ax: a matplotlib axis, e.g., as output from plt.gca().
- '''
-
- x_limits = ax.get_xlim3d()
- y_limits = ax.get_ylim3d()
- z_limits = ax.get_zlim3d()
-
- x_range = abs(x_limits[1] - x_limits[0])
- x_middle = np.mean(x_limits)
- y_range = abs(y_limits[1] - y_limits[0])
- y_middle = np.mean(y_limits)
- z_range = abs(z_limits[1] - z_limits[0])
- z_middle = np.mean(z_limits)
-
- # The plot bounding box is a sphere in the sense of the infinity
- # norm, hence I call half the max range the plot radius.
- plot_radius = 0.5*max([x_range, y_range, z_range])
-
- ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
- ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
- ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
-
-def plot_atoms(atms, species,
- pts=None, pts_cs = None, pts_marker = 'x',
- cell = None,
- hide_species = (),
- wrap = False,
- title = ""):
- fig = plt.figure()
- ax = fig.add_subplot(111, projection="3d")
- cs = {
- 'Li' : "blue", 'O' : "red", 'Ta' : "gray", 'Ge' : "darkgray", 'P' : "orange",
- 'point' : 'black'
- }
-
- if wrap and not cell is None:
- atms = np.asarray([collapse_into_unit_cell(pt, cell) for pt in atms])
- if not pts is None:
- pts = np.asarray([collapse_into_unit_cell(pt, cell) for pt in pts])
-
- for s in hide_species:
- cs[s] = 'none'
-
- ax.scatter(atms[:,0],
- atms[:,1],
- atms[:,2],
- c = [cs.get(e, 'gray') for e in species],
- s = [10.0*(ase.data.atomic_numbers[s])**0.5 for s in species])
-
-
- if not pts is None:
- c = None
- if pts_cs is None:
- c = cs['point']
- else:
- c = pts_cs
- ax.scatter(pts[:,0],
- pts[:,1],
- pts[:,2],
- marker = pts_marker,
- c = c,
- cmap=matplotlib.cm.Dark2)
-
- if not cell is None:
- for cvec in cell:
- cvec = np.array([[0, 0, 0], cvec])
- ax.plot(cvec[:,0],
- cvec[:,1],
- cvec[:,2],
- color = "gray",
- alpha=0.5,
- linestyle="--")
-
- ax.set_title(title)
-
- set_axes_equal(ax)
-
- return ax
diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py
index f2575ae..df0c128 100644
--- a/sitator/site_descriptors/SOAP.py
+++ b/sitator/site_descriptors/SOAP.py
@@ -2,40 +2,15 @@
import numpy as np
from abc import ABCMeta, abstractmethod
-from sitator.SiteNetwork import SiteNetwork
-from sitator.SiteTrajectory import SiteTrajectory
-
-try:
- import quippy as qp
- from quippy import descriptors
-except ImportError:
- raise ImportError("Quippy with GAP is required for using SOAP descriptors.")
+from sitator import SiteNetwork, SiteTrajectory
+from sitator.util.progress import tqdm
from ase.data import atomic_numbers
-DEFAULT_SOAP_PARAMS = {
- 'cutoff' : 3.0,
- 'cutoff_transition_width' : 1.0,
- 'l_max' : 6, 'n_max' : 6,
- 'atom_sigma' : 0.4
-}
-
-# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698
-import sys
-try:
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
-except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
-
-class SOAP(object):
+import logging
+logger = logging.getLogger(__name__)
+
+class SOAP(object, metaclass=ABCMeta):
"""Abstract base class for computing SOAP vectors in a SiteNetwork.
SOAP computations are *not* thread-safe; use one SOAP object per thread.
@@ -45,42 +20,42 @@ class SOAP(object):
of the environment to consider. I.e. for Li2CO3, can be set to ['O'] or [8]
for oxygen only, or ['C', 'O'] / ['C', 8] / [6,8] if carbon and oxygen
are considered an environment.
- Defaults to `None`, in which case all non-mobile atoms are considered
+ Defaults to ``None``, in which case all non-mobile atoms are considered
regardless of species.
- :param soap_mask: Which atoms in the SiteNetwork's structure
+ :param soap_mask: Which atoms in the ``SiteNetwork``'s structure
to use in SOAP calculations.
Can be either a boolean mask ndarray or a tuple of species.
- If `None`, the entire static_structure of the SiteNetwork will be used.
+ If ``None``, the entire ``static_structure`` of the ``SiteNetwork`` will be used.
Mobile atoms cannot be used for the SOAP host structure.
Even not masked, species not considered in environment will be not accounted for.
- For ideal performance: Specify environment and soap_mask correctly!
+ For ideal performance: Specify environment and ``soap_mask`` correctly!
:param dict soap_params = {}: Any custom SOAP params.
+ :param func backend: A function that can be called with
+ ``sn, soap_mask, tracer_atomic_number, environment_list`` as
+ parameters, returning a function that, given the current soap structure
+ along with tracer atoms, returns SOAP vectors in a numpy array. (i.e.
+ its signature is ``soap(structure, positions)``). The returned function
+ can also have a property, ``n_dim``, giving the length of a single SOAP
+ vector.
"""
- __metaclass__ = ABCMeta
+
+ from .backend.quip import quip_soap_backend as backend_quip
+ from .backend.dscribe import dscribe_soap_backend as backend_dscribe
+
def __init__(self, tracer_atomic_number, environment = None,
- soap_mask=None, soap_params={}, verbose =True):
+ soap_mask = None,
+ backend = None):
from ase.data import atomic_numbers
- # Creating a dictionary for convenience, to check the types and values:
- self.tracer_atomic_number = 3
- centers_list = [self.tracer_atomic_number]
+ self.tracer_atomic_number = tracer_atomic_number
self._soap_mask = soap_mask
- # -- Create the descriptor object
- soap_opts = dict(DEFAULT_SOAP_PARAMS)
- soap_opts.update(soap_params)
- soap_cmd_line = ["soap"]
-
- # User options
- for opt in soap_opts:
- soap_cmd_line.append("{}={}".format(opt, soap_opts[opt]))
+ if backend is None:
+ backend = SOAP.dscribe_soap_backend
+ self._backend = backend
- #
- soap_cmd_line.append('n_Z={} Z={{{}}}'.format(len(centers_list), ' '.join(map(str, centers_list))))
-
- # - Add environment species controls if given
- self._environment = None
- if not environment is None:
+ # - Standardize environment species controls if given
+ if not environment is None: # User given environment
if not isinstance(environment, (list, tuple)):
raise TypeError('environment has to be a list or tuple of species (atomic number'
' or symbol of the environment to consider')
@@ -99,40 +74,38 @@ def __init__(self, tracer_atomic_number, environment = None,
raise TypeError("Environment has to be a list of atomic numbers or atomic symbols")
self._environment = environment_list
- soap_cmd_line.append('n_species={} species_Z={{{}}}'.format(len(environment_list), ' '.join(map(str, environment_list))))
-
- soap_cmd_line = " ".join(soap_cmd_line)
-
- if verbose:
- print("SOAP command line: %s" % soap_cmd_line)
-
- self._soaper = descriptors.Descriptor(soap_cmd_line)
- self._verbose = verbose
- self._cutoff = soap_opts['cutoff']
-
-
-
- @property
- def n_dim(self):
- return self._soaper.n_dim
+ else:
+ self._environment = None
def get_descriptors(self, stn):
- """
- Get the descriptors.
- :param stn: A valid instance of SiteTrajectory or SiteNetwork
- :returns: an array of descriptor vectors and an equal length array of
- labels indicating which descriptors correspond to which sites.
+ """Get the descriptors.
+
+ Args:
+ stn (SiteTrajectory or SiteNetwork)
+ Returns:
+ An array of descriptor vectors and an equal length array of labels
+ indicating which descriptors correspond to which sites.
"""
# Build SOAP host structure
if isinstance(stn, SiteTrajectory):
- structure, tracer_index, soap_mask = self._make_structure(stn.site_network)
+ sn = stn.site_network
elif isinstance(stn, SiteNetwork):
- structure, tracer_index, soap_mask = self._make_structure(stn)
+ sn = stn
else:
raise TypeError("`stn` must be SiteNetwork or SiteTrajectory")
+ structure, tracer_atomic_number, soap_mask = self._make_structure(sn)
+
+ if self._environment is not None:
+ environment_list = self._environment
+ else:
+ # Set it to all species represented by the soap_mask
+ environment_list = np.unique(sn.structure.get_atomic_numbers()[soap_mask])
+
+ soaper = self._backend(sn, soap_mask, tracer_atomic_number, environment_list)
+
# Compute descriptors
- return self._get_descriptors(stn, structure, tracer_index, soap_mask)
+ return self._get_descriptors(stn, structure, tracer_atomic_number, soap_mask, soaper)
# ----
@@ -140,7 +113,7 @@ def _make_structure(self, sn):
if self._soap_mask is None:
# Make a copy of the static structure
- structure = qp.Atoms(sn.static_structure)
+ structure = sn.static_structure.copy()
soap_mask = sn.static_mask # soap mask is the
else:
if isinstance(self._soap_mask, tuple):
@@ -149,7 +122,7 @@ def _make_structure(self, sn):
soap_mask = self._soap_mask
assert not np.any(soap_mask & sn.mobile_mask), "Error for atoms %s; No atom can be both static and mobile" % np.where(soap_mask & sn.mobile_mask)[0]
- structure = qp.Atoms(sn.structure[soap_mask])
+ structure = sn.structure[soap_mask]
assert np.any(soap_mask), "Given `soap_mask` excluded all host atoms."
if not self._environment is None:
@@ -161,14 +134,16 @@ def _make_structure(self, sn):
else:
tracer_atomic_number = self.tracer_atomic_number
- structure.add_atoms((0.0, 0.0, 0.0), tracer_atomic_number)
+ if np.any(structure.get_atomic_numbers() == tracer_atomic_number):
+ raise ValueError("Structure cannot have static atoms (that are enabled in the SOAP mask) of the same species as `tracer_atomic_number`.")
+
structure.set_pbc([True, True, True])
- tracer_index = len(structure) - 1
- return structure, tracer_index, soap_mask
+ return structure, tracer_atomic_number, soap_mask
+
@abstractmethod
- def _get_descriptors(self, stn, structure, tracer_index):
+ def _get_descriptors(self, stn, structure, tracer_atomic_number, soaper):
pass
@@ -177,46 +152,35 @@ def _get_descriptors(self, stn, structure, tracer_index):
class SOAPCenters(SOAP):
"""Compute the SOAPs of the site centers in the fixed host structure.
- Requires a SiteNetwork as input.
+ Requires a ``SiteNetwork`` as input.
"""
- def _get_descriptors(self, sn, structure, tracer_index, soap_mask):
- assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork, not `%s`" % sn
+ def _get_descriptors(self, sn, structure, tracer_atomic_number, soap_mask, soaper):
+ if isinstance(sn, SiteTrajectory):
+ sn = sn.site_network
+ assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork or SiteTrajectory, not `%s`" % sn
pts = sn.centers
- out = np.empty(shape = (len(pts), self.n_dim), dtype = np.float)
-
- structure.set_cutoff(self._soaper.cutoff())
-
- for i, pt in enumerate(tqdm(pts, desc="SOAP") if self._verbose else pts):
- # Move tracer
- structure.positions[tracer_index] = pt
-
- # SOAP requires connectivity data to be computed first
- structure.calc_connect()
-
- #There should only be one descriptor, since there should only be one Li
- out[i] = self._soaper.calc(structure)['descriptor'][0]
+ out = soaper(structure, pts)
return out, np.arange(sn.n_sites)
class SOAPSampledCenters(SOAPCenters):
- """Compute the SOAPs of representative points for each site, as determined by `sampling_transform`.
+ """Compute the SOAPs of representative points for each site, as determined by ``sampling_transform``.
- Takes either a SiteNetwork or SiteTrajectory as input; requires that
- `sampling_transform` produce a SiteNetwork where `site_types` indicates
- which site in the original SiteNetwork/SiteTrajectory it was sampled from.
+ Takes either a ``SiteNetwork`` or ``SiteTrajectory`` as input; requires that
+ ``sampling_transform`` produce a ``SiteNetwork`` where ``site_types`` indicates
+ which site in the original ``SiteNetwork``/``SiteTrajectory`` it was sampled from.
- Typical sampling transforms are `sitator.misc.NAvgsPerSite` (for a SiteTrajectory)
- and `sitator.misc.GenerateAroundSites` (for a SiteNetwork).
+ Typical sampling transforms are ``sitator.misc.NAvgsPerSite`` (for a ``SiteTrajectory``)
+ and ``sitator.misc.GenerateAroundSites`` (for a ``SiteNetwork``).
"""
def __init__(self, *args, **kwargs):
self.sampling_transform = kwargs.pop('sampling_transform', 1)
super(SOAPSampledCenters, self).__init__(*args, **kwargs)
def get_descriptors(self, stn):
-
# Do sampling
sampled = self.sampling_transform.run(stn)
assert isinstance(sampled, SiteNetwork), "Sampling transform returned `%s`, not a SiteNetwork" % sampled
@@ -237,11 +201,11 @@ class SOAPDescriptorAverages(SOAP):
then averaged in SOAP space to give the final SOAP vectors for each site.
This method often performs better than SOAPSampledCenters on more dynamic
- systems, but requires significantly more computation.
+ systems, but requires significantly more computation.
:param int stepsize: Stride (in frames) when computing SOAPs. Default 1.
:param int averaging: Number of SOAP vectors to average for each output vector.
- :param int avg_descriptors_per_site: Can be specified instead of `averaging`.
+ :param int avg_descriptors_per_site: Can be specified instead of ``averaging``.
Specifies the _average_ number of average SOAP vectors to compute for each
site. This does not guerantee that number of SOAP vectors for any site,
rather, it allows a trajectory-size agnostic way to specify approximately
@@ -281,7 +245,7 @@ def __init__(self, *args, **kwargs):
super(SOAPDescriptorAverages, self).__init__(*args, **kwargs)
- def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask):
+ def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soap_mask, soaper):
"""
calculate descriptors
"""
@@ -291,63 +255,66 @@ def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask):
mob_indices = np.where(site_trajectory.site_network.mobile_mask)[0]
# real_traj is the real space positions, site_traj the site trajectory
# (i.e. for every mobile species the site index)
- # I load into new variable, only the steps I need (memory???)
real_traj = site_trajectory._real_traj[::self._stepsize]
site_traj = site_trajectory.traj[::self._stepsize]
# Now, I need to allocate the output
# so for each site, I count how much data there is!
- counts = np.array([np.count_nonzero(site_traj==site_idx) for site_idx in xrange(nsit)], dtype=int)
+ counts = np.zeros(shape = nsit + 1, dtype = np.int)
+ for frame in site_traj:
+ counts[frame] += 1 # A duplicate in `frame` will still only cause a single addition due to Python rules.
+ counts = counts[:-1]
if self._averaging is not None:
averaging = self._averaging
else:
averaging = int(np.floor(np.mean(counts) / self._avg_desc_per_site))
+ logger.debug("Will average %i SOAP vectors for every output vector" % averaging)
- nr_of_descs = counts // averaging
+ if averaging == 0:
+ logger.warning("Asking for too many average descriptors per site; got averaging = 0; setting averaging = 1")
+ averaging = 1
- if np.any(nr_of_descs == 0):
- raise ValueError("You are asking too much, averaging with {} gives a problem".format(averaging))
+ nr_of_descs = counts // averaging
+ insufficient = nr_of_descs == 0
+ if np.any(insufficient):
+ logger.warning("You're asking to average %i SOAP vectors, but at this stepsize, %i sites are insufficiently occupied. Num occ./averaging: %s" % (averaging, np.sum(insufficient), counts[insufficient] / averaging))
+ averagings = np.full(shape = len(nr_of_descs), fill_value = averaging)
+ averagings[insufficient] = counts[insufficient]
+ nr_of_descs = np.maximum(nr_of_descs, 1) # If it's 0, just make one with whatever we've got
+ assert np.all(nr_of_descs >= 1)
+ logger.debug("Minimum # of descriptors/site: %i; maximum: %i" % (np.min(nr_of_descs), np.max(nr_of_descs)))
# This is where I load the descriptor:
- descs = np.zeros((np.sum(nr_of_descs), self.n_dim))
+ descs = np.zeros(shape = (np.sum(nr_of_descs), soaper.n_dim))
# An array that tells me the index I'm at for each site type
- desc_index = [np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))]
- max_index = [np.sum(nr_of_descs[:i+1]) for i in range(len(nr_of_descs))]
+ desc_index = np.asarray([np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))])
+ max_index = np.asarray([np.sum(nr_of_descs[:i+1]) for i in range(len(nr_of_descs))])
count_of_site = np.zeros(len(nr_of_descs), dtype=int)
- blocked = np.empty(nsit, dtype=bool)
- blocked[:] = False
- structure.set_cutoff(self._soaper.cutoff())
- for site_traj_t, pos in tqdm(zip(site_traj, real_traj), desc="SOAP"):
+ allowed = np.ones(nsit, dtype = np.bool)
+
+ for site_traj_t, pos in zip(tqdm(site_traj, desc="SOAP Frame"), real_traj):
# I update the host lattice positions here, once for every timestep
- structure.positions[:tracer_index] = pos[soap_mask]
- for mob_idx, site_idx in enumerate(site_traj_t):
- if site_idx >= 0 and not blocked[site_idx]:
- # Now, for every lithium that has been associated to a site of index site_idx,
- # I take my structure and load the position of this mobile atom:
- structure.positions[tracer_index] = pos[mob_indices[mob_idx]]
- # calc_connect to calculated distance
-# structure.calc_connect()
- #There should only be one descriptor, since there should only be one mobile
- # I also divide by averaging, to avoid getting into large numbers.
-# soapv = self._soaper.calc(structure)['descriptor'][0] / self._averaging
- structure.set_cutoff(self._cutoff)
- structure.calc_connect()
- soapv = self._soaper.calc(structure, grad=False)["descriptor"]
-
- #~ soapv ,_,_ = get_fingerprints([structure], d)
- # So, now I need to figure out where to load the soapv into desc
- idx_to_add_desc = desc_index[site_idx]
- descs[idx_to_add_desc, :] += soapv[0] / averaging
- count_of_site[site_idx] += 1
- # Now, if the count reaches the averaging I want, I augment
- if count_of_site[site_idx] == averaging:
- desc_index[site_idx] += 1
- count_of_site[site_idx] = 0
- # Now I check whether I have to block this site from accumulating more descriptors
- if max_index[site_idx] == desc_index[site_idx]:
- blocked[site_idx] = True
-
- desc_to_site = np.repeat(range(nsit), nr_of_descs)
+ structure.positions[:] = pos[soap_mask]
+
+ to_describe = (site_traj_t != SiteTrajectory.SITE_UNKNOWN) & allowed[site_traj_t]
+
+ if np.any(to_describe):
+ sites_to_describe = site_traj_t[to_describe]
+ soaps = soaper(structure, pos[mob_indices[to_describe]])
+ soaps /= averagings[sites_to_describe][:, np.newaxis]
+ idx_to_add_desc = desc_index[sites_to_describe]
+ descs[idx_to_add_desc] += soaps
+ count_of_site[sites_to_describe] += 1
+
+ # Reset and increment full averages
+ full_average = count_of_site == averagings
+ desc_index[full_average] += 1
+ count_of_site[full_average] = 0
+ allowed[max_index == desc_index] = False
+
+ assert not np.any(allowed) # We should have maxed out all of them after processing all frames.
+
+ desc_to_site = np.repeat(list(range(nsit)), nr_of_descs)
return descs, desc_to_site
diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py
new file mode 100644
index 0000000..8ad2ef8
--- /dev/null
+++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py
@@ -0,0 +1,148 @@
+import numpy as np
+
+from sitator.util.progress import tqdm
+
+try:
+ from pymatgen.io.ase import AseAtomsAdaptor
+ import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf
+ from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \
+ LightStructureEnvironments
+ from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions
+ from pymatgen.analysis.bond_valence import BVAnalyzer
+ has_pymatgen = True
+except ImportError:
+ has_pymatgen = False
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+class SiteCoordinationEnvironment(object):
+ """Determine site types based on local coordination environments.
+
+ Determine site types using the method from the following paper:
+
+ David Waroquiers, Xavier Gonze, Gian-Marco Rignanese,
+ Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk,
+ Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier
+
+ Statistical analysis of coordination environments in oxides
+
+ Chem. Mater., 2017, 29 (19), pp 8346–8360, DOI: 10.1021/acs.chemmater.7b02766
+
+ as implement in ``pymatgen``'s
+ ``pymatgen.analysis.chemenv.coordination_environments``.
+
+ Adds three site attributes:
+ - ``coordination_environments``: The name of the coordination environment,
+ as returned by ``pymatgen``. Example: ``"T:4"`` (tetrahedral, coordination
+ of 4).
+ - ``site_type_confidences``: The ``ce_fraction`` of the best match chemical
+ environment (from 0 to 1).
+ - ``coordination_numbers``: The coordination number of the site.
+
+ Args:
+ only_ionic_bonds (bool): If True, assumes ``sn`` is an ``IonicSiteNetwork``
+ and uses its charge information to only consider as neighbors
+ atoms with compatable (anion and cation, ion and neutral) charges.
+ full_chemenv_site_types (bool): If True, ``sitator`` site types on the
+ final ``SiteNetwork`` will be assigned based on unique chemical
+ environments, including shape. If False, they will be assigned
+ solely based on coordination number. Either way, both sets of information
+ are included in the ``SiteNetwork``, this just changes which determines
+ the ``site_types``.
+ **kwargs: passed to ``compute_structure_environments``.
+ """
+ def __init__(self,
+ only_ionic_bonds = True,
+ full_chemenv_site_types = False,
+ **kwargs):
+ if not has_pymatgen:
+ raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.")
+ self._kwargs = kwargs
+ self._only_ionic_bonds = only_ionic_bonds
+ self._full_chemenv_site_types = full_chemenv_site_types
+
+ def run(self, sn):
+ """
+ Args:
+ sn (SiteNetwork)
+ Returns:
+ ``sn``, with type information.
+ """
+ # -- Determine local environments
+ # Get an ASE structure with a single mobile site that we'll move around
+ if sn.n_sites == 0:
+ logger.warning("Site network had no sites.")
+ return sn
+ site_struct, site_species = sn[0:1].get_structure_with_sites()
+ pymat_struct = AseAtomsAdaptor.get_structure(site_struct)
+ lgf = cgf.LocalGeometryFinder()
+ site_atom_index = len(site_struct) - 1
+
+ coord_envs = []
+ vertices = []
+
+ valences = 'undefined'
+ if self._only_ionic_bonds:
+ valences = list(sn.static_charges) + [sn.mobile_charge]
+
+ logger.info("Running site coordination environment analysis...")
+ # Do this once.
+ # __init__ here defaults to disabling structure refinement, so all this
+ # method is doing is making a copy of the structure and setting some
+ # variables to None.
+ lgf.setup_structure(structure = pymat_struct)
+
+ for site in tqdm(range(sn.n_sites), desc = 'Site'):
+ # Update the position of the site
+ lgf.structure[site_atom_index].coords = sn.centers[site]
+ # Compute structure environments for the site
+ struct_envs = lgf.compute_structure_environments(
+ only_indices = [site_atom_index],
+ valences = valences,
+ additional_conditions = [AdditionalConditions.ONLY_ANION_CATION_BONDS],
+ **self._kwargs
+ )
+ struct_envs = LightStructureEnvironments.from_structure_environments(
+ strategy = cgf.LocalGeometryFinder.DEFAULT_STRATEGY,
+ structure_environments = struct_envs
+ )
+ # Store the results
+ # We take the first environment for each site since it's the most likely
+ this_site_envs = struct_envs.coordination_environments[site_atom_index]
+ if len(this_site_envs) > 0:
+ coord_envs.append(this_site_envs[0])
+ vertices.append(
+ [n['index'] for n in struct_envs.neighbors_sets[site_atom_index][0].neighb_indices_and_images]
+ )
+ else:
+ coord_envs.append({'ce_symbol' : 'BAD:0', 'ce_fraction' : 0.})
+ vertices.append([])
+
+ del lgf
+ del struct_envs
+
+ # -- Postprocess
+ # TODO: allow user to ask for full fractional breakdown
+ str_coord_environments = [env['ce_symbol'] for env in coord_envs]
+ # The closer to 1 this is, the better
+ site_type_confidences = np.array([env['ce_fraction'] for env in coord_envs])
+ coordination_numbers = np.array([int(env['ce_symbol'].split(':')[1]) for env in coord_envs])
+ assert np.all(coordination_numbers == [len(v) for v in vertices])
+
+ typearr = str_coord_environments if self._full_chemenv_site_types else coordination_numbers
+ unique_envs = list(set(typearr))
+ site_types = np.array([unique_envs.index(t) for t in typearr])
+
+ n_types = len(unique_envs)
+ logger.info(("Type " + "{:<8}" * n_types).format(*unique_envs))
+ logger.info(("# of sites " + "{:<8}" * n_types).format(*np.bincount(site_types)))
+
+ sn.site_types = site_types
+ sn.vertices = vertices
+ sn.add_site_attribute("coordination_environments", str_coord_environments)
+ sn.add_site_attribute("site_type_confidences", site_type_confidences)
+ sn.add_site_attribute("coordination_numbers", coordination_numbers)
+
+ return sn
diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py
index 30a7143..26efa8f 100644
--- a/sitator/site_descriptors/SiteTypeAnalysis.py
+++ b/sitator/site_descriptors/SiteTypeAnalysis.py
@@ -1,7 +1,3 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
from sitator.misc import GenerateAroundSites
@@ -14,39 +10,59 @@
import itertools
+import logging
+logger = logging.getLogger(__name__)
+
+has_pydpc = False
try:
import pydpc
+ has_pydpc = True
except ImportError:
- raise ImportError("SiteTypeAnalysis requires the `pydpc` package")
+ pass
class SiteTypeAnalysis(object):
- """Cluster sites into types using a descriptor and DPCLUS.
-
- -- descriptor --
- Some kind of object implementing:
- - n_dim: the number of components in a descriptor vector
- - get_descriptors(site_traj or site_network): returns an array of descriptor vectors
- of dimension (M, n_dim) and an array of length M indicating which
- descriptor vectors correspond to which sites in (site_traj.)site_network.
+ """Cluster sites into types using a continuous descriptor and Density Peak Clustering.
+
+ Computes descriptor vectors, processes them with Principal Component Analysis,
+ and then clusters using Density Peak Clustering.
+
+ Args:
+ descriptor (object): Must implement ``get_descriptors(st|sn)``, which
+ returns an array of descriptor vectors of dimension (M, n_dim) and
+ an array of length M indicating which descriptor vectors correspond
+ to which sites in (``site_traj.``)``site_network``.
+ min_pca_variance (float): The minimum proportion of the total variance
+ that the taken principal components of the descriptor must explain.
+ min_pca_dimensions (int): Force taking at least this many principal
+ components.
+ n_site_types_max (int): Maximum number of clusters. Must be set reasonably
+ for the automatic selection of cluster number to work.
"""
def __init__(self, descriptor,
min_pca_variance = 0.9, min_pca_dimensions = 2,
- verbose = True, n_site_types_max = 20):
+ n_site_types_max = 20):
+ if not has_pydpc:
+ raise ImportError("SiteTypeAnalysis requires the `pydpc` package")
+
self.descriptor = descriptor
self.min_pca_variance = min_pca_variance
self.min_pca_dimensions = min_pca_dimensions
- self.verbose = verbose
self.n_site_types_max = n_site_types_max
self._n_dvecs = None
def run(self, descriptor_input, **kwargs):
+ """
+ Args:
+ descriptor_input (SiteNetwork or SiteTrajectory)
+ Returns:
+ ``SiteNetwork``
+ """
if not self._n_dvecs is None:
raise ValueError("Can't run SiteTypeAnalysis more than once!")
# -- Sample enough points
- if self.verbose:
- print(" -- Running SiteTypeAnalysis --")
+ logger.info(" -- Running SiteTypeAnalysis --")
if isinstance(descriptor_input, SiteNetwork):
sn = descriptor_input.copy()
@@ -56,30 +72,26 @@ def run(self, descriptor_input, **kwargs):
raise RuntimeError("Input {}".format(type(descriptor_input)))
# -- Compute descriptor vectors
- if self.verbose:
- print(" - Computing Descriptor Vectors")
+ logger.info(" - Computing Descriptor Vectors")
self.dvecs, dvecs_to_site = self.descriptor.get_descriptors(descriptor_input, **kwargs)
assert len(self.dvecs) == len(dvecs_to_site), "Length mismatch in descriptor return values"
assert np.min(dvecs_to_site) == 0 and np.max(dvecs_to_site) < sn.n_sites
# -- Dimensionality Reduction
- if self.verbose:
- print(" - Clustering Descriptor Vectors")
+ logger.info(" - Clustering Descriptor Vectors")
self.pca = PCA(self.min_pca_variance)
pca_dvecs = self.pca.fit_transform(self.dvecs)
if pca_dvecs.shape[1] < self.min_pca_dimensions:
- if self.verbose:
- print(" PCA accounted for %.0f%% variance in only %i dimensions; less than minimum of %.0f." % (100.0 * np.sum(self.pca.explained_variance_ratio_), pca_dvecs.shape[1], self.min_pca_dimensions))
- print(" Forcing PCA to use %i dimensions." % self.min_pca_dimensions)
- self.pca = PCA(n_components = self.min_pca_dimensions)
- pca_dvecs = self.pca.fit_transform(self.dvecs)
+ logger.info(" PCA accounted for %.0f%% variance in only %i dimensions; less than minimum of %.0f." % (100.0 * np.sum(self.pca.explained_variance_ratio_), pca_dvecs.shape[1], self.min_pca_dimensions))
+ logger.info(" Forcing PCA to use %i dimensions." % self.min_pca_dimensions)
+ self.pca = PCA(n_components = self.min_pca_dimensions)
+ pca_dvecs = self.pca.fit_transform(self.dvecs)
self.dvecs = pca_dvecs
- if self.verbose:
- print(" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1]))
+ logger.info((" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1])))
# -- Do clustering
# pydpc requires a C-contiguous array
@@ -122,14 +134,13 @@ def run(self, descriptor_input, **kwargs):
assert self.n_types == len(site_type_counts), "Got %i types from pydpc, but counted %i" % (self.n_types, len(site_type_counts))
- if self.verbose:
- print(" Found %i site type clusters" % self.n_types )
- print(" Failed to assign %i/%i descriptor vectors to clusters." % (self._n_unassigned, self._n_dvecs))
+ logger.info(" Found %i site type clusters" % self.n_types )
+ logger.info(" Failed to assign %i/%i descriptor vectors to clusters." % (self._n_unassigned, self._n_dvecs))
# -- Voting
types = np.empty(shape = sn.n_sites, dtype = np.int)
self.winning_vote_percentages = np.empty(shape = sn.n_sites, dtype = np.float)
- for site in xrange(sn.n_sites):
+ for site in range(sn.n_sites):
corresponding_samples = dvecs_to_site == site
votes = assignments[corresponding_samples]
n_votes = len(votes)
@@ -145,12 +156,11 @@ def run(self, descriptor_input, **kwargs):
n_sites_of_each_type = np.bincount(types, minlength = self.n_types)
sn.site_types = types
- if self.verbose:
- print((" " + "Type {:<2} " * self.n_types).format(*xrange(self.n_types)))
- print(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type))
+ logger.info((" " + "Type {:<2} " * self.n_types).format(*range(self.n_types)))
+ logger.info(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type))
- if np.any(n_sites_of_each_type == 0):
- print("WARNING: Had site types with no sites; check clustering settings/voting!")
+ if np.any(n_sites_of_each_type == 0):
+ logger.warning("WARNING: Had site types with no sites; check clustering settings/voting!")
return sn
@@ -172,11 +182,11 @@ def plot_dvecs(self, fig = None, ax = None, **kwargs):
@plotter(is3D = False)
def plot_clustering(self, fig = None, ax = None, **kwargs):
ccycle = itertools.cycle(DEFAULT_COLORS)
- for cluster in xrange(self.n_types):
+ for cluster in range(self.n_types):
mask = self.dpc.membership == cluster
dvecs_core = self.dvecs[mask & ~self.dpc.border_member]
dvecs_border = self.dvecs[mask & self.dpc.border_member]
- color = ccycle.next()
+ color = next(ccycle)
ax.scatter(dvecs_core[:,0], dvecs_core[:,1], s = 3, color = color, label = "Type %i" % cluster)
ax.scatter(dvecs_border[:,0], dvecs_border[:,1], s = 3, color = color, alpha = 0.3)
diff --git a/sitator/site_descriptors/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py
new file mode 100644
index 0000000..e065680
--- /dev/null
+++ b/sitator/site_descriptors/SiteVolumes.py
@@ -0,0 +1,134 @@
+import numpy as np
+
+from scipy.spatial import ConvexHull
+from scipy.spatial.qhull import QhullError
+
+from sitator import SiteNetwork, SiteTrajectory
+from sitator.util import PBCCalculator
+
+import logging
+logger = logging.getLogger(__name__)
+
+class InsufficientCoordinatingAtomsError(Exception):
+ pass
+
+class SiteVolumes(object):
+ """Compute the volumes of sites.
+
+ Args:
+ error_on_insufficient_coord (bool): To compute an ideal
+ site volume (``compute_volumes()``), at least 4 coordinating
+ atoms (because we are in 3D space) must be specified in ``vertices``.
+ If ``True``, an error will be thrown when a site with less than four
+ vertices is encountered; if ``False``, a volume of 0 and surface area
+ of NaN will be returned.
+ """
+ def __init__(self, error_on_insufficient_coord = True):
+ self.error_on_insufficient_coord = error_on_insufficient_coord
+
+
+ def compute_accessable_volumes(self, st, n_recenterings = 8):
+ """Computes the volumes of convex hulls around all positions associated with a site.
+
+ Uses the shift-and-wrap trick for dealing with periodicity, so sites that
+ take up the majority of the unit cell may give bogus results.
+
+ Adds the ``accessable_site_volumes`` attribute to the ``SiteNetwork``.
+
+ Args:
+ st (SiteTrajectory)
+ n_recenterings (int): How many different recenterings to try (the
+ algorithm will recenter around n of the points and take the minimal
+ resulting volume; this deals with cases where there is one outlier
+ where recentering around it gives very bad results.)
+ """
+ assert isinstance(st, SiteTrajectory)
+ vols = np.empty(shape = st.site_network.n_sites, dtype = np.float)
+ areas = np.empty(shape = st.site_network.n_sites, dtype = np.float)
+
+ pbcc = PBCCalculator(st.site_network.structure.cell)
+
+ for site in range(st.site_network.n_sites):
+ pos = st.real_positions_for_site(site)
+
+ assert pos.flags['OWNDATA']
+
+ vol = np.inf
+ area = None
+ for i in range(n_recenterings):
+ # Recenter
+ offset = pbcc.cell_centroid - pos[int(i * (len(pos)/n_recenterings))]
+ pos += offset
+ pbcc.wrap_points(pos)
+
+ try:
+ hull = ConvexHull(pos)
+ except QhullError as qhe:
+ logger.warning("For site %i, iter %i: %s" % (site, i, qhe))
+ vols[site] = np.nan
+ areas[site] = np.nan
+ continue
+
+ if hull.volume < vol:
+ vol = hull.volume
+ area = hull.area
+
+ vols[site] = vol
+ areas[site] = area
+
+ st.site_network.add_site_attribute('accessable_site_volumes', vols)
+
+
+ def compute_volumes(self, sn):
+ """Computes the volume of the convex hull defined by each sites' static verticies.
+
+ Requires vertex information in the SiteNetwork.
+
+ Adds the ``site_volumes`` and ``site_surface_areas`` attributes.
+
+ Volumes can be NaN for degenerate hulls/point sets on which QHull fails.
+
+ Args:
+ - sn (SiteNetwork)
+ """
+ assert isinstance(sn, SiteNetwork)
+ if sn.vertices is None:
+ raise ValueError("SiteNetwork must have verticies to compute volumes!")
+
+ vols = np.empty(shape = sn.n_sites, dtype = np.float)
+ areas = np.empty(shape = sn.n_sites, dtype = np.float)
+
+ pbcc = PBCCalculator(sn.structure.cell)
+
+ for site in range(sn.n_sites):
+ pos = sn.static_structure.positions[list(sn.vertices[site])]
+ if len(pos) < 4:
+ if self.error_on_insufficient_coord:
+ raise InsufficientCoordinatingAtomsError("Site %i had only %i vertices (less than needed 4)" % (site, len(pos)))
+ else:
+ vols[site] = 0
+ areas[site] = np.nan
+ continue
+
+ assert pos.flags['OWNDATA'] # It should since we're indexing with index lists
+ # Recenter
+ offset = pbcc.cell_centroid - sn.centers[site]
+ pos += offset
+ pbcc.wrap_points(pos)
+
+ try:
+ hull = ConvexHull(pos)
+ vols[site] = hull.volume
+ areas[site] = hull.area
+ except QhullError as qhe:
+ logger.warning("Had QHull failure when computing volume of site %i" % site)
+ vols[site] = np.nan
+ areas[site] = np.nan
+
+ sn.add_site_attribute('site_volumes', vols)
+ sn.add_site_attribute('site_surface_areas', areas)
+
+
+ def run(self, st):
+ """For backwards compatability."""
+ self.compute_accessable_volumes(st)
diff --git a/sitator/site_descriptors/__init__.py b/sitator/site_descriptors/__init__.py
index ca132b2..fd5e7af 100644
--- a/sitator/site_descriptors/__init__.py
+++ b/sitator/site_descriptors/__init__.py
@@ -1,3 +1,3 @@
-from SiteTypeAnalysis import SiteTypeAnalysis
-
-from SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages
+from .SiteTypeAnalysis import SiteTypeAnalysis
+from .SiteCoordinationEnvironment import SiteCoordinationEnvironment
+from .SiteVolumes import SiteVolumes
diff --git a/sitator/site_descriptors/backend/__init__.py b/sitator/site_descriptors/backend/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/sitator/site_descriptors/backend/dscribe.py b/sitator/site_descriptors/backend/dscribe.py
new file mode 100644
index 0000000..66fc5d5
--- /dev/null
+++ b/sitator/site_descriptors/backend/dscribe.py
@@ -0,0 +1,40 @@
+
+import numpy as np
+
+DEFAULT_SOAP_PARAMS = {
+ 'cutoff' : 3.0,
+ 'l_max' : 6, 'n_max' : 6,
+ 'atom_sigma' : 0.4,
+ 'rbf' : 'gto',
+ 'crossover' : False,
+ 'periodic' : True,
+}
+
+def dscribe_soap_backend(soap_params = {}):
+ from dscribe.descriptors import SOAP
+
+ soap_opts = dict(DEFAULT_SOAP_PARAMS)
+ soap_opts.update(soap_params)
+
+ def backend(sn, soap_mask, tracer_atomic_number, environment_list):
+ soap = SOAP(
+ species = environment_list,
+ crossover = soap_opts['crossover'],
+ rcut = soap_opts['cutoff'],
+ nmax = soap_opts['n_max'],
+ lmax = soap_opts['l_max'],
+ rbf = soap_opts['rbf'],
+ sigma = soap_opts['atom_sigma'],
+ periodic = soap_opts['periodic'],
+ sparse = False
+ )
+
+ def dscribe_soap(structure, positions):
+ out = soap.create(structure, positions = positions).astype(np.float)
+ return out
+
+ dscribe_soap.n_dim = soap.get_number_of_features()
+
+ return dscribe_soap
+
+ return backend
diff --git a/sitator/site_descriptors/backend/quip.py b/sitator/site_descriptors/backend/quip.py
new file mode 100644
index 0000000..d97459c
--- /dev/null
+++ b/sitator/site_descriptors/backend/quip.py
@@ -0,0 +1,86 @@
+"""
+quip.py: Compute SOAP vectors for given positions in a structure using the command line QUIP tool
+"""
+
+import numpy as np
+
+import os
+
+import ase
+
+from tempfile import NamedTemporaryFile
+import subprocess
+
+DEFAULT_SOAP_PARAMS = {
+ 'cutoff' : 3.0,
+ 'cutoff_transition_width' : 1.0,
+ 'l_max' : 6, 'n_max' : 6,
+ 'atom_sigma' : 0.4
+}
+
+def quip_soap_backend(soap_params = {}, quip_path = os.getenv("SITATOR_QUIP_PATH", default = 'quip')):
+ def backend(sn, soap_mask, tracer_atomic_number, environment_list):
+
+ soap_opts = dict(DEFAULT_SOAP_PARAMS)
+ soap_opts.update(soap_params)
+ soap_cmd_line = ["soap"]
+
+ # User options
+ for opt in soap_opts:
+ soap_cmd_line.append("{}={}".format(opt, soap_opts[opt]))
+
+ #
+ soap_cmd_line.append('n_Z=1 Z={{{}}}'.format(tracer_atomic_number))
+
+ soap_cmd_line.append('n_species={} species_Z={{{}}}'.format(len(environment_list), ' '.join(map(str, environment_list))))
+
+ soap_cmd_line = " ".join(soap_cmd_line)
+
+ def soaper(structure, positions):
+ structure = structure.copy()
+ for i in range(len(positions)):
+ structure.append(ase.Atom(position = tuple(positions[i]), symbol = tracer_atomic_number))
+ return _soap(soap_cmd_line, structure, quip_path = quip_path)
+
+ return soaper
+ return backend
+
+
+def _soap(descriptor_str,
+ structure,
+ quip_path = 'quip'):
+ """Calculate SOAP vectors by calling `quip` as a subprocess.
+
+ Args:
+ - descriptor_str (str): The QUIP descriptor str, i.e. `soap cutoff=3 ...`
+ - structure (ase.Atoms)
+ - quip_path (str): Path to `quip` executable
+ """
+
+ with NamedTemporaryFile() as xyz:
+ structure.write(xyz.name, format = 'extxyz')
+
+ quip_cmd = [
+ quip_path,
+ "atoms_filename=" + xyz.name,
+ "descriptor_str=\"" + descriptor_str + "\""
+ ]
+
+ result = subprocess.run(quip_cmd, stdout = subprocess.PIPE, check = True, text = True).stdout
+
+ lines = result.splitlines()
+
+ soaps = []
+ for line in lines:
+ if line.startswith("DESC"):
+ soaps.append(np.fromstring(line.lstrip("DESC"), dtype = np.float, sep = ' '))
+ elif line.startswith("Error"):
+ e = subprocess.CalledProcessError(returncode = 0, cmd = quip_cmd)
+ e.stdout = result
+ raise e
+ else:
+ continue
+
+ soaps = np.asarray(soaps)
+
+ return soaps
diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx
index d933a5c..51a61c9 100644
--- a/sitator/util/DotProdClassifier.pyx
+++ b/sitator/util/DotProdClassifier.pyx
@@ -4,23 +4,13 @@ import numpy as np
import numbers
-import sys
-try:
- from IPython import get_ipython
- ipy_str = str(type(get_ipython()))
- if 'zmqshell' in ipy_str:
- from tqdm import tqdm_notebook as tqdm
- if 'terminal' in ipy_str:
- from tqdm import tqdm
-except:
- if sys.stderr.isatty():
- from tqdm import tqdm
- else:
- def tqdm(iterable, **kwargs):
- return iterable
+from sitator.util.progress import tqdm
N_SITES_ALLOC_INCREMENT = 100
+import logging
+logger = logging.getLogger(__name__)
+
class OneValueListlike(object):
def __init__(self, value, length = np.inf):
self.length = length
@@ -33,20 +23,26 @@ class OneValueListlike(object):
return self.value
class DotProdClassifier(object):
+ """Assign vectors to clusters indicated by a representative vector using a cosine metric.
+
+ Cluster centers can be given through ``set_cluster_centers()`` or approximated
+ using the custom method described in the appendix of the main landmark
+ analysis paper (``fit_centers()``).
+
+ Args:
+ threshold (float): Similarity threshold for joining a cluster.
+ In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal)
+ max_converge_iters (int): Maximum number of iterations. If the algorithm
+ hasn't converged by then, it will exit with a warning.
+ min_samples (float or int): filter out clusters with low sample counts.
+ If an int, filters out clusters with fewer samples than this.
+ If a float, filters out clusters with fewer than
+ floor(min_samples * n_assigned_samples) samples assigned to them.
+ """
def __init__(self,
threshold = 0.9,
max_converge_iters = 10,
min_samples = 1):
- """
- :param float threshold: Similarity threshold for joining a cluster.
- In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal)
- :param int max_converge_iters: Maximum number of iterations. If the algorithm hasn't converged
- by then, it will exit with a warning.
- :param int|float min_samples: filter out clusters with low sample counts.
- If an int, filters out clusters with fewer samples than this.
- If a float, filters out clusters with fewer than floor(min_samples * n_assigned_samples)
- samples assigned to them.
- """
self._threshold = threshold
self._max_iters = max_converge_iters
self._min_samples = min_samples
@@ -58,6 +54,9 @@ class DotProdClassifier(object):
def cluster_centers(self):
return self._cluster_centers
+ def set_cluster_centers(self, centers):
+ self._cluster_centers = centers
+
@property
def cluster_counts(self):
return self._cluster_counts
@@ -66,7 +65,7 @@ class DotProdClassifier(object):
def n_clusters(self):
return len(self._cluster_counts)
- def fit_predict(self, X, verbose = True, predict_threshold = None, return_info = False):
+ def fit_predict(self, X, verbose = True, predict_threshold = None, predict_normed = True, return_info = False):
""" Fit the data vectors X and return their cluster labels.
"""
@@ -80,22 +79,136 @@ class DotProdClassifier(object):
if predict_threshold is None:
predict_threshold = self._threshold
+ if self._cluster_centers is None:
+ self.fit_centers(X)
+
+ # Run a predict now:
+ labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold, predict_normed = predict_normed)
+
+ total_n_assigned = np.sum(labels >= 0)
+
+ # -- filter out low counts
+ if not self._min_samples is None:
+ self._cluster_counts = np.bincount(labels[labels >= 0], minlength = len(self._cluster_centers))
+
+ assert len(self._cluster_counts) == len(self._cluster_centers)
+
+ min_samples = None
+ if isinstance(self._min_samples, numbers.Integral):
+ min_samples = self._min_samples
+ elif isinstance(self._min_samples, numbers.Real):
+ min_samples = int(np.floor(self._min_samples * total_n_assigned))
+ else:
+ raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
+ min_samples = max(min_samples, 1)
+
+ count_mask = self._cluster_counts >= min_samples
+
+ self._cluster_centers = self._cluster_centers[count_mask]
+ self._cluster_counts = self._cluster_counts[count_mask]
+
+ if len(self._cluster_centers) == 0:
+ # Then we removed everything...
+ raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask))
+
+ logger.info("DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \
+ (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts)))
+
+ # Do another predict -- this could be more efficient, but who cares?
+ labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold, predict_normed = predict_normed)
+
+ if return_info:
+ info = {
+ 'clusters_below_min_samples' : np.sum(~count_mask),
+ 'kept_clusters_mask' : count_mask
+ }
+ return labels, confs, info
+ else:
+ return labels, confs
+
+ def predict(self, X, return_confidences = False, threshold = None, predict_normed = True, verbose = True, ignore_zeros = True):
+ """Return a predicted cluster label for vectors X.
+
+ :param float threshold: alternate threshold. Defaults to None, when self.threshold
+ is used.
+
+ :returns: an array of labels. -1 indicates no assignment.
+ :returns: an array of confidences in assignments. Normalzied
+ values from 0 (no confidence, no label) to 1 (identical to cluster center).
+ """
+
+ assert len(X.shape) == 2, "Data must be 2D."
+
+ if not X.shape[1] == (self._featuredim):
+ raise TypeError("X has wrong dimension %s; should be (%i)" % (X.shape, self._featuredim))
+
+ labels = np.empty(shape = len(X), dtype = np.int)
+
+ if threshold is None:
+ threshold = self._threshold
+
+ confidences = None
+ if return_confidences:
+ confidences = np.empty(shape = len(X), dtype = np.float)
+
+ zeros_count = 0
+
+ if predict_normed:
+ center_norms = np.linalg.norm(self._cluster_centers, axis = 1)
+ normed_centers = self._cluster_centers.copy()
+ normed_centers /= center_norms[:, np.newaxis]
+ else:
+ normed_centers = self._cluster_centers
+
+ # preallocate buffers
+ diffs = np.empty(shape = len(normed_centers), dtype = np.float)
+
+ for i, x in enumerate(tqdm(X, desc = "Sample")):
+
+ if np.all(x == 0):
+ if ignore_zeros:
+ labels[i] = -1
+ zeros_count += 1
+ continue
+ else:
+ raise ValueError("Data %i is all zeros!" % i)
+
+ np.dot(normed_centers, x, out = diffs)
+ if predict_normed:
+ diffs /= np.linalg.norm(x)
+ np.abs(diffs, out = diffs)
+
+ assigned_to = np.argmax(diffs)
+ assignment_confidence = diffs[assigned_to]
+
+ if assignment_confidence < threshold:
+ assigned_to = -1
+ assignment_confidence = 0.0
+
+ labels[i] = assigned_to
+ confidences[i] = assignment_confidence
+
+ if zeros_count > 0:
+ logger.warning("Encountered %i zero vectors during prediction" % zeros_count)
+
+ if return_confidences:
+ return labels, confidences
+ else:
+ return labels
+
+ def fit_centers(self, X):
# Essentially hierarchical clustering that stops when no cluster *centers*
# are more similar than the threshold.
-
labels = np.empty(shape = len(X), dtype = np.int)
labels.fill(-1)
# Start with each sample as a cluster
- #old_centers = np.copy(X)
- #old_n_assigned = np.ones(shape = len(X), dtype = np.int)
# For memory's sake, no copying
old_centers = X
old_n_assigned = OneValueListlike(value = 1, length = len(X))
old_n_clusters = len(X)
# -- Classification loop
-
# Maximum number of iterations
last_n_sites = -1
did_converge = False
@@ -109,7 +222,7 @@ class DotProdClassifier(object):
first_iter = True
- for iteration in xrange(self._max_iters):
+ for iteration in tqdm(xrange(self._max_iters), desc = "Clustering iter.", total = float('inf')):
# This iteration's centers
# The first sample is always its own cluster
cluster_centers[0] = old_centers[0]
@@ -117,7 +230,7 @@ class DotProdClassifier(object):
n_assigned_to[0] = old_n_assigned[0]
n_clusters = 1
# skip the first sample which has already been accounted for
- for i, vec in tqdm(zip(xrange(1, old_n_clusters), old_centers[1:old_n_clusters]), desc = "Iteration %i" % iteration):
+ for i, vec in zip(xrange(1, old_n_clusters), old_centers[1:old_n_clusters]):
assigned_to = -1
assigned_cosang = 0.0
@@ -200,114 +313,3 @@ class DotProdClassifier(object):
raise ValueError("Clustering did not converge after %i iterations" % (self._max_iters))
self._cluster_centers = np.asarray(cluster_centers[:n_clusters])
-
- # Run a predict now:
- labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
-
- total_n_assigned = np.sum(labels >= 0)
-
- # -- filter out low counts
- if not self._min_samples is None:
- self._cluster_counts = np.bincount(labels[labels >= 0])
-
- assert len(self._cluster_counts) == len(self._cluster_centers)
-
- min_samples = None
- if isinstance(self._min_samples, numbers.Integral):
- min_samples = self._min_samples
- elif isinstance(self._min_samples, numbers.Real):
- min_samples = int(np.floor(self._min_samples * total_n_assigned))
- else:
- raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples)
-
- count_mask = self._cluster_counts >= min_samples
-
- self._cluster_centers = self._cluster_centers[count_mask]
- self._cluster_counts = self._cluster_counts[count_mask]
-
- if len(self._cluster_centers) == 0:
- # Then we removed everything...
- raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask))
-
- if verbose:
- print "DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \
- (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts))
-
- # Do another predict -- this could be more efficient, but who cares?
- labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold)
-
- if return_info:
- info = {
- 'clusters_below_min_samples' : np.sum(~count_mask)
- }
- return labels, confs, info
- else:
- return labels, confs
-
- def predict(self, X, return_confidences = False, threshold = None, verbose = True, ignore_zeros = True):
- """Return a predicted cluster label for vectors X.
-
- :param float threshold: alternate threshold. Defaults to None, when self.threshold
- is used.
-
- :returns: an array of labels. -1 indicates no assignment.
- :returns: an array of confidences in assignments. Normalzied
- values from 0 (no confidence, no label) to 1 (identical to cluster center).
- """
-
- assert len(X.shape) == 2, "Data must be 2D."
-
- if not X.shape[1] == (self._featuredim):
- raise TypeError("X has wrong dimension %s; should be (%i)" % (X.shape, self._featuredim))
-
- labels = np.empty(shape = len(X), dtype = np.int)
-
- if threshold is None:
- threshold = self._threshold
-
- confidences = None
- if return_confidences:
- confidences = np.empty(shape = len(X), dtype = np.float)
-
- zeros_count = 0
-
- center_norms = np.linalg.norm(self._cluster_centers, axis = 1)
- normed_centers = self._cluster_centers.copy()
- normed_centers /= center_norms[:, np.newaxis]
-
- # preallocate buffers
- diffs = np.empty(shape = len(center_norms), dtype = np.float)
-
- for i, x in enumerate(tqdm(X, desc = "Sample")):
-
- if np.all(x == 0):
- if ignore_zeros:
- labels[i] = -1
- zeros_count += 1
- continue
- else:
- raise ValueError("Data %i is all zeros!" % i)
-
- # diffs = np.sum(x * self._cluster_centers, axis = 1)
- # diffs /= np.linalg.norm(x) * center_norms
- np.dot(normed_centers, x, out = diffs)
- diffs /= np.linalg.norm(x)
- #diffs /= center_norms
-
- assigned_to = np.argmax(diffs)
- assignment_confidence = diffs[assigned_to]
-
- if assignment_confidence < threshold:
- assigned_to = -1
- assignment_confidence = 0.0
-
- labels[i] = assigned_to
- confidences[i] = assignment_confidence
-
- if verbose and zeros_count > 0:
- print "Encountered %i zero vectors during prediction" % zeros_count
-
- if return_confidences:
- return labels, confidences
- else:
- return labels
diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx
index f34830d..caf827f 100644
--- a/sitator/util/PBCCalculator.pyx
+++ b/sitator/util/PBCCalculator.pyx
@@ -11,18 +11,20 @@ ctypedef double precision
ctypedef double cell_precision
cdef class PBCCalculator(object):
- """Performs calculations on collections of 3D points under PBC"""
+ """Performs calculations on collections of 3D points under PBC."""
cdef cell_precision [:, :] _cell_mat_array
cdef cell_precision [:, :] _cell_mat_inverse_array
cdef cell_precision [:] _cell_centroid
cdef cell_precision [:, :] _cell
+
def __init__(self, cell):
"""
:param DxD ndarray: the unit cell -- an array of cell vectors, like the
cell of an ASE :class:Atoms object.
"""
+ cell = cell[:]
cellmat = np.matrix(cell).T
assert cell.shape[1] == cell.shape[0], "Cell must be square"
@@ -32,15 +34,39 @@ cdef class PBCCalculator(object):
self._cell_mat_inverse_array = np.asarray(cellmat.I)
self._cell_centroid = np.sum(0.5 * cell, axis = 0)
+
@property
def cell_centroid(self):
return self._cell_centroid
+
+ cpdef pairwise_distances(self, pts, out = None):
+ """Compute the pairwise distance matrix of ``pts`` with itself.
+
+ :returns ndarray (len(pts), len(pts)): distances
+ """
+ if out is None:
+ out = np.empty(shape = (len(pts), len(pts)), dtype = pts.dtype)
+
+ buf = pts.copy()
+
+ for i in xrange(len(pts) - 1):
+ out[i, i] = 0
+ self.distances(pts[i], buf[i + 1:], in_place = True, out = out[i, i + 1:])
+ out[i + 1:, i] = out[i, i + 1:]
+ buf[:] = pts
+
+ out[len(pts) - 1, len(pts) - 1] = 0
+
+ return out
+
+
cpdef distances(self, pt1, pts2, in_place = False, out = None):
- """Compute the Euclidean distances from pt1 to all points in pts2, using
- shift-and-wrap.
+ """
+ Compute the Euclidean distances from ``pt1`` to all points in
+ ``pts2``, using shift-and-wrap.
- Makes a copy of pts2 unless in_place == True.
+ Makes a copy of ``pts2`` unless ``in_place == True``.
:returns ndarray len(pts2): distances
"""
@@ -76,6 +102,7 @@ cdef class PBCCalculator(object):
#return np.linalg.norm(self._cell_centroid - pts2, axis = 1)
return np.sqrt(out, out = out)
+
cpdef average(self, points, weights = None):
"""Average position of a "cloud" of points using the shift-and-wrap hack.
@@ -83,13 +110,18 @@ cdef class PBCCalculator(object):
Assumes that the points are relatively close (within a half unit cell)
together, and that the first point is not a particular outsider (the
- cell is centered at that point).
+ cell is centered at that point). If the average is weighted, the
+ maximally weighted point will be taken as the center.
Can be a weighted average with the semantics of :func:numpy.average.
"""
assert points.shape[1] == 3 and points.ndim == 2
- offset = self._cell_centroid - points[0]
+ center_about = 0
+ if weights is not None:
+ center_about = np.argmax(weights)
+
+ offset = self._cell_centroid - points[center_about]
ptbuf = points.copy()
@@ -106,6 +138,7 @@ cdef class PBCCalculator(object):
return out
+
cpdef time_average(self, frames):
"""Do multiple PBC correct means. Frames is n_frames x n_pts x 3.
@@ -137,6 +170,7 @@ cdef class PBCCalculator(object):
del posbuf
return out
+
cpdef void wrap_point(self, precision [:] pt):
"""Wrap a single point into the unit cell, IN PLACE. 3D only."""
cdef cell_precision [:, :] cell = self._cell_mat_array
@@ -158,7 +192,8 @@ cdef class PBCCalculator(object):
pt[0] = buf[0]; pt[1] = buf[1]; pt[2] = buf[2];
- cpdef bint is_in_unit_cell(self, precision [:] pt):
+
+ cpdef bint is_in_unit_cell(self, const precision [:] pt):
cdef cell_precision [:, :] cell = self._cell_mat_array
cdef cell_precision [:, :] cell_I = self._cell_mat_inverse_array
@@ -174,13 +209,15 @@ cdef class PBCCalculator(object):
return (buf[0] < 1.0) and (buf[1] < 1.0) and (buf[2] < 1.0) and \
(buf[0] >= 0.0) and (buf[1] >= 0.0) and (buf[2] >= 0.0)
- cpdef bint all_in_unit_cell(self, precision [:, :] pts):
+
+ cpdef bint all_in_unit_cell(self, const precision [:, :] pts):
for pt in pts:
if not self.is_in_unit_cell(pt):
return False
return True
- cpdef bint is_in_image_of_cell(self, precision [:] pt, image):
+
+ cpdef bint is_in_image_of_cell(self, const precision [:] pt, image):
cdef cell_precision [:, :] cell = self._cell_mat_array
cdef cell_precision [:, :] cell_I = self._cell_mat_inverse_array
@@ -199,6 +236,7 @@ cdef class PBCCalculator(object):
return out
+
cpdef void to_cell_coords(self, precision [:, :] points):
"""Convert to cell coordinates in place."""
assert points.shape[1] == 3, "Points must be 3D"
@@ -220,11 +258,15 @@ cdef class PBCCalculator(object):
# Store into points
points[i, 0] = buf[0]; points[i, 1] = buf[1]; points[i, 2] = buf[2];
- cpdef int min_image(self, precision [:] ref, precision [:] pt):
- """Find the minimum image of `pt` relative to `ref`. In place in pt.
+
+ cpdef int min_image(self, const precision [:] ref, precision [:] pt):
+ """Find the minimum image of ``pt`` relative to ``ref``. In place in pt.
Uses the brute force algorithm for correctness; returns the minimum image.
+ Assumes that ``ref`` and ``pt`` are already in the *same* cell (though not
+ necessarily the <0,0,0> cell -- any periodic image will do).
+
:returns int[3] minimg: Which image was the minimum image.
"""
# # There are 27 possible minimum images
@@ -273,6 +315,7 @@ cdef class PBCCalculator(object):
return 100 * minimg[0] + 10 * minimg[1] + 1 * minimg[2]
+
cpdef void to_real_coords(self, precision [:, :] points):
"""Convert to real coords from crystal coords in place."""
assert points.shape[1] == 3, "Points must be 3D"
@@ -294,8 +337,9 @@ cdef class PBCCalculator(object):
# Store into points
points[i, 0] = buf[0]; points[i, 1] = buf[1]; points[i, 2] = buf[2];
+
cpdef void wrap_points(self, precision [:, :] points):
- """Wrap `points` into a unit cell, IN PLACE. 3D only.
+ """Wrap ``points`` into a unit cell, IN PLACE. 3D only.
"""
assert points.shape[1] == 3, "Points must be 3D"
diff --git a/sitator/util/RecenterTrajectory.pyx b/sitator/util/RecenterTrajectory.pyx
index 39e5d7c..1e3a7c9 100644
--- a/sitator/util/RecenterTrajectory.pyx
+++ b/sitator/util/RecenterTrajectory.pyx
@@ -18,7 +18,7 @@ class RecenterTrajectory(object):
``static_mask``, IN PLACE.
Args:
- structure (ASE Atoms): An atoms representing the structure of the
+ structure (ase.Atoms): An atoms representing the structure of the
simulation.
static_mask (ndarray): Boolean mask indicating which atoms to recenter on
positions (ndarray): (n_frames, n_atoms, 3), modified in place
@@ -32,6 +32,8 @@ class RecenterTrajectory(object):
masses of all atoms in the system.
"""
+ assert np.any(static_mask), "Static mask all false; there must be static atoms to recenter on."
+
factors = static_mask.astype(np.float)
n_static = np.sum(static_mask)
diff --git a/sitator/util/__init__.py b/sitator/util/__init__.py
index 76ae207..b59424f 100644
--- a/sitator/util/__init__.py
+++ b/sitator/util/__init__.py
@@ -1,8 +1,8 @@
-from PBCCalculator import PBCCalculator
+from .PBCCalculator import PBCCalculator
-from DotProdClassifier import DotProdClassifier
+from .DotProdClassifier import DotProdClassifier
-from zeo import Zeopy
+from .zeo import Zeopy
-from RecenterTrajectory import RecenterTrajectory
+from .RecenterTrajectory import RecenterTrajectory
diff --git a/sitator/util/chemistry.py b/sitator/util/chemistry.py
new file mode 100644
index 0000000..c0070ea
--- /dev/null
+++ b/sitator/util/chemistry.py
@@ -0,0 +1,100 @@
+import numpy as np
+
+import ase.data
+try:
+ from ase.formula import Formula
+except ImportError:
+ from ase.util.formula import Formula
+
+from sitator.util import PBCCalculator
+
+# Key is a central atom, value is a list of the formulas for the rest.
+# See, http://www.fccj.us/PolyatomicIons/CompletePolyatomicIonList.htm
+DEFAULT_POLYATOMIC_IONS = {
+ 'Ar' : ['O4', 'O3'],
+ 'C' : ['N', 'O3', 'O2', 'NO', 'SN'],
+ 'B' : ['O3','O2'],
+ 'Br' : ['O4', 'O3', 'O2', 'O'],
+ 'Fe' : ['O4'],
+ 'I' : ['O3', 'O4', 'O2', 'O'],
+ 'Si' : ['O4', 'O3'],
+ 'S' : ['O5', 'O4', 'O3', 'SO3', 'O2'],
+ 'Sb' : ['O4', 'O3'],
+ 'Se' : ['O4', 'O3'],
+ 'Sn' : ['O3', 'O2'],
+ 'N' : ['O3', 'O2', 'CO'],
+ 'Re' : ['O4'],
+ 'Cl' : ['O4', 'O3', 'O2', 'O'],
+ 'Mn' : ['O4'],
+ 'Mo' : ['O4'],
+ 'Cr' : ['O4', 'CrO7', 'O2'],
+ 'P' : ['O4', 'O3', 'O2'],
+ 'Tc' : ['O4'],
+ 'Te' : ['O4', 'O6', 'O3'],
+ 'Pb' : ['O3', 'O2'],
+ 'W' : ['O4']
+}
+
+def identify_polyatomic_ions(structure,
+ cutoff_factor = 1.0,
+ ion_definitions = DEFAULT_POLYATOMIC_IONS):
+ """Find polyatomic ions in a structure.
+
+ Goes to each potential polyatomic ion center and first checks if the nearest
+ neighbor is of a viable species. If it is, all nearest neighbors within the
+ maximum possible pairwise summed covalent radii are found and matched against
+ the database.
+
+ Args:
+ structure (ase.Atoms)
+ cutoff_factor (float): Coefficient for the cutoff. Allows tuning just
+ how closely the polyatomic ions must be bound. Defaults to 1.0.
+ ion_definitions (dict): Database of polyatomic ions where the key is a
+ central atom species symbol and the value is a list of the formulas
+ for the rest. Defaults to a list of common, non-hydrogen-containing
+ polyatomic ions with Oxygen.
+ Returns:
+ list of tuples, one for each identified polyatomic anion containing its
+ formula, the index of the central atom, and the indexes of the remaining
+ atoms.
+ """
+ ion_definitions = {
+ k : [Formula(f) for f in v]
+ for k, v in ion_definitions.items()
+ }
+ out = []
+ # Go to each possible center atom in data, check that nearest neighbor is
+ # of right species, then get all within covalent distances and check if
+ # composition matches database...
+ pbcc = PBCCalculator(structure.cell)
+ dmat = pbcc.pairwise_distances(structure.positions) # Precompute
+ np.fill_diagonal(dmat, np.inf)
+ for center_i, center_symbol in enumerate(structure.symbols):
+ if center_symbol in ion_definitions:
+ nn_sym = structure.symbols[np.argmin(dmat[center_i])]
+ could_be = [f for f in ion_definitions[center_symbol] if nn_sym in f]
+ if len(could_be) == 0:
+ # Nearest neighbor isn't even a plausible polyatomic ion species,
+ # so skip this potential center.
+ continue
+ # Take the largest possible other species covalent radius for all
+ # other species it possibily could be.
+ cutoff = max(
+ ase.data.covalent_radii[ase.data.atomic_numbers[other_sym]]
+ for form in could_be for other_sym in form
+ )
+ cutoff += ase.data.covalent_radii[ase.data.atomic_numbers[center_symbol]]
+ cutoff *= cutoff_factor
+ neighbors = dmat[center_i] <= cutoff
+ neighbors = np.where(neighbors)[0]
+ neighbor_formula = Formula.from_list(structure.symbols[neighbors])
+ it_is = [f for f in could_be if f == neighbor_formula]
+ if len(it_is) > 1:
+ raise ValueError("Somehow identified single center %s (atom %i) as multiple polyatomic ions %s" % (center_symbol, center_i, it_is))
+ elif len(it_is) == 1:
+ out.append((
+ center_symbol + neighbor_formula.format('hill'),
+ center_i,
+ neighbors
+ ))
+ return out
diff --git a/sitator/util/elbow.py b/sitator/util/elbow.py
index ec038ac..078bec9 100644
--- a/sitator/util/elbow.py
+++ b/sitator/util/elbow.py
@@ -2,7 +2,7 @@
# See discussion around this question: https://stackoverflow.com/questions/2018178/finding-the-best-trade-off-point-on-a-curve/2022348#2022348
def index_of_elbow(points):
- """Returns the index of the "elbow" in points.
+ """Returns the index of the "elbow" in ``points``.
Decently fast and pretty approximate. Performs worse with disproportionately
long "flat" tails. For example, in a dataset with a nearly right-angle elbow,
diff --git a/sitator/util/mcl.py b/sitator/util/mcl.py
new file mode 100644
index 0000000..784cf7a
--- /dev/null
+++ b/sitator/util/mcl.py
@@ -0,0 +1,60 @@
+import numpy as np
+
+def markov_clustering(transition_matrix,
+ expansion = 2,
+ inflation = 2,
+ pruning_threshold = 0.00001,
+ iterlimit = 100):
+ """Compute the Markov Clustering of a graph.
+ See https://micans.org/mcl/.
+
+ Because we're dealing with matrixes that are stochastic already,
+ there's no need to add artificial loop values.
+
+ Implementation inspired by https://github.com/GuyAllard/markov_clustering
+ """
+
+ assert transition_matrix.shape[0] == transition_matrix.shape[1]
+
+ # Check for nonzero diagonal -- self loops needed to avoid div by zero and NaNs
+ assert np.count_nonzero(transition_matrix.diagonal()) == len(transition_matrix)
+
+ m1 = transition_matrix.copy()
+
+ # Normalize (though it should be close already)
+ m1 /= np.sum(m1, axis = 0)
+
+ allcols = np.arange(m1.shape[1])
+
+ converged = False
+ for i in range(iterlimit):
+ # -- Expansion
+ m2 = np.linalg.matrix_power(m1, expansion)
+ # -- Inflation
+ np.power(m2, inflation, out = m2)
+ m2 /= np.sum(m2, axis = 0)
+ # -- Prune
+ to_prune = m2 < pruning_threshold
+ # Exclude the max of every column
+ to_prune[np.argmax(m2, axis = 0), allcols] = False
+ m2[to_prune] = 0.0
+ # -- Check converged
+ if np.allclose(m1, m2):
+ converged = True
+ break
+
+ m1[:] = m2
+
+ if not converged:
+ raise ValueError("Markov Clustering couldn't converge in %i iterations" % iterlimit)
+
+ # -- Get clusters
+ attractors = m2.diagonal().nonzero()[0]
+
+ clusters = set()
+
+ for a in attractors:
+ cluster = tuple(m2[a].nonzero()[0])
+ clusters.add(cluster)
+
+ return list(clusters)
diff --git a/sitator/util/progress.py b/sitator/util/progress.py
new file mode 100644
index 0000000..4ffe766
--- /dev/null
+++ b/sitator/util/progress.py
@@ -0,0 +1,14 @@
+import os
+
+progress = os.getenv('SITATOR_PROGRESSBAR', 'true').lower()
+progress = (progress == 'true') or (progress == 'yes') or (progress == 'on')
+
+if progress:
+ try:
+ from tqdm.autonotebook import tqdm
+ except:
+ def tqdm(iterable, **kwargs):
+ return iterable
+else:
+ def tqdm(iterable, **kwargs):
+ return iterable
diff --git a/sitator/util/qvoronoi.py b/sitator/util/qvoronoi.py
deleted file mode 100644
index 4156a06..0000000
--- a/sitator/util/qvoronoi.py
+++ /dev/null
@@ -1,144 +0,0 @@
-
-import tempfile
-import subprocess
-import sys
-
-import re
-from collections import OrderedDict
-
-import numpy as np
-
-from sklearn.neighbors import KDTree
-
-from sitator.util import PBCCalculator
-
-def periodic_voronoi(structure, logfile = sys.stdout):
- """
- :param ASE.Atoms structure:
- """
-
- pbcc = PBCCalculator(structure.cell)
-
- # Make a 3x3x3 supercell
- supercell = structure.repeat((3, 3, 3))
-
- qhull_output = None
-
- logfile.write("Qvoronoi ---")
-
- # Run qhull
- with tempfile.NamedTemporaryFile('w',
- prefix = 'qvor',
- suffix='.in', delete = False) as infile, \
- tempfile.NamedTemporaryFile('r',
- prefix = 'qvor',
- suffix='.out',
- delete=True) as outfile:
- # -- Write input file --
- infile.write("3\n") # num of dimensions
- infile.write("%i\n" % len(supercell)) # num of points
- np.savetxt(infile, supercell.get_positions(), fmt = '%.16f')
- infile.flush()
-
- cmdline = ["qvoronoi", "TI", infile.name, "FF", "Fv", "TO", outfile.name]
- process = subprocess.Popen(cmdline, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
- retcode = process.wait()
- logfile.write(process.stdout.read())
- if retcode != 0:
- raise RuntimeError("qvoronoi returned exit code %i" % retcode)
-
- qhull_output = outfile.read()
-
- facets_regex = re.compile(
- """
- -[ \t](?Pf[0-9]+) [\n]
- [ \t]*-[ ]flags: .* [\n]
- [ \t]*-[ ]normal: .* [\n]
- [ \t]*-[ ]offset: .* [\n]
- [ \t]*-[ ]center:(?P([ ][\-]?[0-9]*[\.]?[0-9]*(e[-?[0-9]+)?){3}) [ \t] [\n]
- [ \t]*-[ ]vertices:(?P([ ]p[0-9]+\(v[0-9]+\))+) [ \t]? [\n]
- [ \t]*-[ ]neighboring[ ]facets:(?P([ ]f[0-9]+)+)
- """, re.X | re.M)
-
- vertices_re = re.compile('(?<=p)[0-9]+')
-
- # Allocate stuff
- centers = []
- vertices = []
- facet_indexes_taken = set()
-
- facet_index_to_our_index = {}
- all_facets_centers = []
-
- # ---- Read facets
- facet_index = -1
- next_our_index = 0
- for facet_match in facets_regex.finditer(qhull_output):
- center = np.asarray(map(float, facet_match.group('center').split()))
- facet_index += 1
-
- all_facets_centers.append(center)
-
- if not pbcc.is_in_image_of_cell(center, (1, 1, 1)):
- continue
-
- verts = map(int, vertices_re.findall(facet_match.group('vertices')))
- verts_in_main_cell = tuple(v % len(structure) for v in verts)
-
- facet_indexes_taken.add(facet_index)
-
- centers.append(center)
- vertices.append(verts_in_main_cell)
-
- facet_index_to_our_index[facet_index] = next_our_index
-
- next_our_index += 1
-
- end_of_facets = facet_match.end()
-
- facet_count = facet_index + 1
-
- logfile.write(" qhull gave %i vertices; kept %i" % (facet_count, len(centers)))
-
- # ---- Read ridges
- qhull_output_after_facets = qhull_output[end_of_facets:].strip()
- ridge_re = re.compile('^\d+ \d+ \d+(?P( \d+)+)$', re.M)
-
- ridges = [
- [int(v) for v in match.group('verts').split()]
- for match in ridge_re.finditer(qhull_output_after_facets)
- ]
- # only take ridges with at least 1 facet in main unit cell.
- ridges = [
- r for r in ridges if any(f in facet_indexes_taken for f in r)
- ]
-
- # shift centers back into normal unit cell
- centers -= np.sum(structure.cell, axis = 0)
-
- nearest_center = KDTree(centers)
-
- ridges_in_main_cell = set()
- threw_out = 0
- for r in ridges:
- ridge_centers = np.asarray([all_facets_centers[f] for f in r if f < len(all_facets_centers)])
- if not pbcc.all_in_unit_cell(ridge_centers):
- continue
-
- pbcc.wrap_points(ridge_centers)
- dists, ridge_centers_in_main = nearest_center.query(ridge_centers, return_distance = True)
-
- if np.any(dists > 0.00001):
- threw_out += 1
- continue
-
- assert ridge_centers_in_main.shape == (len(ridge_centers), 1), "%s" % ridge_centers_in_main.shape
- ridge_centers_in_main = ridge_centers_in_main[:,0]
-
- ridges_in_main_cell.add(frozenset(ridge_centers_in_main))
-
- logfile.write(" Threw out %i ridges" % threw_out)
-
- logfile.flush()
-
- return centers, vertices, ridges_in_main_cell
diff --git a/sitator/util/zeo.py b/sitator/util/zeo.py
index 430e4bd..19bddf0 100644
--- a/sitator/util/zeo.py
+++ b/sitator/util/zeo.py
@@ -1,9 +1,6 @@
#zeopy: simple Python interface to the Zeo++ `network` tool.
# Alby Musaelian 2018
-from __future__ import (absolute_import, division,
- print_function)
-
import os
import sys
import tempfile
@@ -19,10 +16,13 @@
from sitator.util import PBCCalculator
+import logging
+logger = logging.getLogger(__name__)
+
# TODO: benchmark CUC vs CIF
class Zeopy(object):
- """A wrapper for the Zeo++ `network` tool.
+ """A wrapper for the Zeo++ ``network`` tool.
:warning: Do not use a single instance of Zeopy in parallel.
"""
@@ -30,7 +30,7 @@ class Zeopy(object):
def __init__(self, path_to_zeo):
"""Create a Zeopy.
- :param str path_to_zeo: Path to the `network` executable.
+ :param str path_to_zeo: Path to the ``network`` executable.
"""
if not (os.path.exists(path_to_zeo) and os.access(path_to_zeo, os.X_OK)):
raise ValueError("`%s` doesn't seem to be the path to an executable file." % path_to_zeo)
@@ -43,7 +43,7 @@ def __enter__(self):
def __exit__(self, *args):
shutil.rmtree(self._tmpdir)
- def voronoi(self, structure, radial = False, verbose=True):
+ def voronoi(self, structure, radial = False):
"""
:param Atoms structure: The ASE Atoms to compute the Voronoi decomposition of.
"""
@@ -55,7 +55,7 @@ def voronoi(self, structure, radial = False, verbose=True):
outp = os.path.join(self._tmpdir, "out.nt2")
v1out = os.path.join(self._tmpdir, "out.v1")
- ase.io.write(inp, structure)
+ ase.io.write(inp, structure, parallel = False)
# with open(inp, "w") as inf:
# inf.write(self.ase2cuc(structure))
@@ -69,12 +69,11 @@ def voronoi(self, structure, radial = False, verbose=True):
output = subprocess.check_output([self._exe] + args + ["-v1", v1out, "-nt2", outp, inp],
stderr = subprocess.STDOUT)
except subprocess.CalledProcessError as e:
- print("Zeo++ returned an error:", file = sys.stderr)
- print(e.output, file = sys.stderr)
+ logger.error("Zeo++ returned an error:")
+ logger.error(str(e.output))
raise
- if verbose:
- print(output)
+ logger.debug(output)
with open(outp, "r") as outf:
verts, edges = self.parse_nt2(outf.readlines())
@@ -130,12 +129,12 @@ def parse_v1_cell(v1lines):
# remove blank lines:
v1lines = iter(filter(None, v1lines))
# First line is just "Unit cell vectors:"
- assert v1lines.next().strip() == "Unit cell vectors:"
+ assert next(v1lines).strip() == "Unit cell vectors:"
# Unit cell:
cell = np.empty(shape = (3, 3), dtype = np.float)
cellvec_re = re.compile('v[abc]=')
- for i in xrange(3):
- cellvec = v1lines.next().strip().split()
+ for i in range(3):
+ cellvec = next(v1lines).strip().split()
assert cellvec_re.match(cellvec[0])
cell[i] = [float(e) for e in cellvec[1:]]
# number of atoms, etc.
diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py
index 157da1d..8c529e8 100644
--- a/sitator/visualization/SiteNetworkPlotter.py
+++ b/sitator/visualization/SiteNetworkPlotter.py
@@ -1,7 +1,3 @@
-from __future__ import (absolute_import, division,
- print_function, unicode_literals)
-from builtins import *
-
import numpy as np
import itertools
@@ -10,21 +6,35 @@
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from sitator.util import PBCCalculator
-from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS
+from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS, set_axes_equal
class SiteNetworkPlotter(object):
- """Plot a SiteNetwork.
-
- site_mappings defines how to show different properties. Each entry maps a
- visual aspect ('marker', 'color', 'size') to the name of a site attribute
- including 'site_type'.
-
- Likewise for edge_mappings, each key maps a visual property ('intensity', 'color',
- 'width', 'linestyle') to an edge attribute in the SiteNetwork.
+ """Plot a ``SiteNetwork``.
Note that for edges, the average of the edge property for i -> j and j -> i
is often used for visual clarity; if your edge properties are not almost symmetric,
the visualization might not be useful.
+
+ Args:
+ site_mappings (dict): defines how to show different properties. Each
+ entry maps a visual aspect ('marker', 'color', 'size') to the name
+ of a site attribute including 'site_type'. The markers can also be
+ arbitrary text (key `"text"`) in which case the value can also be a
+ 2-tuple of an attribute name and a `%` format string.
+ edge_mappings (dict): each key maps a visual property ('intensity',
+ 'color', 'width', 'linestyle') to an edge attribute in the SiteNetwork.
+ markers (list of str): What `matplotlib` markers to use for sites.
+ plot_points_params (dict): User options for plotting site points.
+ minmax_linewidth (2-tuple): Minimum and maximum linewidth to use.
+ minmax_edge_alpha (2-tuple): Similar, for edge line alphas.
+ minmax_markersize (2-tuple): Similar, for markersize.
+ min_color_threshold (float): Minimum (normalized) color intensity for
+ the corresponding line to be shown. Defaults to zero, i.e., all
+ nonzero edges will be drawn.
+ min_width_threshold (float): Minimum normalized edge width for the
+ corresponding edge to be shown. Defaults to zero, i.e., all
+ nonzero edges will be drawn.
+ title (str): Title for the figure.
"""
DEFAULT_SITE_MAPPINGS = {
@@ -34,7 +44,7 @@ class SiteNetworkPlotter(object):
DEFAULT_MARKERS = ['x', '+', 'v', '<', '^', '>', '*', 'd', 'h', 'p']
DEFAULT_LINESTYLES = ['--', ':', '-.', '-']
- EDGE_GROUP_COLORS = ['b', 'g', 'm', 'lightseagreen', 'crimson'] + ['gray'] # gray last for -1's
+ EDGE_GROUP_COLORS = ['b', 'g', 'm', 'crimson', 'lightseagreen', 'darkorange', 'sandybrown', 'gold', 'hotpink'] + ['gray'] # gray last for -1's
def __init__(self,
site_mappings = DEFAULT_SITE_MAPPINGS,
@@ -43,11 +53,12 @@ def __init__(self,
plot_points_params = {},
minmax_linewidth = (1.5, 7),
minmax_edge_alpha = (0.15, 0.75),
- minmax_markersize = (80.0, 180.0),
- min_color_threshold = 0.005,
- min_width_threshold = 0.005,
+ minmax_markersize = (20.0, 80.0),
+ min_color_threshold = 0.0,
+ min_width_threshold = 0.0,
title = ""):
self.site_mappings = site_mappings
+ assert not ("marker" in site_mappings and "text" in site_mappings)
self.edge_mappings = edge_mappings
self.markers = markers
self.plot_points_params = plot_points_params
@@ -66,7 +77,6 @@ def __call__(self, sn, *args, **kwargs):
# -- Plot actual SiteNetwork --
l = [(plot_atoms, {'atoms' : sn.static_structure})]
l += self._site_layers(sn, self.plot_points_params)
-
l += self._plot_edges(sn, *args, **kwargs)
# -- Some visual clean up --
@@ -74,7 +84,6 @@ def __call__(self, sn, *args, **kwargs):
ax.set_title(self.title)
- ax.set_aspect('equal')
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
@@ -90,26 +99,47 @@ def __call__(self, sn, *args, **kwargs):
# -- Put it all together --
layers(*l, **kwargs)
- def _site_layers(self, sn, plot_points_params):
+ def _site_layers(self, sn, plot_points_params, same_normalization = False):
pts_arrays = {'points' : sn.centers}
- pts_params = {'cmap' : 'rainbow'}
+ pts_params = {'cmap' : 'winter'}
# -- Apply mapping
# - other mappings
markers = None
for key in self.site_mappings:
- val = getattr(sn, self.site_mappings[key])
+ val = self.site_mappings[key]
+ if isinstance(val, tuple):
+ val, param = val
+ else:
+ param = None
+ val = getattr(sn, val)
if key == 'marker':
if not val is None:
markers = val.copy()
+ istextmarker = False
+ elif key == 'text':
+ istextmarker = True
+ format_str = "%s" if param is None else param
+ format_str = "$" + format_str + "$"
+ markers = val.copy()
elif key == 'color':
pts_arrays['c'] = val.copy()
- pts_params['norm'] = matplotlib.colors.Normalize(vmin = np.min(val), vmax = np.max(val))
+ if not same_normalization:
+ self._color_minmax = [np.min(val), np.max(val)]
+ if self._color_minmax[0] == self._color_minmax[1]:
+ self._color_minmax[0] -= 1 # Just to avoid div by zero
+ color_minmax = self._color_minmax
+ pts_params['norm'] = matplotlib.colors.Normalize(vmin = color_minmax[0], vmax = color_minmax[1])
elif key == 'size':
+ if not same_normalization:
+ self._size_minmax = [np.min(val), np.max(val)]
+ if self._size_minmax[0] == self._size_minmax[1]:
+ self._size_minmax[0] -= 1 # Just to avoid div by zero
+ size_minmax = self._size_minmax
s = val.copy()
- s += np.min(s)
- s /= np.max(s)
+ s -= size_minmax[0]
+ s /= size_minmax[1] - size_minmax[0]
s *= self.minmax_markersize[1]
s += self.minmax_markersize[0]
pts_arrays['s'] = s
@@ -122,17 +152,27 @@ def _site_layers(self, sn, plot_points_params):
# Just one layer with all points and one marker
marker_layers[SiteNetworkPlotter.DEFAULT_MARKERS[0]] = np.ones(shape = sn.n_sites, dtype = np.bool)
else:
- markers = self._make_discrete(markers)
+ if not istextmarker:
+ markers = self._make_discrete(markers)
unique_markers = np.unique(markers)
- marker_i = 0
+ if not same_normalization:
+ if istextmarker:
+ self._marker_table = dict(zip(unique_markers, (format_str % um for um in unique_markers)))
+ else:
+ if len(unique_markers) > len(self.markers):
+ raise ValueError("Too many distinct values of the site property mapped to markers (there are %i) for the %i markers in `self.markers`" % (len(unique_markers), len(self.markers)))
+ self._marker_table = dict(zip(unique_markers, self.markers[:len(unique_markers)]))
+
for um in unique_markers:
- marker_layers[SiteNetworkPlotter.DEFAULT_MARKERS[marker_i]] = (markers == um)
- marker_i += 1
+ marker_layers[self._marker_table[um]] = (markers == um)
# -- Do plot
# If no color info provided, a fallback
if not 'color' in pts_params and not 'c' in pts_arrays:
pts_params['color'] = 'k'
+ # If no color info provided, a fallback
+ if not 's' in pts_params and not 's' in pts_arrays:
+ pts_params['s'] = sum(self.minmax_markersize) / 2
# Add user options for `plot_points`
pts_params.update(plot_points_params)
@@ -201,8 +241,8 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
sites_to_plot = []
sites_to_plot_positions = []
- for i in xrange(n_sites):
- for j in xrange(n_sites):
+ for i in range(n_sites):
+ for j in range(n_sites):
# No self edges
if i == j:
continue
@@ -210,9 +250,9 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
if done_already[i, j]:
continue
# Ignore anything below the threshold
- if all_cs[i, j] < self.min_color_threshold:
+ if all_cs[i, j] <= self.min_color_threshold:
continue
- if do_widths and all_linewidths[i, j] < self.min_width_threshold:
+ if do_widths and all_linewidths[i, j] <= self.min_width_threshold:
continue
segment = np.empty(shape = (2, 3), dtype = centers.dtype)
@@ -254,7 +294,9 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
lccolors = np.empty(shape = (len(cs), 4), dtype = np.float)
# Group colors
if do_groups:
- for i in xrange(len(cs)):
+ for i in range(len(cs)):
+ if groups[i] >= len(SiteNetworkPlotter.EDGE_GROUP_COLORS) - 1:
+ raise ValueError("Too many groups, not enough group colors")
lccolors[i] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[groups[i]])
else:
lccolors[:] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[0])
@@ -273,12 +315,14 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs):
ax.add_collection(lc)
# -- Plot new sites
- sn2 = sn[sites_to_plot]
- sn2.update_centers(np.asarray(sites_to_plot_positions))
-
- pts_params = dict(self.plot_points_params)
- pts_params['alpha'] = 0.2
- return self._site_layers(sn2, pts_params)
+ if len(sites_to_plot) > 0:
+ sn2 = sn[sites_to_plot]
+ sn2.update_centers(np.asarray(sites_to_plot_positions))
+ pts_params = dict(self.plot_points_params)
+ pts_params['alpha'] = 0.2
+ return self._site_layers(sn2, pts_params, same_normalization = True)
+ else:
+ return []
else:
return []
diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py
new file mode 100644
index 0000000..9994dfe
--- /dev/null
+++ b/sitator/visualization/SiteTrajectoryPlotter.py
@@ -0,0 +1,158 @@
+import numpy as np
+
+import matplotlib
+from matplotlib.collections import LineCollection
+
+from sitator.util import PBCCalculator
+from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS
+
+
+class SiteTrajectoryPlotter(object):
+ """Produce various plots of a ``SiteTrajectory``."""
+
+ @plotter(is3D = True)
+ def plot_frame(self, st, frame, **kwargs):
+ """Plot sites and instantaneous positions from a given frame.
+
+ Args:
+ st (SiteTrajectory)
+ frame (int)
+ """
+ sites_of_frame = np.unique(st._traj[frame])
+ frame_sn = st._sn[sites_of_frame]
+
+ frame_sn.plot(**kwargs)
+
+ if not st._real_traj is None:
+ mobile_atoms = st._sn.structure.copy()
+ del mobile_atoms[~st._sn.mobile_mask]
+
+ mobile_atoms.positions[:] = st._real_traj[frame, st._sn.mobile_mask]
+ plot_atoms(atoms = mobile_atoms, **kwargs)
+
+ kwargs['ax'].set_title("Frame %i/%i" % (frame, st.n_frames))
+
+
+ @plotter(is3D = True)
+ def plot_site(self, st, site, **kwargs):
+ """Plot all real space positions associated with a site.
+
+ Args:
+ st (SiteTrajectory)
+ site (int)
+ """
+ pbcc = PBCCalculator(st._sn.structure.cell)
+ pts = st.real_positions_for_site(site).copy()
+ offset = pbcc.cell_centroid - pts[3]
+ pts += offset
+ pbcc.wrap_points(pts)
+ lattice_pos = st._sn.static_structure.positions.copy()
+ lattice_pos += offset
+ pbcc.wrap_points(lattice_pos)
+ site_pos = st._sn.centers[site:site+1].copy()
+ site_pos += offset
+ pbcc.wrap_points(site_pos)
+ # Plot point cloud
+ plot_points(points = pts, alpha = 0.3, marker = '.', color = 'k', **kwargs)
+ # Plot site
+ plot_points(points = site_pos, color = 'cyan', **kwargs)
+ # Plot everything else
+ plot_atoms(st._sn.static_structure, positions = lattice_pos, **kwargs)
+
+ title = "Site %i/%i" % (site, len(st._sn))
+
+ if not st._sn.site_types is None:
+ title += " (type %i)" % st._sn.site_types[site]
+
+ kwargs['ax'].set_title(title)
+
+
+ @plotter(is3D = False)
+ def plot_particle_trajectory(self, st, particle, ax = None, fig = None, **kwargs):
+ """Plot the sites occupied by a mobile particle over time.
+
+ Args:
+ st (SiteTrajectory)
+ particle (int)
+ """
+ types = not st._sn.site_types is None
+ if types:
+ type_height_percent = 0.1
+ axpos = ax.get_position()
+ typeax_height = type_height_percent * axpos.height
+ typeax = fig.add_axes([axpos.x0, axpos.y0, axpos.width, typeax_height], sharex = ax)
+ ax.set_position([axpos.x0, axpos.y0 + typeax_height, axpos.width, axpos.height - typeax_height])
+ type_height = 1
+ # Draw trajectory
+ segments = []
+ linestyles = []
+ colors = []
+
+ traj = st._traj[:, particle]
+ current_value = traj[0]
+ last_value = traj[0]
+ if types:
+ last_type = None
+ current_segment_start = 0
+ puttext = False
+
+ for i, f in enumerate(traj):
+ if f != current_value or i == len(traj) - 1:
+ val = last_value if current_value == -1 else current_value
+ if last_value == -1:
+ segments.append([[current_segment_start, val], [i, val]])
+ else:
+ segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]])
+ linestyles.append(':' if current_value == -1 else '-')
+ if current_value == -1:
+ if val == -1:
+ c = 'red' # Uncorrected unknown
+ else:
+ c = 'lightgray' # Unknown but reassigned
+ else:
+ c = 'k' # Known
+ colors.append(c)
+
+ if types:
+ rxy = (current_segment_start, 0)
+ this_type = st._sn.site_types[val]
+ typerect = matplotlib.patches.Rectangle(rxy, i - current_segment_start, type_height,
+ color = DEFAULT_COLORS[this_type], linewidth = 0)
+ typeax.add_patch(typerect)
+ if this_type != last_type:
+ typeax.annotate("T%i" % this_type,
+ xy = (rxy[0], rxy[1] + 0.5 * type_height),
+ xytext = (3, -1),
+ textcoords = 'offset points',
+ fontsize = 'xx-small',
+ va = 'center',
+ fontweight = 'bold')
+ last_type = this_type
+
+ last_value = val
+ current_segment_start = i
+ current_value = f
+
+ lc = LineCollection(segments, linestyles = linestyles, colors = colors, linewidth=1.5)
+ ax.add_collection(lc)
+
+ if types:
+ typeax.set_xlabel("Frame")
+ ax.tick_params(axis = 'x', which = 'both', bottom = False, top = False, labelbottom = False)
+ typeax.tick_params(axis = 'y', which = 'both', left = False, right = False, labelleft = False)
+ typeax.annotate("Type", xy = (0, 0.5), xytext = (-25, 0), xycoords = 'axes fraction', textcoords = 'offset points', va = 'center', fontsize = 'x-small')
+ else:
+ ax.set_xlabel("Frame")
+ ax.set_ylabel("Atom %i's site" % particle)
+
+ ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
+ ax.grid()
+
+ ax.set_xlim((0, st.n_frames - 1))
+ margin_percent = 0.04
+ ymargin = (margin_percent * st._sn.n_sites)
+ ax.set_ylim((-ymargin, st._sn.n_sites - 1.0 + ymargin))
+
+ if types:
+ typeax.set_xlim((0, st.n_frames - 1))
+ typeax.set_ylim((0, type_height))
diff --git a/sitator/visualization/__init__.py b/sitator/visualization/__init__.py
index 4624148..e96d843 100644
--- a/sitator/visualization/__init__.py
+++ b/sitator/visualization/__init__.py
@@ -1,5 +1,7 @@
-from common import layers, grid, plotter, DEFAULT_COLORS
+from .common import layers, grid, plotter, DEFAULT_COLORS, set_axes_equal
-from atoms import plot_atoms, plot_points
+from .atoms import plot_atoms, plot_points
-from SiteNetworkPlotter import SiteNetworkPlotter
+from .SiteNetworkPlotter import SiteNetworkPlotter
+
+from .SiteTrajectoryPlotter import SiteTrajectoryPlotter
diff --git a/sitator/visualization/atoms.py b/sitator/visualization/atoms.py
index 1d0cc9c..b4f76cd 100644
--- a/sitator/visualization/atoms.py
+++ b/sitator/visualization/atoms.py
@@ -36,9 +36,20 @@ def plot_atoms(atoms, positions = None, hide_species = (), wrap = False, fig = N
all_cvecs = []
- whos_left = set(xrange(len(atoms.cell)))
+ whos_left = set(range(len(atoms.cell)))
+ cvec_labels = ["$\\vec{a}$", "$\\vec{b}$", "$\\vec{c}$"]
for i, cvec1 in enumerate(atoms.cell):
all_cvecs.append(np.array([[0.0, 0.0, 0.0], cvec1]))
+ ax.text(
+ cvec1[0] * 0.25,
+ cvec1[1] * 0.25,
+ cvec1[2] * 0.25,
+ cvec_labels[i],
+ size = 9,
+ color = 'gray',
+ ha = 'left',
+ va = 'center'
+ )
for j, cvec2 in enumerate(atoms.cell[list(whos_left - {i})]):
all_cvecs.append(np.array([cvec1, cvec1 + cvec2]))
for i, cvec1 in enumerate(atoms.cell):
diff --git a/sitator/visualization/common.py b/sitator/visualization/common.py
index 9c4e702..74db185 100644
--- a/sitator/visualization/common.py
+++ b/sitator/visualization/common.py
@@ -33,8 +33,10 @@ def plotter_wrapper(func):
def plotter_wraped(*args, **kwargs):
fig = None
ax = None
+ toplevel = False
if not ('ax' in kwargs and 'fig' in kwargs):
- # No existing axis/figure
+ # No existing axis/figure - toplevel
+ toplevel = True
fig = plt.figure(**outer)
if is3D:
ax = fig.add_subplot(111, projection = '3d')
@@ -54,14 +56,19 @@ def plotter_wraped(*args, **kwargs):
kwargs['i'] = 0
func(*args, fig = fig, ax = ax, **kwargs)
+
+ if is3D and toplevel:
+ set_axes_equal(ax)
+
return fig, ax
+
return plotter_wraped
+
return plotter_wrapper
@plotter(is3D = True)
def layers(*args, **fax):
i = fax['i']
- print i
for p, kwargs in args:
p(fig = fax['fig'], ax = fax['ax'], i = i, **kwargs)
i += 1
diff --git a/sitator/voronoi.py b/sitator/voronoi.py
new file mode 100644
index 0000000..748e819
--- /dev/null
+++ b/sitator/voronoi.py
@@ -0,0 +1,50 @@
+
+import numpy as np
+
+import os
+
+from sitator import SiteNetwork
+from sitator.util import Zeopy
+
+class VoronoiSiteGenerator(object):
+ """Given an empty SiteNetwork, use the Voronoi decomposition to predict/generate sites.
+
+ :param str zeopp_path: Path to the Zeo++ ``network`` executable
+ :param bool radial: Whether to use the radial Voronoi transform. Defaults to,
+ and should typically be, ``False``.
+ """
+
+ def __init__(self,
+ zeopp_path = os.getenv("SITATOR_ZEO_PATH", default = "network"),
+ radial = False):
+ self._radial = radial
+ self._zeopy = Zeopy(zeopp_path)
+
+ def run(self, sn, seed_mask = None):
+ """
+ Args:
+ sn (SiteNetwork): Any sites will be ignored; needed for structure
+ and static mask.
+ Returns:
+ A ``SiteNetwork``.
+ """
+ assert isinstance(sn, SiteNetwork)
+
+ if seed_mask is None:
+ seed_mask = sn.static_mask
+ assert not np.any(seed_mask & sn.mobile_mask), "Seed mask must not overlap with mobile mask"
+ assert not np.any(seed_mask & ~sn.static_mask), "All seed atoms must be static."
+ voro_struct = sn.structure[seed_mask]
+ translation = np.zeros(shape = len(sn.static_mask), dtype = np.int)
+ translation[sn.static_mask] = np.arange(sn.n_static)
+ translation = translation[seed_mask]
+
+ with self._zeopy:
+ nodes, verts, edges, _ = self._zeopy.voronoi(voro_struct,
+ radial = self._radial)
+
+ out = sn.copy()
+ out.centers = nodes
+ out.vertices = [translation[v] for v in verts]
+
+ return out
diff --git a/sitator/voronoi/VoronoiSiteGenerator.py b/sitator/voronoi/VoronoiSiteGenerator.py
deleted file mode 100644
index 64c2058..0000000
--- a/sitator/voronoi/VoronoiSiteGenerator.py
+++ /dev/null
@@ -1,34 +0,0 @@
-
-import numpy as np
-
-from sitator import SiteNetwork
-from sitator.util import Zeopy
-
-class VoronoiSiteGenerator(object):
- """Given an empty SiteNetwork, use the Voronoi decomposition to predict/generate sites.
-
- :param str zeopp_path: Path to the Zeo++ `network` executable
- :param bool radial: Whether to use the radial Voronoi transform. Defaults to,
- and should typically be, False.
- :param bool verbose:
- """
-
- def __init__(self, zeopp_path = "network", radial = False, verbose = True):
- self._radial = radial
- self._verbose = verbose
- self._zeopy = Zeopy(zeopp_path)
-
- def run(self, sn):
- """SiteNetwork -> SiteNetwork"""
- assert isinstance(sn, SiteNetwork)
-
- with self._zeopy:
- nodes, verts, edges, _ = self._zeopy.voronoi(sn.static_structure,
- radial = self._radial,
- verbose = self._verbose)
-
- out = sn.copy()
- out.centers = nodes
- out.vertices = verts
-
- return out
diff --git a/sitator/voronoi/__init__.py b/sitator/voronoi/__init__.py
deleted file mode 100644
index 17c3b95..0000000
--- a/sitator/voronoi/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-from VoronoiSiteGenerator import VoronoiSiteGenerator