diff --git a/.gitignore b/.gitignore index 4800c95..f71864f 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,8 @@ build/ /dist/ /*.egg-info + +*.bak + +.tags +.tags_swap diff --git a/README.md b/README.md index 607bcd1..c93db2c 100644 --- a/README.md +++ b/README.md @@ -9,41 +9,54 @@ A modular framework for conducting and visualizing site analysis of molecular dy `sitator` contains an efficient implementation of our method, landmark analysis, as well as visualization tools, generic data structures for site analysis, pre- and post-processing tools, and more. -For details on the method and its application, please see our paper: +For details on landmark analysis and its application, please see our paper: > L. Kahle, A. Musaelian, N. Marzari, and B. Kozinsky
> [Unsupervised landmark analysis for jump detection in molecular dynamics simulations](https://doi.org/10.1103/PhysRevMaterials.3.055404)
> Phys. Rev. Materials 3, 055404 – 21 May 2019 -If you use `sitator` in your research, please consider citing this paper. The BibTex citation can be found in [`CITATION.bib`](CITATION.bib). +If you use `sitator` in your research, please consider citing this paper. The BibTeX citation can be found in [`CITATION.bib`](CITATION.bib). ## Installation -`sitator` is currently built for Python 2.7. We recommend the use of a Python 2.7 virtual environment (`virtualenv`, `conda`, etc.). `sitator` has two external dependencies: +`sitator` is built for Python >=3.2 (the older version supports Python 2.7). We recommend the use of a virtual environment (`virtualenv`, `conda`, etc.). `sitator` has a number of optional dependencies that enable various features: - - The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does *not* have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.) - - If you want to use the site type analysis features, a working installation of [Quippy](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) and Python bindings is required for computing SOAP descriptors. +* Landmark Analysis + * The `network` executable from [Zeo++](http://www.maciejharanczyk.info/Zeopp/examples.html) is required for computing the Voronoi decomposition. (It does not have to be installed in `PATH`; the path to it can be given with the `zeopp_path` option of `VoronoiSiteGenerator`.) +* Site Type Analysis + * For SOAP-based site types: either the `quip` binary from [QUIP](https://libatoms.github.io/QUIP/) with [GAP](http://www.libatoms.org/gap/gap_download.html) **or** the [`DScribe`](https://singroup.github.io/dscribe/index.html) Python library. + * The Python 2.7 bindings for QUIP (`quippy`) are **not** required. Generally, `DScribe` is much simpler to install than QUIP. **Please note**, however, that the SOAP descriptor vectors **differ** between QUIP and `DScribe` and one or the other may give better results depending on the system you are analyzing. + * For coordination environment analysis (`sitator.site_descriptors.SiteCoordinationEnvironment`), we integrate the `pymatgen.analysis.chemenv` package; a somewhat recent installation of `pymatgen` is required. After downloading, the package is installed with `pip`: -``` +```bash # git clone ... OR unzip ... OR ... cd sitator pip install . ``` -To enable site type analysis, add the `[SiteTypeAnalysis]` option: +To enable site type analysis, add the `[SiteTypeAnalysis]` option (this adds two dependencies -- Python packages `pydpc` and `dscribe`): -``` +```bash pip install ".[SiteTypeAnalysis]" ``` ## Examples and Documentation -Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/). +Two example Jupyter notebooks for conducting full landmark analyses of LiAlSiO4 and Li12La3Zr2O12 as in our paper, including data files, can be found [on Materials Cloud](https://archive.materialscloud.org/2019.0008/). + +Full API documentation can be found at [ReadTheDocs](https://sitator.readthedocs.io/en/py3/). + +`sitator` generally assumes units of femtoseconds for time, Angstroms for space, +and Cartesian (not crystal) coordinates. + +## Global Options + +`sitator` uses the `tqdm.autonotebook` tool to automatically produce the correct fancy progress bars for terminals and iPython notebooks. To disable all progress bars, run with the environment variable `SITATOR_PROGRESSBAR` set to `false`. -All individual classes and parameters are documented with docstrings in the source code. +The `SITATOR_ZEO_PATH` and `SITATOR_QUIP_PATH` environment variables can set the default paths to the Zeo++ `network` and QUIP `quip` executables, respectively. ## License -This software is made available under the MIT License. See `LICENSE` for more details. +This software is made available under the MIT License. See [`LICENSE`](LICENSE) for more details. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..6247f7e --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..8e7ec1b --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,55 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'sitator' +copyright = '2019, Alby Musaelian' +author = 'Alby Musaelian' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon' +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + +master_doc = 'index' + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..e2f4d80 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,20 @@ +.. sitator documentation master file, created by + sphinx-quickstart on Tue Jul 16 15:15:41 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to sitator's documentation! +=================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/modules.rst b/docs/source/modules.rst new file mode 100644 index 0000000..244d585 --- /dev/null +++ b/docs/source/modules.rst @@ -0,0 +1,7 @@ +sitator +======= + +.. toctree:: + :maxdepth: 4 + + sitator diff --git a/docs/source/sitator.descriptors.rst b/docs/source/sitator.descriptors.rst new file mode 100644 index 0000000..aeb7b2e --- /dev/null +++ b/docs/source/sitator.descriptors.rst @@ -0,0 +1,15 @@ +sitator.descriptors package +=========================== + +Module contents +--------------- + +.. automodule:: sitator.descriptors + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.descriptors.ConfigurationalEntropy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.dynamics.rst b/docs/source/sitator.dynamics.rst new file mode 100644 index 0000000..c2510ac --- /dev/null +++ b/docs/source/sitator.dynamics.rst @@ -0,0 +1,40 @@ +sitator.dynamics package +======================== + +Module contents +--------------- + +.. automodule:: sitator.dynamics + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.dynamics.AverageVibrationalFrequency + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.dynamics.JumpAnalysis + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.dynamics.MergeSitesByDynamics + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.dynamics.MergeSitesByThreshold + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.dynamics.RemoveShortJumps + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.dynamics.RemoveUnoccupiedSites + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.landmark.cluster.rst b/docs/source/sitator.landmark.cluster.rst new file mode 100644 index 0000000..c332d86 --- /dev/null +++ b/docs/source/sitator.landmark.cluster.rst @@ -0,0 +1,30 @@ +sitator.landmark.cluster package +================================ + +Submodules +---------- + +sitator.landmark.cluster.dotprod module +--------------------------------------- + +.. automodule:: sitator.landmark.cluster.dotprod + :members: + :undoc-members: + :show-inheritance: + +sitator.landmark.cluster.mcl module +----------------------------------- + +.. automodule:: sitator.landmark.cluster.mcl + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: sitator.landmark.cluster + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.landmark.rst b/docs/source/sitator.landmark.rst new file mode 100644 index 0000000..8bf446a --- /dev/null +++ b/docs/source/sitator.landmark.rst @@ -0,0 +1,42 @@ +sitator.landmark package +======================== + +Subpackages +----------- + +.. toctree:: + + sitator.landmark.cluster + +Submodules +---------- + +sitator.landmark.errors module +------------------------------ + +.. automodule:: sitator.landmark.errors + :members: + :undoc-members: + :show-inheritance: + +sitator.landmark.helpers module +------------------------------- + +.. automodule:: sitator.landmark.helpers + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: sitator.landmark + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.landmark.LandmarkAnalysis + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.misc.rst b/docs/source/sitator.misc.rst new file mode 100644 index 0000000..3b441d0 --- /dev/null +++ b/docs/source/sitator.misc.rst @@ -0,0 +1,34 @@ +sitator.misc package +==================== + +sitator.misc.oldio module +------------------------- + +.. automodule:: sitator.misc.oldio + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: sitator.misc + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.misc.GenerateAroundSites + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.misc.GenerateClampedTrajectory + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.misc.NAvgsPerSite + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.network.rst b/docs/source/sitator.network.rst new file mode 100644 index 0000000..be7fb28 --- /dev/null +++ b/docs/source/sitator.network.rst @@ -0,0 +1,25 @@ +sitator.network package +======================= + +Module contents +--------------- + +.. automodule:: sitator.network + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.network.DiffusionPathwayAnalysis + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.network.MergeSitesByBarrier + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.network.merging + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.rst b/docs/source/sitator.rst new file mode 100644 index 0000000..8662328 --- /dev/null +++ b/docs/source/sitator.rst @@ -0,0 +1,38 @@ +sitator package +=============== + +Subpackages +----------- + +.. toctree:: + + sitator.descriptors + sitator.dynamics + sitator.landmark + sitator.misc + sitator.network + sitator.site_descriptors + sitator.util + sitator.visualization + sitator.voronoi + +Submodules +---------- + +Module contents +--------------- + +.. automodule:: sitator + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.SiteNetwork + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.SiteTrajectory + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.site_descriptors.backend.rst b/docs/source/sitator.site_descriptors.backend.rst new file mode 100644 index 0000000..a7afe2f --- /dev/null +++ b/docs/source/sitator.site_descriptors.backend.rst @@ -0,0 +1,30 @@ +sitator.site\_descriptors.backend package +========================================= + +Submodules +---------- + +sitator.site\_descriptors.backend.dscribe module +------------------------------------------------ + +.. automodule:: sitator.site_descriptors.backend.dscribe + :members: + :undoc-members: + :show-inheritance: + +sitator.site\_descriptors.backend.quip module +--------------------------------------------- + +.. automodule:: sitator.site_descriptors.backend.quip + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: sitator.site_descriptors.backend + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.site_descriptors.rst b/docs/source/sitator.site_descriptors.rst new file mode 100644 index 0000000..4275695 --- /dev/null +++ b/docs/source/sitator.site_descriptors.rst @@ -0,0 +1,43 @@ +sitator.site\_descriptors package +================================= + +Subpackages +----------- + +.. toctree:: + + sitator.site_descriptors.backend + +Submodules +---------- + +sitator.site\_descriptors.SOAP module +------------------------------------- + +.. automodule:: sitator.site_descriptors.SOAP + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: sitator.site_descriptors + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.site_descriptors.SiteCoordinationEnvironment + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.site_descriptors.SiteTypeAnalysis + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.site_descriptors.SiteVolumes + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.util.rst b/docs/source/sitator.util.rst new file mode 100644 index 0000000..1f95c20 --- /dev/null +++ b/docs/source/sitator.util.rst @@ -0,0 +1,61 @@ +sitator.util package +==================== + +Submodules +---------- + +sitator.util.elbow module +------------------------- + +.. automodule:: sitator.util.elbow + :members: + :undoc-members: + :show-inheritance: + +sitator.util.mcl module +----------------------- + +.. automodule:: sitator.util.mcl + :members: + :undoc-members: + :show-inheritance: + +sitator.util.progress module +---------------------------- + +.. automodule:: sitator.util.progress + :members: + :undoc-members: + :show-inheritance: + +sitator.util.zeo module +----------------------- + +.. automodule:: sitator.util.zeo + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: sitator.util + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.util.DotProdClassifier + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.util.PBCCalculator + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.util.RecenterTrajectory + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.visualization.rst b/docs/source/sitator.visualization.rst new file mode 100644 index 0000000..9596867 --- /dev/null +++ b/docs/source/sitator.visualization.rst @@ -0,0 +1,40 @@ +sitator.visualization package +============================= + +Submodules +---------- + +sitator.visualization.atoms module +---------------------------------- + +.. automodule:: sitator.visualization.atoms + :members: + :undoc-members: + :show-inheritance: + +sitator.visualization.common module +----------------------------------- + +.. automodule:: sitator.visualization.common + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: sitator.visualization + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.visualization.SiteNetworkPlotter + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sitator.visualization.SiteTrajectoryPlotter + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/sitator.voronoi.rst b/docs/source/sitator.voronoi.rst new file mode 100644 index 0000000..adad362 --- /dev/null +++ b/docs/source/sitator.voronoi.rst @@ -0,0 +1,10 @@ +sitator.voronoi package +======================= + +Module contents +--------------- + +.. automodule:: sitator.voronoi + :members: + :undoc-members: + :show-inheritance: diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e719012 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +numpy +Cython +scipy +matplotlib +ase +tqdm +sklearn diff --git a/setup.py b/setup.py index f6067cd..9194cd1 100644 --- a/setup.py +++ b/setup.py @@ -1,33 +1,43 @@ from setuptools import setup, find_packages from Cython.Build import cythonize +import Cython.Compiler import numpy as np -setup(name = 'sitator', - version = '1.0.1', - description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.', - download_url = "https://github.com/Linux-cpp-lisp/sitator", - author = 'Alby Musaelian', - license = "MIT", - python_requires = '>=2.7, <3', - packages = find_packages(), - ext_modules = cythonize([ +# Allows cimport'ing PBCCalculator +Cython.Compiler.Options.cimport_from_pyx = True + +setup( + name = 'sitator', + version = '2.0.0', + description = 'Unsupervised landmark analysis for jump detection in molecular dynamics simulations.', + download_url = "https://github.com/Linux-cpp-lisp/sitator", + author = 'Alby Musaelian', + license = "MIT", + python_requires = '>=3.2', + packages = find_packages(), + ext_modules = cythonize( + [ "sitator/landmark/helpers.pyx", - "sitator/util/*.pyx" - ]), - include_dirs=[np.get_include()], - install_requires = [ + "sitator/util/*.pyx", + "sitator/dynamics/*.pyx", + "sitator/misc/*.pyx" + ], + language_level = 3 + ), + include_dirs=[np.get_include()], + install_requires = [ "numpy", "scipy", "matplotlib", "ase", "tqdm", - "backports.tempfile", - "future", "sklearn" - ], - extras_require = { + ], + extras_require = { "SiteTypeAnalysis" : [ - "pydpc" + "pydpc", + "dscribe" ] - }, - zip_safe = True) + }, + zip_safe = True +) diff --git a/sitator/SiteNetwork.py b/sitator/SiteNetwork.py index 28909dc..bd71be4 100644 --- a/sitator/SiteNetwork.py +++ b/sitator/SiteNetwork.py @@ -1,14 +1,11 @@ -from __future__ import (absolute_import, division, - print_function, unicode_literals) -from builtins import * - import numpy as np import re import os import tarfile -from backports import tempfile +import tempfile +import ase import ase.io import matplotlib @@ -17,19 +14,33 @@ class SiteNetwork(object): """A network of mobile particle sites in a static lattice. - Stores the locations of sites (`centers`), their defining static atoms (`vertices`), - and their "types" (`site_types`). + Stores the locations of sites (``centers``) for some indicated mobile atoms + (``mobile_mask``) in a structure (``structure``). Optionally includes their + defining static atoms (``vertices``) and "types" (``site_types``). Arbitrary data can also be associated with each site and with each edge between sites. Site data can be any array of length n_sites; edge data can be any matrix of shape (n_sites, n_sites) where entry i, j is the value for the - edge from site i to site j. + edge from site i to site j (edge attributes can be asymmetric). + + Attributes can be marked as "computed"; this is a hint that the attribute + was computed based on a ``SiteTrajectory``. Most ``sitator`` algorithms that + modify/process ``SiteTrajectory``s will clear "computed" attrbutes, + assuming that they are invalidated by the changes to the ``SiteTrajectory``. Attributes: centers (ndarray): (n_sites, 3) coordinates of each site. vertices (list, optional): list of lists of indexes of static atoms defining each site. site_types (ndarray, optional): (n_sites,) values grouping sites into types. + + Args: + structure (ase.Atoms): an ASE ``Atoms`` containging whatever atoms exist + in the MD trajectory. + static_mask (ndarray): Boolean mask indicating which atoms make up the + host lattice. + mobile_mask (ndarray): Boolean mask indicating which atoms' movement we + are interested in. """ ATTR_NAME_REGEX = re.compile("^[a-zA-Z][a-zA-Z0-9_]*$") @@ -38,16 +49,6 @@ def __init__(self, structure, static_mask, mobile_mask): - """ - Args: - structure (Atoms): an ASE/Quippy ``Atoms`` containging whatever atoms exist - in the MD trajectory. - static_mask (ndarray): Boolean mask indicating which atoms make up the - host lattice. - mobile_mask (ndarray): Boolean mask indicating which atoms' movement we - are interested in. - """ - assert static_mask.ndim == mobile_mask.ndim == 1, "The masks must be one-dimensional" assert len(structure) == len(static_mask) == len(mobile_mask), "The masks must have the same length as the # of atoms in the strucutre." @@ -72,20 +73,54 @@ def __init__(self, self._site_attrs = {} self._edge_attrs = {} + self._attr_computed = {} + + def copy(self, with_computed = True): + """Returns a (shallowish) copy of self. - def copy(self): - """Returns a (shallowish) copy of self.""" + Args: + with_computed (bool): If ``False``, attributes marked "computed" will + not be included in the copy. + Returns: + A ``SiteNetwork``. + """ # Use a mask to force a copy msk = np.ones(shape = self.n_sites, dtype = np.bool) - return self[msk] + sn = self[msk] + if not with_computed: + sn.clear_computed_attributes() + return sn def __len__(self): return self.n_sites + def __str__(self): + return ( + "{}: {:d} sites for {:d} mobile particles in static lattice of {:d} particles\n" + " Has vertices: {}\n" + " Has types: {}\n" + " Has site attributes: {}\n" + " Has edge attributes: {}\n" + "" + ).format( + type(self).__name__, + self.n_sites, + self.n_mobile, + self.n_static, + self._vertices is not None, + self._types is not None, + ", ".join(self._site_attrs.keys()), + ", ".join(self._edge_attrs.keys()) + ) + def __getitem__(self, key): - sn = type(self)(self.structure, - self.static_mask, - self.mobile_mask) + sn = self.__new__(type(self)) + SiteNetwork.__init__( + sn, + self.structure, + self.static_mask, + self.mobile_mask + ) if not self._centers is None: sn.centers = self._centers[key] @@ -108,80 +143,14 @@ def __getitem__(self, key): return sn - _STRUCT_FNAME = "structure.xyz" - _SMASK_FNAME = "static_mask.npy" - _MMASK_FNAME = "mobile_mask.npy" - _MAIN_FNAMES = ['centers', 'vertices', 'site_types'] - - def save(self, file): - """Save this SiteNetwork to a tar archive.""" - with tempfile.TemporaryDirectory() as tmpdir: - # -- Write the structure - ase.io.write(os.path.join(tmpdir, self._STRUCT_FNAME), self.structure) - # -- Write masks - np.save(os.path.join(tmpdir, self._SMASK_FNAME), self.static_mask) - np.save(os.path.join(tmpdir, self._MMASK_FNAME), self.mobile_mask) - # -- Write what we have - for arrname in self._MAIN_FNAMES: - if not getattr(self, arrname) is None: - np.save(os.path.join(tmpdir, "%s.npy" % arrname), getattr(self, arrname)) - # -- Write all site/edge attributes - for atype, attrs in zip(("site_attr", "edge_attr"), (self._site_attrs, self._edge_attrs)): - for attr in attrs: - np.save(os.path.join(tmpdir, "%s-%s.npy" % (atype, attr)), attrs[attr]) - # -- Write final archive - with tarfile.open(file, mode = 'w:gz', format = tarfile.PAX_FORMAT) as outf: - outf.add(tmpdir, arcname = "") - - @classmethod - def from_file(cls, file): - """Load a SiteNetwork from a tar file/file descriptor.""" - all_others = {} - site_attrs = {} - edge_attrs = {} - structure = None - with tarfile.open(file, mode = 'r:gz', format = tarfile.PAX_FORMAT) as input: - # -- Load everything - for member in input.getmembers(): - if member.name == '': - continue - f = input.extractfile(member) - if member.name == cls._STRUCT_FNAME: - with tempfile.TemporaryDirectory() as tmpdir: - input.extract(member, path = tmpdir) - structure = ase.io.read(os.path.join(tmpdir, member.name), format = 'xyz') - else: - basename = os.path.splitext(os.path.basename(member.name))[0] - data = np.load(f) - if basename.startswith("site_attr"): - site_attrs[basename.split('-')[1]] = data - elif basename.startswith("edge_attr"): - edge_attrs[basename.split('-')[1]] = data - else: - all_others[basename] = data - - # Create SiteNetwork - assert not structure is None - assert all(k in all_others for k in ("static_mask", "mobile_mask")), "Malformed SiteNetwork file." - sn = SiteNetwork(structure, - all_others['static_mask'], - all_others['mobile_mask']) - if 'centers' in all_others: - sn.centers = all_others['centers'] - for key in all_others: - if key in ('centers', 'static_mask', 'mobile_mask'): - continue - setattr(sn, key, all_others[key]) - - assert all(len(sa) == sn.n_sites for sa in site_attrs.values()) - assert all(ea.shape == (sn.n_sites, sn.n_sites) for ea in edge_attrs.values()) - sn._site_attrs = site_attrs - sn._edge_attrs = edge_attrs - - return sn - def of_type(self, stype): - """Returns a "view" to this SiteNetwork with only sites of a certain type.""" + """Returns a subset of this ``SiteNetwork`` with only sites of a certain type. + + Args: + stype (int) + Returns: + A ``SiteNetwork``. + """ if self._types is None: raise ValueError("This SiteNetwork has no type information.") @@ -190,18 +159,45 @@ def of_type(self, stype): return self[self._types == stype] + def get_structure_with_sites(self, site_atomic_number = None): + """Get an ``ase.Atoms`` with the sites included. + + Sites are appended to the static structure; the first ``np.sum(static_mask)`` + atoms in the returned object are the static structure. + + Args: + site_atomic_number: If ``None``, the species of the first mobile + atom will be used. + Returns: + ``ase.Atoms`` and final ``site_atomic_number`` + """ + out = self.static_structure.copy() + if site_atomic_number is None: + site_atomic_number = self.structure.get_atomic_numbers()[self.mobile_mask][0] + numbers = np.full(len(self), site_atomic_number) + sites_atoms = ase.Atoms( + positions = self.centers, + numbers = numbers + ) + site_idexes = len(out) + np.arange(self.n_sites) + out.extend(sites_atoms) + return out, site_atomic_number + @property def n_sites(self): + """The number of sites.""" if self._centers is None: return 0 return len(self._centers) @property def n_total(self): + """The total number of atoms in the system.""" return len(self.static_mask) @property def centers(self): + """The positions of the sites.""" view = self._centers.view() view.flags.writeable = False return view @@ -215,27 +211,47 @@ def centers(self, value): self._types = None self._site_attrs = {} self._edge_attrs = {} + self._attr_computed = {} # Set centers self._centers = value def update_centers(self, newcenters): - """Update the SiteNetwork's centers *without* reseting all other information.""" + """Update the ``SiteNetwork``'s centers *without* reseting all other information. + + Args: + newcenters (ndarray): Must have same length as current number of sites. + """ if newcenters.shape != self._centers.shape: raise ValueError("New `centers` must have same shape as old; try using the setter `.centers = ...`") self._centers = newcenters @property def vertices(self): + """The static atoms defining each site.""" return self._vertices + @property + def site_ids(self): + """Convenience property giving the index of each site.""" + return np.arange(self.n_sites) + @vertices.setter def vertices(self, value): if not len(value) == len(self._centers): raise ValueError("Wrong # of vertices %i; expected %i" % (len(value), len(self._centers))) self._vertices = value + @property + def number_of_vertices(self): + """The number of vertices of each site.""" + if self._vertices is None: + return None + else: + return [len(v) for v in self._vertices] + @property def site_types(self): + """The type IDs of each site.""" if self._types is None: return None view = self._types.view() @@ -250,24 +266,40 @@ def site_types(self, value): @property def n_types(self): + """The number of site types in the ``SiteNetwork``.""" return len(np.unique(self.site_types)) @property def types(self): + """The unique site type IDs in the ``SiteNetwork``.""" return np.unique(self.site_types) @property def site_attributes(self): - return self._site_attrs.keys() + """The names of the ``SiteNetwork``'s site attributes.""" + return list(self._site_attrs.keys()) @property def edge_attributes(self): - return self._edge_attrs.keys() + """The names of the ``SiteNetwork``'s edge attributes.""" + return list(self._edge_attrs.keys()) def has_attribute(self, attr): + """Whether the ``SiteNetwork`` has a given site or edge attrbute. + + Args: + attr (str) + Returns: + bool + """ return (attr in self._site_attrs) or (attr in self._edge_attrs) def remove_attribute(self, attr): + """Remove a site or edge attribute. + + Args: + attr (str) + """ if attr in self._site_attrs: del self._site_attrs[attr] elif attr in self._edge_attrs: @@ -275,10 +307,22 @@ def remove_attribute(self, attr): else: raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attr) + def clear_attributes(self): + """Remove all site and edge attributes.""" + self._site_attrs = {} + self._edge_attrs = {} + + def clear_computed_attributes(self): + """Remove all attributes marked "computed".""" + for k, computed in self._attr_computed.items(): + if computed: + self.remove_attribute(k) + def __getattr__(self, attrkey): - if attrkey in self._site_attrs: + v = vars(self) + if '_site_attrs' in v and attrkey in self._site_attrs: return self._site_attrs[attrkey] - elif attrkey in self._edge_attrs: + elif '_edge_attrs' in v and attrkey in self._edge_attrs: return self._edge_attrs[attrkey] else: raise AttributeError("This SiteNetwork has no site or edge attribute `%s`" % attrkey) @@ -320,21 +364,37 @@ def get_edge(self, edge): return out - def add_site_attribute(self, name, attr): + def add_site_attribute(self, name, attr, computed = True): + """Add a site attribute. + + Args: + name (str) + attr (ndarray): Must be of length ``n_sites``. + computed (bool): Whether to mark this attribute as "computed". + """ self._check_name(name) attr = np.asarray(attr) if not attr.shape[0] == self.n_sites: raise ValueError("Attribute array has only %i entries; need one for all %i sites." % (len(attr), self.n_sites)) self._site_attrs[name] = attr + self._attr_computed[name] = computed - def add_edge_attribute(self, name, attr): + def add_edge_attribute(self, name, attr, computed = True): + """Add an edge attribute. + + Args: + name (str) + attr (ndarray): Must be of shape ``(n_sites, n_sites)``. + computed (bool): Whether to mark this attribute as "computed". + """ self._check_name(name) attr = np.asarray(attr) if not (attr.shape[0] == attr.shape[1] == self.n_sites): raise ValueError("Attribute matrix has shape %s; need first two dimensions to be %i" % (attr.shape, self.n_sites)) self._edge_attrs[name] = attr + self._attr_computed[name] = computed def _check_name(self, name): if not SiteNetwork.ATTR_NAME_REGEX.match(name): @@ -345,6 +405,6 @@ def _check_name(self, name): raise ValueError("Attribute name `%s` reserved." % name) def plot(self, *args, **kwargs): - """Convenience method -- constructs a defualt SiteNetworkPlotter and calls it.""" + """Convenience method -- constructs a defualt ``SiteNetworkPlotter`` and calls it.""" p = SiteNetworkPlotter(title = "Sites") p(self, *args, **kwargs) diff --git a/sitator/SiteTrajectory.py b/sitator/SiteTrajectory.py index f934edd..a994e28 100644 --- a/sitator/SiteTrajectory.py +++ b/sitator/SiteTrajectory.py @@ -1,14 +1,11 @@ -from __future__ import (absolute_import, division, - print_function, unicode_literals) -from builtins import * - import numpy as np from sitator.util import PBCCalculator -from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS +from sitator.visualization import SiteTrajectoryPlotter +from sitator.util.progress import tqdm -import matplotlib -from matplotlib.collections import LineCollection +import logging +logger = logging.getLogger(__name__) class SiteTrajectory(object): """A trajectory capturing the dynamics of particles through a SiteNetwork.""" @@ -42,6 +39,8 @@ def __init__(self, self._real_traj = None + self._default_plotter = None + def __len__(self): return self.n_frames @@ -51,28 +50,45 @@ def __getitem__(self, key): confidences = None if self._confs is None else self._confs[key]) if not self._real_traj is None: st.set_real_traj(self._real_traj[key]) - return st + def __getstate__(self): + # Copy the object's state from self.__dict__ which contains + # all our instance attributes. Always use the dict.copy() + # method to avoid modifying the original state. + state = self.__dict__.copy() + # Don't want to pickle giant trajectories or uninteresting plotters + state['_real_traj'] = None + state['_default_plotter'] = None + return state + @property def traj(self): - """The underlying trajectory.""" + """The site assignments over time.""" return self._traj + @property + def confidences(self): + return self._confs + @property def n_frames(self): + """The number of frames in the trajectory.""" return len(self._traj) @property def n_unassigned(self): + """The total number of times a mobile particle is unassigned.""" return np.sum(self._traj < 0) @property def n_assigned(self): + """The total number of times a mobile particle was assigned to a site.""" return self._sn.n_mobile * self.n_frames - self.n_unassigned @property def percent_unassigned(self): + """Proportion of particle positions that are unassigned over all time.""" return float(self.n_unassigned) / (self._sn.n_mobile * self.n_frames) @property @@ -88,25 +104,49 @@ def site_network(self, value): @property def real_trajectory(self): + """The real-space trajectory this ``SiteTrajectory`` is based on.""" return self._real_traj + def copy(self, with_computed = True): + """Return a copy. + + Args: + with_computed (bool): See ``SiteNetwork.copy()``. + """ + st = self[:] + st.site_network = st.site_network.copy(with_computed = with_computed) + return st + def set_real_traj(self, real_traj): """Assocaite this SiteTrajectory with a trajectory of points in real space. - The trajectory is not copied, and should have shape (n_frames, n_total) + The trajectory is not copied. + + Args: + real_traj (ndarray of shape (n_frames, n_total)) """ expected_shape = (self.n_frames, self._sn.n_total, 3) if not real_traj.shape == expected_shape: raise ValueError("real_traj of shape %s does not have expected shape %s" % (real_traj.shape, expected_shape)) self._real_traj = real_traj + def remove_real_traj(self): """Forget associated real trajectory.""" del self._real_traj self._real_traj = None + def trajectory_for_particle(self, i, return_confidences = False): - """Returns the array of sites particle i is assigned to over time.""" + """Returns the array of sites particle i is assigned to over time. + + Args: + i (int) + return_confidences (bool): If ``True``, also return the confidences + with which those assignments were made. + Returns: + ndarray (int) of length ``n_frames``[, ndarray (float) length ``n_frames``] + """ if return_confidences and self._confs is None: raise ValueError("This SiteTrajectory has no confidences") if return_confidences: @@ -114,7 +154,19 @@ def trajectory_for_particle(self, i, return_confidences = False): else: return self._traj[:, i] + def real_positions_for_site(self, site, return_confidences = False): + """Get all real-space positions assocated with a site. + + Args: + site (int) + return_confidences (bool): If ``True``, the confidences with which + each real-space position was assigned to ``site`` are also + returned. + + Returns: + ndarray (N, 3)[, ndarray (N)] + """ if self._real_traj is None: raise ValueError("This SiteTrajectory has no real trajectory") if return_confidences and self._confs is None: @@ -131,22 +183,69 @@ def real_positions_for_site(self, site, return_confidences = False): else: return pts + def compute_site_occupancies(self): - """Computes site occupancies and adds site attribute `occupancies` to site_network.""" - occ = np.true_divide(np.bincount(self._traj[self._traj >= 0]), self.n_frames) + """Computes site occupancies. + + Adds site attribute ``occupancies`` to ``site_network``. + + In cases of multiple occupancy, this will be higher than the number of + frames in which the site is occupied and could be over 1.0. + + Returns: + ndarray of occupancies (length ``n_sites``) + """ + occ = np.true_divide(np.bincount(self._traj[self._traj >= 0], minlength = self._sn.n_sites), self.n_frames) + if self.site_network.has_attribute('occupancies'): + self.site_network.remove_attribute('occupancies') self.site_network.add_site_attribute('occupancies', occ) return occ - def assign_to_last_known_site(self, frame_threshold = 1, verbose = True): - """Assign unassigned mobile particles to their last known site within - `frame_threshold` frames. - :returns: information dictionary of debugging/diagnostic information. + def check_multiple_occupancy(self, max_mobile_per_site = 1): + """Count cases of "multiple occupancy" where more than one mobile share the same site at the same time. + + These cases usually indicate bad site analysis. + + Returns: + int: the total number of multiple assignment incidents; and + float: the average number of mobile atoms at any site at any one time. + """ + from sitator.errors import MultipleOccupancyError + n_more_than_ones = 0 + avg_mobile_per_site = 0 + divisor = 0 + for frame_i, site_frame in enumerate(self._traj): + sites, counts = np.unique(site_frame[site_frame >= 0], return_counts = True) + count_msk = counts > max_mobile_per_site + if np.any(count_msk): + first_multi_site = sites[count_msk][0] + raise MultipleOccupancyError( + mobile = np.where(site_frame == first_multi_site)[0], + site = first_multi_site, + frame = frame_i + ) + n_more_than_ones += np.sum(counts > 1) + avg_mobile_per_site += np.sum(counts) + divisor += len(counts) + avg_mobile_per_site /= divisor + return n_more_than_ones, avg_mobile_per_site + + + def assign_to_last_known_site(self, frame_threshold = 1): + """Assign unassigned mobile particles to their last known site. + + Args: + frame_threshold (int): The maximum number of frames between the last + known site and the present frame up to which the last known site + can be used. + + Returns: + information dictionary of debugging/diagnostic information. """ total_unknown = self.n_unassigned - if verbose: - print("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %i frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold)) + logger.info("%i unassigned positions (%i%%); assigning unassigned mobile particles to last known positions within %s frames..." % (total_unknown, 100.0 * self.percent_unassigned, frame_threshold)) last_known = np.empty(shape = self._sn.n_mobile, dtype = np.int) last_known.fill(-1) @@ -156,7 +255,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True): max_time_unknown = 0 total_reassigned = 0 - for i in xrange(self.n_frames): + for i in range(self.n_frames): # All those unknown this frame unknown = self._traj[i] == -1 # Update last_known for assigned sites @@ -184,10 +283,9 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True): if avg_time_unknown_div > 0: # We corrected some unknowns avg_time_unknown = float(avg_time_unknown) / avg_time_unknown_div - if verbose: - print(" Maximum # of frames any mobile particle spent unassigned: %i" % max_time_unknown) - print(" Avg. # of frames spent unassigned: %f" % avg_time_unknown) - print(" Assigned %i/%i unassigned positions, leaving %i (%i%%) unknown" % (total_reassigned, total_unknown, self.n_unassigned, self.percent_unassigned)) + logger.info(" Maximum # of frames any mobile particle spent unassigned: %i" % max_time_unknown) + logger.info(" Avg. # of frames spent unassigned: %f" % avg_time_unknown) + logger.info(" Assigned %i/%i unassigned positions, leaving %i (%i%%) unknown" % (total_reassigned, total_unknown, self.n_unassigned, self.percent_unassigned)) res = { 'max_time_unknown' : max_time_unknown, @@ -195,8 +293,7 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True): 'total_reassigned' : total_reassigned } else: - if self.verbose: - print(" None to correct.") + logger.info(" None to correct.") res = { 'max_time_unknown' : 0, @@ -206,119 +303,87 @@ def assign_to_last_known_site(self, frame_threshold = 1, verbose = True): return res - @plotter(is3D = True) - def plot_frame(self, frame, **kwargs): - sites_of_frame = np.unique(self._traj[frame]) - frame_sn = self._sn[sites_of_frame] - frame_sn.plot(**kwargs) + def jumps(self, **kwargs): + """Iterate over all jumps in the trajectory, jump by jump. - if not self._real_traj is None: - mobile_atoms = self._sn.structure.copy() - del mobile_atoms[~self._sn.mobile_mask] - - mobile_atoms.positions[:] = self._real_traj[frame, self._sn.mobile_mask] - plot_atoms(atoms = mobile_atoms, **kwargs) - - kwargs['ax'].set_title("Frame %i/%i" % (frame, self.n_frames)) - - @plotter(is3D = True) - def plot_site(self, site, **kwargs): - pbcc = PBCCalculator(self._sn.structure.cell) - pts = self.real_positions_for_site(site).copy() - offset = pbcc.cell_centroid - pts[3] - pts += offset - pbcc.wrap_points(pts) - lattice_pos = self._sn.static_structure.positions.copy() - lattice_pos += offset - pbcc.wrap_points(lattice_pos) - site_pos = self._sn.centers[site:site+1].copy() - site_pos += offset - pbcc.wrap_points(site_pos) - # Plot point cloud - plot_points(points = pts, alpha = 0.3, marker = '.', color = 'k', **kwargs) - # Plot site - plot_points(points = site_pos, color = 'cyan', **kwargs) - # Plot everything else - plot_atoms(self._sn.static_structure, positions = lattice_pos, **kwargs) - - title = "Site %i/%i" % (site, len(self._sn)) - - if not self._sn.site_types is None: - title += " (type %i)" % self._sn.site_types[site] - - kwargs['ax'].set_title(title) - - @plotter(is3D = False) - def plot_particle_trajectory(self, particle, ax = None, fig = None, **kwargs): - types = not self._sn.site_types is None - if types: - type_height_percent = 0.1 - axpos = ax.get_position() - typeax_height = type_height_percent * axpos.height - typeax = fig.add_axes([axpos.x0, axpos.y0, axpos.width, typeax_height], sharex = ax) - ax.set_position([axpos.x0, axpos.y0 + typeax_height, axpos.width, axpos.height - typeax_height]) - type_height = 1 - # Draw trajectory - segments = [] - linestyles = [] - colors = [] - - traj = self._traj[:, particle] - current_value = traj[0] - last_value = traj[0] - if types: - last_type = None - current_segment_start = 0 - puttext = False - - for i, f in enumerate(traj): - if f != current_value or i == len(traj) - 1: - val = last_value if current_value == -1 else current_value - segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]]) - linestyles.append(':' if current_value == -1 else '-') - colors.append('lightgray' if current_value == -1 else 'k') - - if types: - rxy = (current_segment_start, 0) - this_type = self._sn.site_types[val] - typerect = matplotlib.patches.Rectangle(rxy, i - current_segment_start, type_height, - color = DEFAULT_COLORS[this_type], linewidth = 0) - typeax.add_patch(typerect) - if this_type != last_type: - typeax.annotate("T%i" % this_type, - xy = (rxy[0], rxy[1] + 0.5 * type_height), - xytext = (3, -1), - textcoords = 'offset points', - fontsize = 'xx-small', - va = 'center', - fontweight = 'bold') - last_type = this_type - - last_value = val - current_segment_start = i - current_value = f - - lc = LineCollection(segments, linestyles = linestyles, colors = colors, linewidth=1.5) - ax.add_collection(lc) - - if types: - typeax.set_xlabel("Frame") - ax.tick_params(axis = 'x', which = 'both', bottom = False, top = False, labelbottom = False) - typeax.tick_params(axis = 'y', which = 'both', left = False, right = False, labelleft = False) - typeax.annotate("Type", xy = (0, 0.5), xytext = (-25, 0), xycoords = 'axes fraction', textcoords = 'offset points', va = 'center', fontsize = 'x-small') - else: - ax.set_xlabel("Frame") - ax.set_ylabel("Atom %i's site" % particle) + A jump is considered to occur "at the frame" when it first acheives its + new site. For example, - ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) - ax.grid() + - Frame 0: Atom 1 at site 4 + - Frame 1: Atom 1 at site 5 - ax.set_xlim((0, self.n_frames - 1)) - margin_percent = 0.04 - ymargin = (margin_percent * self._sn.n_sites) - ax.set_ylim((-ymargin, self._sn.n_sites - 1.0 + ymargin)) + will yield a jump ``(1, 1, 4, 5)``. - if types: - typeax.set_xlim((0, self.n_frames - 1)) - typeax.set_ylim((0, type_height)) + Args: + unknown_as_jump (bool): If ``True``, moving from a site to unknown + (or vice versa) is considered a jump; if ``False``, unassigned + mobile atoms are considered to be at their last known sites. + Yields: + tuple: (frame_number, mobile_atom_number, from_site, to_site) + """ + n_mobile = self.site_network.n_mobile + for frame_i, jumped, last_known, frame in self._jumped_generator(**kwargs): + for atom_i in range(n_mobile): + if jumped[atom_i]: + yield frame_i, atom_i, last_known[atom_i], frame[atom_i] + + def jumps_by_frame(self, **kwargs): + """Iterate over all jumps in the trajectory, frame by frame. + + A jump is considered to occur "at the frame" when it first acheives its + new site. For example, + + - Frame 0: Atom 1 at site 4 + - Frame 1: Atom 1 at site 5 + + will yield a jump ``(1, 1, 4, 5)``. + + Args: + unknown_as_jump (bool): If ``True``, moving from a site to unknown + (or vice versa) is considered a jump; if ``False``, unassigned + mobile atoms are considered to be at their last known sites. + Yields: + tuple: (frame_number, mob_that_jumped, from_sites, to_sites) + """ + n_mobile = self.site_network.n_mobile + for frame_i, jumped, last_known, frame in self._jumped_generator(**kwargs): + yield frame_i, np.where(jumped)[0], last_known[jumped], frame[jumped] + + def _jumped_generator(self, unknown_as_jump = False): + """Internal jump generator that does not create intermediate arrays. + + Wrapped by convinience functions. + """ + traj = self.traj + n_mobile = self.site_network.n_mobile + assert n_mobile == traj.shape[1] + last_known = traj[0].copy() + known = np.ones(shape = len(last_known), dtype = np.bool) + jumped = np.zeros(shape = len(last_known), dtype = np.bool) + for frame_i in range(1, self.n_frames): + if not unknown_as_jump: + np.not_equal(traj[frame_i], SiteTrajectory.SITE_UNKNOWN, out = known) + + np.not_equal(traj[frame_i], last_known, out = jumped) + jumped &= known # Must be currently known to have jumped + + yield frame_i, jumped, last_known, traj[frame_i] + + last_known[known] = traj[frame_i, known] + + # ---- Plotting code + def plot_frame(self, *args, **kwargs): + if self._default_plotter is None: + self._default_plotter = SiteTrajectoryPlotter() + self._default_plotter.plot_frame(self, *args, **kwargs) + + def plot_site(self, *args, **kwargs): + if self._default_plotter is None: + self._default_plotter = SiteTrajectoryPlotter() + self._default_plotter.plot_site(self, *args, **kwargs) + + def plot_particle_trajectory(self, *args, **kwargs): + if self._default_plotter is None: + self._default_plotter = SiteTrajectoryPlotter() + self._default_plotter.plot_particle_trajectory(self, *args, **kwargs) diff --git a/sitator/__init__.py b/sitator/__init__.py index 0f393bc..1ca01b0 100644 --- a/sitator/__init__.py +++ b/sitator/__init__.py @@ -1,3 +1,3 @@ -from SiteNetwork import SiteNetwork -from SiteTrajectory import SiteTrajectory +from .SiteNetwork import SiteNetwork +from .SiteTrajectory import SiteTrajectory diff --git a/sitator/descriptors/ConfigurationalEntropy.py b/sitator/descriptors/ConfigurationalEntropy.py index a1c31ac..4008588 100644 --- a/sitator/descriptors/ConfigurationalEntropy.py +++ b/sitator/descriptors/ConfigurationalEntropy.py @@ -3,21 +3,28 @@ from sitator import SiteTrajectory from sitator.dynamics import JumpAnalysis +import logging +logger = logging.getLogger(__name__) + class ConfigurationalEntropy(object): - """Compute the S~ configurational entropy. + """Compute the S~ (S tilde) configurational entropy. If the SiteTrajectory lacks type information, the summation is taken over the sites rather than the site types. - Ref: - Structural, Chemical, and Dynamical Frustration: Origins of Superionic Conductivity in closo-Borate Solid Electrolytes - Kyoung E. Kweon, Joel B. Varley, Patrick Shea, Nicole Adelstein, Prateek Mehta, Tae Wook Heo, Terrence J. Udovic, Vitalie Stavila, and Brandon C. Wood + Reference: + Structural, Chemical, and Dynamical Frustration: Origins of Superionic + Conductivity in closo-Borate Solid Electrolytes + + Kyoung E. Kweon, Joel B. Varley, Patrick Shea, Nicole Adelstein, + Prateek Mehta, Tae Wook Heo, Terrence J. Udovic, Vitalie Stavila, + and Brandon C. Wood + Chemistry of Materials 2017 29 (21), 9142-9153 DOI: 10.1021/acs.chemmater.7b02902 """ - def __init__(self, acceptable_overshoot = 0.0001, verbose = True): + def __init__(self, acceptable_overshoot = 0.0): self.acceptable_overshoot = acceptable_overshoot - self.verbose = verbose def compute(self, st): assert isinstance(st, SiteTrajectory) @@ -52,16 +59,13 @@ def compute(self, st): size_of_problems = p2 - 1.0 forgivable = problems & (size_of_problems < self.acceptable_overshoot) - if self.verbose: - print "n_i " + ("{:5.3} " * len(n_i)).format(*n_i) - print "N_i " + ("{:>5} " * len(N_i)).format(*N_i) - print " " + ("------" * len(n_i)) - print "P_2 " + ("{:5.3} " * len(p2)).format(*p2) + logger.info("n_i " + ("{:5.3} " * len(n_i)).format(*n_i)) + logger.info("N_i " + ("{:>5} " * len(N_i)).format(*N_i)) + logger.info(" " + ("------" * len(n_i))) + logger.info("P_2 " + ("{:5.3} " * len(p2)).format(*p2)) if not np.all(problems == forgivable): raise ValueError("P_2 values for site types %s larger than 1.0 + acceptable_overshoot (%f)" % (np.where(problems)[0], self.acceptable_overshoot)) - elif np.any(problems) and self.verbose: - print "" # Correct forgivable problems p2[forgivable] = 1.0 diff --git a/sitator/descriptors/__init__.py b/sitator/descriptors/__init__.py index 0f2d416..31d3db7 100644 --- a/sitator/descriptors/__init__.py +++ b/sitator/descriptors/__init__.py @@ -1 +1 @@ -from ConfigurationalEntropy import ConfigurationalEntropy +from .ConfigurationalEntropy import ConfigurationalEntropy diff --git a/sitator/dynamics/AverageVibrationalFrequency.py b/sitator/dynamics/AverageVibrationalFrequency.py new file mode 100644 index 0000000..88b6d34 --- /dev/null +++ b/sitator/dynamics/AverageVibrationalFrequency.py @@ -0,0 +1,61 @@ +import numpy as np + +class AverageVibrationalFrequency(object): + """Compute the average vibrational frequency of indicated atoms in a trajectory. + + Uses the method described in section 2.2 of this paper: + + Klerk, Niek J.J. de, Eveline van der Maas, and Marnix Wagemaker. + + Analysis of Diffusion in Solid-State Electrolytes through MD Simulations, + Improvement of the Li-Ion Conductivity in β-Li3PS4 as an Example. + + ACS Applied Energy Materials 1, no. 7 (July 23, 2018): 3230–42. + https://doi.org/10.1021/acsaem.8b00457. + + Args: + min_frequency (float, units: timestep^-1): Compute mean frequency of power + spectrum above this frequency. + max_frequency (float, units: timestep^-1): Compute mean frequency of power + spectrum below this frequency. + """ + def __init__(self, + min_frequency = 0, + max_frequency = np.inf): + # Always want to exclude DC frequency + assert min_frequency >= 0 + self.min_frequency = min_frequency + self.max_frequency = max_frequency + + def compute_avg_vibrational_freq(self, traj, mask, return_stdev = False): + """Compute the average vibrational frequency. + + Args: + traj (ndarray n_frames x n_atoms x 3): An MD trajectory. + mask (ndarray n_atoms bool): Which atoms to average over. + Returns: + A frequency in units of (timestep)^-1 + """ + speeds = traj[1:, mask] + speeds -= traj[:-1, mask] + speeds = np.linalg.norm(speeds, axis = 2) + + freqs = np.fft.rfftfreq(speeds.shape[0]) + fmask = (freqs > self.min_frequency) & (freqs < self.max_frequency) + assert np.any(fmask), "Trajectory too short?" + freqs = freqs[fmask] + + # de Klerk et. al. do an average over the atom-by-atom averages + n_mob = speeds.shape[1] + avg_freqs = np.empty(shape = n_mob) + + for mob in range(n_mob): + ps = np.abs(np.fft.rfft(speeds[:, mob])) ** 2 + avg_freqs[mob] = np.average(freqs, weights = ps[fmask]) + + avg_freq = np.mean(avg_freqs) + + if return_stdev: + return avg_freq, np.std(avg_freqs) + else: + return avg_freq diff --git a/sitator/dynamics/DiffusionPathwayAnalysis.py b/sitator/dynamics/DiffusionPathwayAnalysis.py deleted file mode 100644 index 245f278..0000000 --- a/sitator/dynamics/DiffusionPathwayAnalysis.py +++ /dev/null @@ -1,78 +0,0 @@ - -import numpy as np - -import numbers - -from scipy.sparse.csgraph import connected_components - -class DiffusionPathwayAnalysis(object): - """Find connected diffusion pathways in a SiteNetwork. - - :param float|int connectivity_threshold: The percentage of the total number of - (non-self) jumps, or absolute number of jumps, that must occur over an edge - for it to be considered connected. - :param int minimum_n_sites: The minimum number of sites that must be part of - a pathway for it to be considered as such. - """ - - NO_PATHWAY = -1 - - def __init__(self, - connectivity_threshold = 0.001, - minimum_n_sites = 4, - verbose = True): - assert minimum_n_sites >= 0 - - self.connectivity_threshold = connectivity_threshold - self.minimum_n_sites = minimum_n_sites - - self.verbose = verbose - - def run(self, sn): - """ - Expects a SiteNetwork that has had a JumpAnalysis run on it. - """ - if not sn.has_attribute('n_ij'): - raise ValueError("SiteNetwork has no `n_ij`; run a JumpAnalysis on it first.") - - nondiag = np.ones(shape = sn.n_ij.shape, dtype = np.bool) - np.fill_diagonal(nondiag, False) - n_non_self_jumps = np.sum(sn.n_ij[nondiag]) - - if isinstance(self.connectivity_threshold, numbers.Integral): - threshold = self.connectivity_threshold - elif isinstance(self.connectivity_threshold, numbers.Real): - threshold = self.connectivity_threshold * n_non_self_jumps - else: - raise TypeError("Don't know how to interpret connectivity_threshold `%s`" % self.connectivity_threshold) - - connectivity_matrix = sn.n_ij >= threshold - - n_ccs, ccs = connected_components(connectivity_matrix, - directed = False, # even though the matrix is symmetric - connection = 'weak') # diffusion could be unidirectional - - _, counts = np.unique(ccs, return_counts = True) - - is_pathway = counts >= self.minimum_n_sites - - if self.verbose: - print "Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps) - print "Found %i connected components, of which %i are large enough to qualify as pathways." % (n_ccs, np.sum(is_pathway)) - - translation = np.empty(n_ccs, dtype = np.int) - translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY - translation[is_pathway] = np.arange(np.sum(is_pathway)) - - node_pathways = translation[ccs] - - outmat = np.empty(shape = (sn.n_sites, sn.n_sites), dtype = np.int) - - for i in xrange(sn.n_sites): - rowmask = node_pathways[i] == node_pathways - outmat[i, rowmask] = node_pathways[i] - outmat[i, ~rowmask] = DiffusionPathwayAnalysis.NO_PATHWAY - - sn.add_site_attribute('site_diffusion_pathway', node_pathways) - sn.add_edge_attribute('edge_diffusion_pathway', outmat) - return sn diff --git a/sitator/dynamics/JumpAnalysis.py b/sitator/dynamics/JumpAnalysis.py index af5085c..945fd7d 100644 --- a/sitator/dynamics/JumpAnalysis.py +++ b/sitator/dynamics/JumpAnalysis.py @@ -5,31 +5,39 @@ from sitator import SiteNetwork, SiteTrajectory from sitator.visualization import plotter, plot_atoms, layers +import logging +logger = logging.getLogger(__name__) + class JumpAnalysis(object): """Given a SiteTrajectory, compute various statistics about the jumps it contains. Adds these edge attributes to the SiteTrajectory's SiteNetwork: - - `n_ij`: total number of jumps from i to j. - - `p_ij`: being at i, the probability of jumping to j. - - `jump_lag`: The average number of frames a particle spends at i before jumping + - ``n_ij``: total number of jumps from i to j. + - ``p_ij``: being at i, the probability of jumping to j. + - ``jump_lag``: The average number of frames a particle spends at i before jumping to j. Can be +inf if no such jumps every occur. And these site attributes: - - `residence_times`: Avg. number of frames a particle spends at a site before jumping. - - `total_corrected_residences`: Total number of frames when a particle was at the site, + - ``residence_times``: Avg. number of frames a particle spends at a site before jumping. + - ``total_corrected_residences``: Total number of frames when a particle was at the site, *including* frames when an unassigned particle's last known site was this site. """ - def __init__(self, verbose = True): - self.verbose = verbose + def __init__(self): + pass def run(self, st): """Run the analysis. - Adds edge attributes to st's SiteNetwork and returns st. + Adds edge attributes to ``st``'s ``SiteNetwork``. + + Args: + st (SiteTrajectory) + + Returns: + ``st`` """ assert isinstance(st, SiteTrajectory) - if self.verbose: - print "Running JumpAnalysis..." + logger.info("Running JumpAnalysis...") n_mobile = st.site_network.n_mobile n_frames = st.n_frames @@ -60,17 +68,13 @@ def run(self, st): unassigned = frame == SiteTrajectory.SITE_UNKNOWN # Reassign unassigned frame[unassigned] = last_known[unassigned] - fknown = frame >= 0 + fknown = (frame >= 0) & (last_known >= 0) - if np.any(~fknown) and self.verbose: - print " at frame %i, %i uncorrectable unassigned particles" % (i, np.sum(~fknown)) + n_problems += np.sum(~fknown) # -- Update stats total_time_spent_at_site[frame[fknown]] += 1 jumped = (frame != last_known) & fknown - problems = last_known[jumped] == -1 - jumped[np.where(jumped)[0][problems]] = False - n_problems += np.sum(problems) n_ij[last_known[fknown], frame[fknown]] += 1 @@ -94,8 +98,8 @@ def run(self, st): # The time before jumping to self should always be inf assert not np.any(np.nonzero(avg_time_before_jump.diagonal())) - if self.verbose and n_problems != 0: - print "Came across %i times where assignment and last known assignment were unassigned." % n_problems + if n_problems != 0: + logger.warning("Came across %i times where assignment and last known assignment were unassigned." % n_problems) msk = avg_time_before_jump_n > 0 # Zeros -- i.e. no jumps -- should actualy be infs @@ -103,12 +107,20 @@ def run(self, st): # Do mean avg_time_before_jump[msk] /= avg_time_before_jump_n[msk] + if st.site_network.has_attribute('n_ij'): + st.site_network.remove_attribute('n_ij') + st.site_network.remove_attribute('p_ij') + st.site_network.remove_attribute('jump_lag') + st.site_network.remove_attribute('residence_times') + st.site_network.remove_attribute('occupancy_freqs') + st.site_network.remove_attribute('total_corrected_residences') + st.site_network.add_edge_attribute('jump_lag', avg_time_before_jump) st.site_network.add_edge_attribute('n_ij', n_ij) st.site_network.add_edge_attribute('p_ij', n_ij / total_time_spent_at_site) res_times = np.empty(shape = n_sites, dtype = np.float) - for site in xrange(n_sites): + for site in range(n_sites): times = avg_time_before_jump[site] noninf = times < np.inf if np.any(noninf): @@ -125,18 +137,23 @@ def run(self, st): def jump_lag_by_type(self, sn, return_counts = False): - """Given a SiteNetwork with jump_lag info, compute avg. residence times by type + """Given a SiteNetwork with jump_lag info, compute avg. residence times by type. Computes the average number of frames a mobile particle spends at each type of site before jumping to each other type of site. - Returns an (n_types, n_types) matrix. If no jumps of a given type occured, + Args: + sn (SiteNetwork) + return_counts (bool): Whether to also return a matrix giving the + number of each type of jump that occured. + + + Returns: + An (n_types, n_types) matrix. If no jumps of a given type occured, the corresponding entry is +inf. - If ``return_counts``, then also returns a matrix giving the number of - each type of jump that occured. + If ``return_counts``, two such matrixes. """ - if sn.site_types is None: raise ValueError("SiteNetwork has no type information.") @@ -148,7 +165,7 @@ def jump_lag_by_type(self, if return_counts: countmat = np.empty(shape = outmat.shape, dtype = np.int) - for stype_from, stype_to in itertools.product(xrange(len(all_types)), repeat = 2): + for stype_from, stype_to in itertools.product(range(len(all_types)), repeat = 2): lags = sn.jump_lag[site_types == all_types[stype_from]][:, site_types == all_types[stype_to]] # Only take things that aren't inf lags = lags[lags < np.inf] @@ -172,13 +189,11 @@ def plot_jump_lag(self, sn, mode = 'site', min_n_events = 1, ax = None, fig = No """Plot the jump lag of a site network. :param SiteNetwork sn: - :param str mode: If 'site', show jump lag between individual sites. - If 'type', show jump lag between types of sites (see :func:jump_lag_by_type) - Default: 'site' + :param str mode: If ``'site'``, show jump lag between individual sites. + If ``'type'``, show jump lag between types of sites (see :func:jump_lag_by_type) :param int min_n_events: Minimum number of jump events of a given type (i -> j or type -> type) to show a jump lag. If a given jump has - occured less than min_n_events times, no jump lag will be shown. - Default: 1 + occured less than ``min_n_events`` times, no jump lag will be shown. """ if mode == 'site': mat = np.copy(sn.jump_lag) @@ -192,7 +207,7 @@ def plot_jump_lag(self, sn, mode = 'site', min_n_events = 1, ax = None, fig = No mat[counts < min_n_events] = np.nan # Show diagonal - ax.plot(*zip([0.0, 0.0], mat.shape), color = 'k', alpha = 0.5, linewidth = 1, linestyle = '--') + ax.plot(*list(zip([0.0, 0.0], mat.shape)), color = 'k', alpha = 0.5, linewidth = 1, linestyle = '--') ax.grid() im = ax.matshow(mat, zorder = 10, cmap = 'plasma') diff --git a/sitator/dynamics/MergeSitesByDynamics.py b/sitator/dynamics/MergeSitesByDynamics.py index 5288cca..c8950e5 100644 --- a/sitator/dynamics/MergeSitesByDynamics.py +++ b/sitator/dynamics/MergeSitesByDynamics.py @@ -1,16 +1,23 @@ import numpy as np -from sitator import SiteNetwork, SiteTrajectory from sitator.dynamics import JumpAnalysis from sitator.util import PBCCalculator +from sitator.network.merging import MergeSites +from sitator.util.mcl import markov_clustering -class MergeSitesByDynamics(object): +import logging +logger = logging.getLogger(__name__) + + +class MergeSitesByDynamics(MergeSites): """Merges sites using dynamical data. Given a SiteTrajectory, merges sites using Markov Clustering. :param float distance_threshold: Don't merge sites further than this - in real space. + in real space. Zeros out the connectivity_matrix at distances greater than + this; a hard, step function style cutoff. For a more gentle cutoff, try + changing `connectivity_matrix_generator` to incorporate distance. :param float post_check_thresh_factor: Throw an error if proposed merge sites are further than this * distance_threshold away. Only a sanity check; not a hard guerantee. Can be `None`; defaults to `1.5`. Can be loosely @@ -25,38 +32,92 @@ class MergeSitesByDynamics(object): Valid keys are ``'inflation'``, ``'expansion'``, and ``'pruning_threshold'``. """ def __init__(self, + connectivity_matrix_generator = None, distance_threshold = 1.0, post_check_thresh_factor = 1.5, check_types = True, - verbose = True, - iterlimit = 100, markov_parameters = {}): - self.verbose = verbose + super().__init__( + maximum_merge_distance = post_check_thresh_factor * distance_threshold, + check_types = check_types + ) + + if connectivity_matrix_generator is None: + connectivity_matrix_generator = MergeSitesByDynamics.connectivity_n_ij + assert callable(connectivity_matrix_generator) + + self.connectivity_matrix_generator = connectivity_matrix_generator self.distance_threshold = distance_threshold self.post_check_thresh_factor = post_check_thresh_factor self.check_types = check_types self.iterlimit = iterlimit self.markov_parameters = markov_parameters - def run(self, st): - """Takes a SiteTrajectory and returns a SiteTrajectory, including a new SiteNetwork.""" + # Connectivity Matrix Generation Schemes: + + @staticmethod + def connectivity_n_ij(sn): + """Basic default connectivity scheme: uses n_ij directly as connectivity matrix. + + Works well for systems with sufficient statistics. + """ + return sn.n_ij + + @staticmethod + def connectivity_jump_lag_biased(jump_lag_coeff = 1.0, + jump_lag_sigma = 20.0, + jump_lag_cutoff = np.inf, + distance_coeff = 0.5, + distance_sigma = 1.0): + """Bias the typical connectivity matrix p_ij with jump lag and distance contributions. + + The jump lag and distance are processed through Gaussian functions with + the given sigmas (i.e. higher jump lag/larger distance => lower + connectivity value). These matrixes are then added to p_ij, with a prefactor + of ``jump_lag_coeff`` and ``distance_coeff``. + + Site pairs with jump lags greater than ``jump_lag_cutoff`` have their bias + set to zero regardless of ``jump_lag_sigma``. Defaults to ``inf``. + """ + def cfunc(sn): + jl = sn.jump_lag.copy() + jl -= 1.0 # Center it around 1 since that's the minimum lag, 1 frame + jl /= jump_lag_sigma + np.square(jl, out = jl) + jl *= -0.5 + np.exp(jl, out = jl) # exp correctly takes the -infs to 0 + + jl[sn.jump_lag > jump_lag_cutoff] = 0. + + # Distance term + pbccalc = PBCCalculator(sn.structure.cell) + dists = pbccalc.pairwise_distances(sn.centers) + dmat = dists.copy() + + # We want to strongly boost the similarity of *very* close sites + dmat /= distance_sigma + np.square(dmat, out = dmat) + dmat *= -0.5 + np.exp(dmat, out = dmat) - if self.check_types and st.site_network.site_types is None: - raise ValueError("Cannot run a check_types=True MergeSitesByDynamics on a SiteTrajectory without type information.") + return (sn.p_ij + jump_lag_coeff * jl) * (distance_coeff * dmat + (1 - distance_coeff)) + return cfunc + + # Real methods + + def _get_sites_to_merge(self, st): # -- Compute jump statistics - if not st.site_network.has_attribute('p_ij'): - ja = JumpAnalysis(verbose = self.verbose) + if not st.site_network.has_attribute('n_ij'): + ja = JumpAnalysis() ja.run(st) pbcc = PBCCalculator(st.site_network.structure.cell) site_centers = st.site_network.centers - if self.check_types: - site_types = st.site_network.site_types # -- Build connectivity_matrix - connectivity_matrix = st.site_network.n_ij.copy() + connectivity_matrix = self.connectivity_matrix_generator(st.site_network).copy() n_sites_before = st.site_network.n_sites assert n_sites_before == connectivity_matrix.shape[0] @@ -71,7 +132,7 @@ def run(self, st): n_alarming_ignored_edges = 0 # Apply distance threshold - for i in xrange(n_sites_before): + for i in range(n_sites_before): dists = pbcc.distances(centers_before[i], centers_before[i + 1:]) js_too_far = np.where(dists > self.distance_threshold)[0] js_too_far += i + 1 @@ -83,120 +144,10 @@ def run(self, st): connectivity_matrix[i, js_too_far] = 0 connectivity_matrix[js_too_far, i] = 0 # Symmetry - if self.verbose and n_alarming_ignored_edges > 0: - print(" At least %i site pairs with high (z-score > 3) fluxes were over the given distance cutoff.\n" - " This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges) + if n_alarming_ignored_edges > 0: + logger.warning(" At least %i site pairs with high (z-score > 3) fluxes were over the given distance cutoff.\n" + " This may or may not be a problem; but if `distance_threshold` is low, consider raising it." % n_alarming_ignored_edges) # -- Do Markov Clustering - clusters = self._markov_clustering(connectivity_matrix, **self.markov_parameters) - - new_n_sites = len(clusters) - - if self.verbose: - print "After merge there will be %i sites" % new_n_sites - - if self.check_types: - new_types = np.empty(shape = new_n_sites, dtype = np.int) - - # -- Merge Sites - new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype) - translation = np.empty(shape = st.site_network.n_sites, dtype = np.int) - translation.fill(-1) - - for newsite in xrange(new_n_sites): - mask = list(clusters[newsite]) - # Update translation table - if np.any(translation[mask] != -1): - # We've assigned a different cluster for this before... weird - # degeneracy - raise ValueError("Markov clustering tried to merge site(s) into more than one new site") - translation[mask] = newsite - - to_merge = site_centers[mask] - - # Check distances - if not self.post_check_thresh_factor is None: - dists = pbcc.distances(to_merge[0], to_merge[1:]) - assert np.all(dists < self.post_check_thresh_factor * self.distance_threshold), \ - "Markov clustering tried to merge sites more than %f * %f apart. Lower your distance_threshold?" % (self.post_check_thresh_factor, self.distance_threshold) - - # New site center - new_centers[newsite] = pbcc.average(to_merge) - if self.check_types: - assert np.all(site_types[mask] == site_types[mask][0]) - new_types[newsite] = site_types[mask][0] - - newsn = st.site_network.copy() - newsn.centers = new_centers - if self.check_types: - newsn.site_types = new_types - - newtraj = translation[st._traj] - newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN - - # It doesn't make sense to propagate confidence information through a - # transform that might completely invalidate it - newst = SiteTrajectory(newsn, newtraj, confidences = None) - - if not st.real_trajectory is None: - newst.set_real_traj(st.real_trajectory) - - return newst - - def _markov_clustering(self, - transition_matrix, - expansion = 2, - inflation = 2, - pruning_threshold = 0.00001): - """ - See https://micans.org/mcl/. - - Because we're dealing with matrixes that are stochastic already, - there's no need to add artificial loop values. - - Implementation inspired by https://github.com/GuyAllard/markov_clustering - """ - - assert transition_matrix.shape[0] == transition_matrix.shape[1] - - m1 = transition_matrix.copy() - - # Normalize (though it should be close already) - m1 /= np.sum(m1, axis = 0) - - allcols = np.arange(m1.shape[1]) - - converged = False - for i in xrange(self.iterlimit): - # -- Expansion - m2 = np.linalg.matrix_power(m1, expansion) - # -- Inflation - np.power(m2, inflation, out = m2) - m2 /= np.sum(m2, axis = 0) - # -- Prune - to_prune = m2 < pruning_threshold - # Exclude the max of every column - to_prune[np.argmax(m2, axis = 0), allcols] = False - m2[to_prune] = 0.0 - # -- Check converged - if np.allclose(m1, m2): - converged = True - if self.verbose: - print "Markov Clustering converged in %i iterations" % i - break - - m1[:] = m2 - - if not converged: - raise ValueError("Markov Clustering couldn't converge in %i iterations" % self.iterlimit) - - # -- Get clusters - attractors = m2.diagonal().nonzero()[0] - - clusters = set() - - for a in attractors: - cluster = tuple(m2[a].nonzero()[0]) - clusters.add(cluster) - - return list(clusters) + clusters = markov_clustering(connectivity_matrix, **self.markov_parameters) + return clusters diff --git a/sitator/dynamics/MergeSitesByThreshold.py b/sitator/dynamics/MergeSitesByThreshold.py new file mode 100644 index 0000000..449c236 --- /dev/null +++ b/sitator/dynamics/MergeSitesByThreshold.py @@ -0,0 +1,87 @@ +import numpy as np + +import operator + +from scipy.sparse.csgraph import connected_components + +from sitator.util import PBCCalculator +from sitator.network.merging import MergeSites + + +class MergeSitesByThreshold(MergeSites): + """Merge sites using a strict threshold on any edge property. + + Takes the edge property matrix given by ``attrname``, applys ``relation`` to it + with ``threshold``, and merges all connected components in the graph represented + by the resulting boolean adjacency matrix. + + Threshold is given by a keyword argument to ``run()``. + + Args: + attrname (str): Name of the edge attribute to merge on. + relation (func, default: ``operator.ge``): The relation to use for the + thresholding. + directed, connection (bool, str): Parameters for ``scipy.sparse.csgraph``'s + ``connected_components``. + **kwargs: Passed to ``MergeSites``. + """ + def __init__(self, + attrname, + relation = operator.ge, + directed = True, + connection = 'strong', + distance_threshold = np.inf, + forbid_multiple_occupancy = False, + **kwargs): + self.attrname = attrname + self.relation = relation + self.directed = directed + self.connection = connection + self.distance_threshold = distance_threshold + self.forbid_multiple_occupancy = forbid_multiple_occupancy + super().__init__(**kwargs) + + + def _get_sites_to_merge(self, st, threshold = 0): + sn = st.site_network + + attrmat = getattr(sn, self.attrname) + assert attrmat.shape == (sn.n_sites, sn.n_sites), "`attrname` doesn't seem to indicate an edge property." + connmat = self.relation(attrmat, threshold) + + # Apply distance threshold + if self.distance_threshold < np.inf: + pbcc = PBCCalculator(sn.structure.cell) + centers = sn.centers + for i in range(sn.n_sites): + dists = pbcc.distances(centers[i], centers[i + 1:]) + js_too_far = np.where(dists > self.distance_threshold)[0] + js_too_far += i + 1 + + connmat[i, js_too_far] = False + connmat[js_too_far, i] = False # Symmetry + + if self.forbid_multiple_occupancy: + n_mobile = sn.n_mobile + for frame in st.traj: + frame = [s for s in frame if s >= 0] + for site in frame: # only known + # can't merge occupied site with other simulatanious occupied sites + connmat[site, frame] = False + + # Everything is always mergable with itself. + np.fill_diagonal(connmat, True) + + # Get mergable groups + n_merged_sites, labels = connected_components( + connmat, + directed = self.directed, + connection = self.connection + ) + # MergeSites will check pairwise distances; we just need to make it the + # right format. + merge_groups = [] + for lbl in range(n_merged_sites): + merge_groups.append(np.where(labels == lbl)[0]) + + return merge_groups diff --git a/sitator/dynamics/RemoveUnoccupiedSites.py b/sitator/dynamics/RemoveUnoccupiedSites.py new file mode 100644 index 0000000..22f3a41 --- /dev/null +++ b/sitator/dynamics/RemoveUnoccupiedSites.py @@ -0,0 +1,71 @@ +import numpy as np + +from sitator import SiteTrajectory +from sitator.errors import InsufficientSitesError + +import logging +logger = logging.getLogger(__name__) + + +class RemoveUnoccupiedSites(object): + """Remove unoccupied sites.""" + def __init__(self): + pass + + def run(self, st, return_kept_sites = False): + """ + Args: + return_kept_sites (bool): If ``True``, a list of the sites from ``st`` + that were kept will be returned. + + Returns: + A ``SiteTrajectory``, or ``st`` itself if it has no unoccupied sites. + """ + assert isinstance(st, SiteTrajectory) + + old_sn = st.site_network + + # Allow for the -1 to affect nothing + seen_mask = np.zeros(shape = old_sn.n_sites + 1, dtype = np.bool) + + for frame in st.traj: + seen_mask[frame] = True + if np.all(seen_mask[:-1]): + return st + + seen_mask = seen_mask[:-1] + + logger.info("Removing unoccupied sites %s" % np.where(~seen_mask)[0]) + + n_new_sites = np.sum(seen_mask) + + if n_new_sites < old_sn.n_mobile: + raise InsufficientSitesError( + verb = "Removing unoccupied sites", + n_sites = n_new_sites, + n_mobile = old_sn.n_mobile + ) + + translation = np.empty(shape = old_sn.n_sites + 1, dtype = np.int) + translation[:-1][seen_mask] = np.arange(n_new_sites) + translation[:-1][~seen_mask] = -4321 + translation[-1] = SiteTrajectory.SITE_UNKNOWN # Map unknown to unknown + + newtraj = translation[st.traj.reshape(-1)] + newtraj.shape = st.traj.shape + + assert -4321 not in newtraj + + # We don't clear computed attributes since nothing is changing for other sites. + newsn = old_sn[seen_mask] + + new_st = SiteTrajectory( + site_network = newsn, + particle_assignments = newtraj + ) + if st.real_trajectory is not None: + new_st.set_real_traj(st.real_trajectory) + if return_kept_sites: + return new_st, np.where(seen_mask) + else: + return new_st diff --git a/sitator/dynamics/ReplaceUnassignedPositions.py b/sitator/dynamics/ReplaceUnassignedPositions.py new file mode 100644 index 0000000..1ee2a21 --- /dev/null +++ b/sitator/dynamics/ReplaceUnassignedPositions.py @@ -0,0 +1,117 @@ +import numpy as np + +from sitator import SiteTrajectory +from sitator.util import PBCCalculator + +import logging +logger = logging.getLogger(__name__) + + +# See https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi +def rle(inarray): + """ run length encoding. Partial credit to R rle function. + Multi datatype arrays catered for including non Numpy + returns: tuple (runlengths, startpositions, values) """ + ia = np.asarray(inarray) # force numpy + n = len(ia) + if n == 0: + return (None, None, None) + else: + y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe) + i = np.append(np.where(y), n - 1) # must include last element posi + z = np.diff(np.append(-1, i)) # run lengths + p = np.cumsum(np.append(0, z))[:-1] # positions + return(z, p, ia[i]) + +class ReplaceUnassignedPositions(object): + """Fill in missing site assignments in a SiteTrajectory. + + Args: + replacement_function (callable): Callable that takes + ``(st, mobile_atom, before_site, start_frame, after_site, end_frame)`` and + returns either: + - A single site assignment with which the unassigned will be replaced + - A timeseries of length ``end_frame - start_frame`` of site + assignments with which the unassigned will be replaced. + If ``None``, defaults to + ``AssignUnassignedPositions.replace_with_from``. + """ + def __init__(self, + replacement_function = None): + if replacement_function is None: + replacement_function = RemoveShortJumps.replace_with_last_known + self.replacement_function = replacement_function + + @staticmethod + def replace_with_last_known(st, mobile_atom, before_site, start_frame, after_site, end_frame): + """Replace unassigned with the last known site.""" + return before_site + + @staticmethod + def replace_with_next_known(st, mobile_atom, before_site, start_frame, after_site, end_frame): + """Replace unassigned with the next known site.""" + return after_site + + @staticmethod + def replace_with_closer(): + """Create function to replace unknown with closest site over time. + + Assigns each of the positions during an unassigned run to whichever of + the before and after sites it is closer to in real space. + """ + pbcc = None + ptbuf = np.empty(shape = (2, 3)) + distbuf = np.empty(shape = 2) + def replace_with_closer(st, mobile_atom, before_site, start_frame, after_site, end_frame): + if before_site == SiteTrajectory.SITE_UNKNOWN or \ + after_site == SiteTrajectory.SITE_UNKNOWN: + return SiteTrajectory.SITE_UNKNOWN + + if pbcc is None: + pbcc = PBCCalculator(st.site_network.structure.cell) + n_frames = end_frame - start_frame + out = np.empty(shape = n_frames) + for i in range(n_frames): + ptbuf[0] = st.site_network.centers[before_site] + ptbuf[1] = st.site_network.centers[after_site] + pbcc.distances( + st.real_trajectory[start_frame + i, mobile_atom], + ptbuf, + in_place = True, + out = distbuf + ) + if distbuf[0] < distbuf[1]: + out[i] = before_site + else: + out[i] = after_site + return out + + + def run(self, + st): + n_mobile = st.site_network.n_mobile + n_frames = st.n_frames + n_sites = st.site_network.n_sites + + traj = st.traj + out = st.traj.copy() + + for mob in range(n_mobile): + runlen, start, runsites = rle(traj[:, mob]) + for runi in range(len(runlen)): + if runsites[runi] == SiteTrajectory.SITE_UNKNOWN: + unknown_start = start[runi] + unknown_end = unknown_start + runlen[runi] + unknown_before = runsites[runi - 1] if runi > 0 else SiteTrajectory.SITE_UNKNOWN + unknown_after = runsites[runi + 1] if runi < len(runlen) - 1 else SiteTrajectory.SITE_UNKNOWN + replace = self.replacement_function( + st, mob, + unknown_before, unknown_start, + unknown_after, unknown_end + ) + out[unknown_start:unknown_end, mob] = replace + + st = st.copy(with_computed = False) + st._traj = out + + return st diff --git a/sitator/dynamics/SmoothSiteTrajectory.pyx b/sitator/dynamics/SmoothSiteTrajectory.pyx new file mode 100644 index 0000000..379b2d7 --- /dev/null +++ b/sitator/dynamics/SmoothSiteTrajectory.pyx @@ -0,0 +1,111 @@ +# cython: language_level=3 + +import numpy as np + +from sitator import SiteTrajectory +from sitator.dynamics import RemoveUnoccupiedSites + +import logging +logger = logging.getLogger(__name__) + +ctypedef Py_ssize_t site_int + +class SmoothSiteTrajectory(object): + """"Smooth" a SiteTrajectory by applying a rolling mode. + + For each mobile particle, the assignmet at each frame is replaced by the + mode of its site assignments over some number of frames centered around it. + + Can be thought of as a discrete lowpass filter. + + The ``set_unassigned_under_threshold`` parameter allows the user to control + how the smoothing handles "transitions" vs. "attempts"; setting it to True, + the default, will mark as unassigned transitional moments where neither + the source nor destination site have a sufficient (``threshold``) majority + in the window, while setting it to False will maintain the assignment to a + transitional site. + + Args: + window_threshold_factor (float): The total width of the rolling window, + in terms of the threshold. + remove_unoccupied_sites (bool): If True, sites that are unoccupied after + the smoothing will be removed. + set_unassigned_under_threshold (bool): If True, if the multiplicity of + the mode is less than the threshold, the particle is marked + unassigned at that frame. If False, the particle's assignment will + not be modified. + """ + def __init__(self, + window_threshold_factor = 2.1, + remove_unoccupied_sites = True, + set_unassigned_under_threshold = True): + self.window_threshold_factor = window_threshold_factor + self.remove_unoccupied_sites = remove_unoccupied_sites + self.set_unassigned_under_threshold = set_unassigned_under_threshold + + def run(self, + st, + threshold): + n_mobile = st.site_network.n_mobile + n_frames = st.n_frames + n_sites = st.site_network.n_sites + + traj = st.traj + out = st.traj.copy() + + window = self.window_threshold_factor * threshold + wleft, wright = int(np.floor(window / 2)), int(np.ceil(window / 2)) + + running_windowed_mode( + traj, + out, + wleft, + wright, + threshold, + n_sites, + self.set_unassigned_under_threshold + ) + + st = st.copy(with_computed = False) + st._traj = out + if self.remove_unoccupied_sites: + # Removing short jumps could have made some sites completely unoccupied + st = RemoveUnoccupiedSites().run(st) + st.site_network.clear_attributes() + + return st + + +cpdef running_windowed_mode(site_int [:, :] traj, + site_int [:, :] out, + Py_ssize_t wleft, + Py_ssize_t wright, + Py_ssize_t threshold, + Py_ssize_t n_sites, + bint replace_no_winner_unknown): + countbuf_np = np.zeros(shape = n_sites + 1, dtype = np.int) + cdef Py_ssize_t [:] countbuf = countbuf_np + cdef Py_ssize_t n_mobile = traj.shape[1] + cdef Py_ssize_t n_frames = traj.shape[0] + cdef site_int s_unknown = SiteTrajectory.SITE_UNKNOWN + cdef site_int winner + cdef Py_ssize_t best_count + + for mob in range(n_mobile): + for frame in range(n_frames): + for wi in range(max(frame - wleft, 0), min(frame + wright, n_frames)): + countbuf[traj[wi, mob] + 1] += 1 + winner = 0 # THis is actually -1, so unknown by default + best_count = 0 + for site in range(n_sites + 1): + if countbuf[site] > best_count: + winner = site + best_count = countbuf[site] + if best_count >= threshold: + out[frame, mob] = winner - 1 + else: + if replace_no_winner_unknown: + out[frame, mob] = s_unknown + else: + out[frame, mob] = traj[frame, mob] + countbuf[:] = 0 diff --git a/sitator/dynamics/__init__.py b/sitator/dynamics/__init__.py index ce898fc..812c503 100644 --- a/sitator/dynamics/__init__.py +++ b/sitator/dynamics/__init__.py @@ -1,3 +1,9 @@ -from JumpAnalysis import JumpAnalysis +from .JumpAnalysis import JumpAnalysis +from .MergeSitesByDynamics import MergeSitesByDynamics +from .MergeSitesByThreshold import MergeSitesByThreshold +from .RemoveUnoccupiedSites import RemoveUnoccupiedSites +from .SmoothSiteTrajectory import SmoothSiteTrajectory +from .AverageVibrationalFrequency import AverageVibrationalFrequency -from MergeSitesByDynamics import MergeSitesByDynamics +# For backwards compatability, since this used to be in this module +from sitator.network import DiffusionPathwayAnalysis diff --git a/sitator/errors.py b/sitator/errors.py new file mode 100644 index 0000000..06572a4 --- /dev/null +++ b/sitator/errors.py @@ -0,0 +1,21 @@ + +class SiteAnaysisError(Exception): + """An error occuring as part of site analysis.""" + pass + +class MultipleOccupancyError(SiteAnaysisError): + """Error raised when multiple mobile atoms are assigned to the same site at the same time.""" + def __init__(self, mobile, site, frame): + super().__init__( + "Multiple mobile particles %s were assigned to site %i at frame %i." % (mobile, site, frame) + ) + self.mobile_particles = mobile + self.site = site + self.frame = frame + +class InsufficientSitesError(SiteAnaysisError): + """Site detection/merging/etc. resulted in fewer sites than mobile particles.""" + def __init__(self, verb, n_sites, n_mobile): + super().__init__("%s resulted in only %i sites for %i mobile particles." % (verb, n_sites, n_mobile)) + self.n_sites = n_sites + self.n_mobile = n_mobile diff --git a/sitator/ionic.py b/sitator/ionic.py new file mode 100644 index 0000000..538618d --- /dev/null +++ b/sitator/ionic.py @@ -0,0 +1,139 @@ +import numpy as np + +from sitator import SiteNetwork + +import ase.data + +try: + from pymatgen.io.ase import AseAtomsAdaptor + import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf + from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \ + LightStructureEnvironments + from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions + from pymatgen.analysis.bond_valence import BVAnalyzer + has_pymatgen = True +except ImportError: + has_pymatgen = False + + +class IonicSiteNetwork(SiteNetwork): + """Site network for a species of mobile charged ions in a static lattice. + + Imposes more restrictions than a plain ``SiteNetwork``: + + - Has a mobile species in place of an arbitrary mobile mask + - Has a set of static species in place of an arbitraty static mask + - All atoms of the mobile species must have the same charge + + And contains more information... + + Attributes: + opposite_ion_mask (ndarray): A mask on ``structure`` indicating all + anions and neutrally charged atoms if the mobile species is a cation, + or vice versa if the mobile species is an anion. + opposite_ion_structure (ase.Atoms): An ``Atoms`` containing the atoms + indicated by ``opposite_ion_mask``. + same_ion_mask (ndarray): A mask on ``structure`` indicating all + atoms whose charge has the same sign as the mobile species. + same_ion_structure (ase.Atoms): An ``Atoms`` containing the atoms + indicated by ``same_ion_mask``. + n_opposite_charge (int): The number of opposite charge static atoms. + n_same_charge (int): The number of same charge static atoms. + + Args: + structure (ase.Atoms) + mobile_species (int): Atomic number of the mobile species. + static_species (list of int): Atomic numbers of the static species. + mobile_charge (int): Charge of mobile atoms. If ``None``, + ``pymatgen``'s ``BVAnalyzer`` will be used to estimate valences. + static_charges (ndarray int): Charges of the atoms in the static + structure. If ``None``, ``sitator`` will try to use + ``pymatgen``'s ``BVAnalyzer`` to estimate valences. + """ + def __init__(self, + structure, + mobile_species, + static_species, + mobile_charge = None, + static_charges = None): + if mobile_species in static_species: + raise ValueError("Mobile species %i cannot also be one of static species %s" % (mobile_species, static_species)) + mobile_mask = structure.numbers == mobile_species + static_mask = np.in1d(structure.numbers, static_species) + super().__init__( + structure = structure, + mobile_mask = mobile_mask, + static_mask = static_mask + ) + + self.mobile_species = mobile_species + self.static_species = static_species + # Estimate bond valences if necessary + if mobile_charge is None or static_charges is None: + if not has_pymatgen: + raise ImportError("Pymatgen could not be imported, and is required for guessing charges.") + sim_struct = AseAtomsAdaptor.get_structure(structure) + bv = BVAnalyzer() + struct_valences = np.asarray(bv.get_valences(sim_struct)) + if static_charges is None: + static_charges = struct_valences[static_mask] + if mobile_charge is None: + mob_val = struct_valences[mobile_mask] + if np.any(mob_val != mob_val[0]): + raise ValueError("Mobile atom estimated valences (%s) not uniform; arbitrarily taking first." % mob_val) + mobile_charge = mob_val[0] + self.mobile_charge = mobile_charge + self.static_charges = static_charges + + # Create oposite ion stuff + mobile_sign = np.sign(mobile_charge) + static_signs = np.sign(static_charges) + self.opposite_ion_mask = np.empty_like(static_mask) + self.opposite_ion_mask.fill(False) + self.opposite_ion_mask[static_mask] = static_signs != mobile_sign + self.opposite_ion_structure = structure[self.opposite_ion_mask] + + self.same_ion_mask = np.empty_like(static_mask) + self.same_ion_mask.fill(False) + self.same_ion_mask[static_mask] = static_signs == mobile_sign + self.same_ion_structure = structure[self.same_ion_mask] + + @property + def n_opposite_charge(self): + return np.sum(self.opposite_ion_mask) + + @property + def n_same_charge(self): + return np.sum(self.same_ion_mask) + + def __getitem__(self, key): + out = super().__getitem__(key) + out.mobile_species = self.mobile_species + out.static_species = self.static_species + out.mobile_charge = self.mobile_charge + out.static_charges = self.static_charges + out.opposite_ion_mask = self.opposite_ion_mask + out.opposite_ion_structure = self.opposite_ion_structure + out.same_ion_mask = self.same_ion_mask + out.same_ion_structure = self.same_ion_structure + return out + + def __str__(self): + out = super().__str__() + static_nums = self.static_structure.numbers + out += ( + " Mobile species: {:2} (charge {:+d})\n" + " Static species: {}\n" + " # opposite charge: {}\n" + ).format( + ase.data.chemical_symbols[self.mobile_species], + self.mobile_charge, + ", ".join( + "{} (avg. charge {:+.1f})".format( + ase.data.chemical_symbols[s], + np.mean(self.static_charges[static_nums == s]) + ) for s in self.static_species + ), + self.n_opposite_charge + ) + return out diff --git a/sitator/landmark/LandmarkAnalysis.py b/sitator/landmark/LandmarkAnalysis.py index a972ca3..85dd5d4 100644 --- a/sitator/landmark/LandmarkAnalysis.py +++ b/sitator/landmark/LandmarkAnalysis.py @@ -1,28 +1,20 @@ import numpy as np from sitator.util import PBCCalculator +from sitator.util.progress import tqdm -# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698 import sys -try: - ipy_str = str(type(get_ipython())) - if 'zmqshell' in ipy_str: - from tqdm import tqdm_notebook as tqdm - if 'terminal' in ipy_str: - from tqdm import tqdm -except: - if sys.stderr.isatty(): - from tqdm import tqdm - else: - def tqdm(iterable, **kwargs): - return iterable import importlib import tempfile -import helpers +from . import helpers from sitator import SiteNetwork, SiteTrajectory +from sitator.errors import MultipleOccupancyError, InsufficientSitesError +from sitator.landmark.anchor import to_origin +import logging +logger = logging.getLogger(__name__) from functools import wraps def analysis_result(func): @@ -35,7 +27,76 @@ def wrapper(self, *args, **kwargs): return wrapper class LandmarkAnalysis(object): - """Track a mobile species through a fixed lattice using landmark vectors.""" + """Site analysis of mobile atoms in a static lattice with landmark analysis. + + :param double cutoff_center: The midpoint for the logistic function used + as the landmark cutoff function. (unitless) + :param double cutoff_steepness: Steepness of the logistic cutoff function. + :param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied) + for a site to qualify as such. + :param str clustering_algorithm: The landmark clustering algorithm. ``sitator`` + supplies two: + - ``"dotprod"``: The method described in our "Unsupervised landmark + analysis for jump detection in molecular dynamics simulations" paper. + - ``"mcl"``: A newer method we are developing. + :param dict clustering_params: Parameters for the chosen ``clustering_algorithm``. + :param str site_centers_method: The method to use for computing the real + space positions of the sites. Options: + - ``SITE_CENTERS_REAL_UNWEIGHTED``: A spatial average of all real-space + mobile atom positions assigned to the site is taken. + - ``SITE_CENTERS_REAL_WEIGHTED``: A spatial average of all real-space + mobile atom positions assigned to the site is taken, weighted + by the confidences with which they assigned to the site. + - ``SITE_CENTERS_REPRESENTATIVE_LANDMARK``: A spatial average over + all landmarks' centers is taken, weighted by the representative + or "typical" landmark vector at the site. + The "real" methods will generally be more faithful to the simulation, + but the representative landmark method can work better in cases with + short trajectories, producing a more "ideal" site location. + :param bool check_for_zero_landmarks: Whether to check for and raise exceptions + when all-zero landmark vectors are computed. + :param float static_movement_threshold: (Angstrom) the maximum allowed + distance between an instantanous static atom position and it's ideal position. + :param bool/callable dynamic_lattice_mapping: Whether to dynamically decide + each frame which static atom represents each average lattice position; + this allows the LandmarkAnalysis to deal with, say, a rare exchage of + two static atoms that does not change the structure of the lattice. + + It does NOT allow LandmarkAnalysis to deal with lattices whose structures + actually change over the course of the trajectory. + + In certain cases this is better delt with by ``MergeSitesByDynamics``. + + If ``False``, no mapping will occur. Otherwise, a callable taking a + ``SiteNetwork`` should be provided. The callable should return a list + of static atom indexes that can be validly assigned to each static lattice + position. If ``True``, ``sitator.landmark.dynamic_mapping.within_species`` + is used. + :param int max_mobile_per_site: The maximum number of mobile atoms that can + be assigned to a single site without throwing an error. Regardless of the + value, assignments of more than one mobile atom to a single site will + be recorded and reported. + + Setting this to 2 can be necessary for very diffusive, liquid-like + materials at high temperatures. + + Statistics related to this are reported in ``self.avg_mobile_per_site`` + and ``self.n_multiple_assignments``. + :param bool force_no_memmap: if True, landmark vectors will be stored only in memory. + Only useful if access to landmark vectors after the analysis has run is desired. + :param bool verbose: Verbosity for the ``clustering_algorithm``. Other output + controlled through ``logging``. + """ + + SITE_CENTERS_REAL_UNWEIGHTED = 'real-unweighted' + SITE_CENTERS_REAL_WEIGHTED = 'real-weighted' + SITE_CENTERS_REPRESENTATIVE_LANDMARK = 'representative-landmark' + + CLUSTERING_CLUSTER_SIZE = 'cluster-size' + CLUSTERING_LABELS = 'cluster-labels' + CLUSTERING_CONFIDENCES = 'cluster-confs' + CLUSTERING_LANDMARK_GROUPINGS = 'cluster-landmark-groupings' + CLUSTERING_REPRESENTATIVE_LANDMARKS = 'cluster-representative-lvecs' def __init__(self, clustering_algorithm = 'dotprod', @@ -43,57 +104,15 @@ def __init__(self, cutoff_midpoint = 1.5, cutoff_steepness = 30, minimum_site_occupancy = 0.01, - peak_evening = 'none', - weighted_site_positions = True, + site_centers_method = SITE_CENTERS_REAL_WEIGHTED, check_for_zero_landmarks = True, static_movement_threshold = 1.0, dynamic_lattice_mapping = False, + static_anchoring = to_origin, relaxed_lattice_checks = False, max_mobile_per_site = 1, force_no_memmap = False, verbose = True): - """ - :param double cutoff_center: The midpoint for the logistic function used - as the landmark cutoff function. (unitless) - :param double cutoff_steepness: Steepness of the logistic cutoff function. - :param double minimum_site_occupancy = 0.1: Minimum occupancy (% of time occupied) - for a site to qualify as such. - :param dict clustering_params: Parameters for the chosen clustering_algorithm - :param str peak_evening: Whether and what kind of peak "evening" to apply; - that is, processing that makes all large peaks in the landmark vector - more similar in magnitude. This can help in site clustering. - - Valid options: 'none', 'clip' - :param bool weighted_site_positions: When computing site positions, whether - to weight the average by assignment confidence. - :param bool check_for_zero_landmarks: Whether to check for and raise exceptions - when all-zero landmark vectors are computed. - :param float static_movement_threshold: (Angstrom) the maximum allowed - distance between an instantanous static atom position and it's ideal position. - :param bool dynamic_lattice_mapping: Whether to dynamically decide each - frame which static atom represents each average lattice position; - this allows the LandmarkAnalysis to deal with, say, a rare exchage of - two static atoms that does not change the structure of the lattice. - - It does NOT allow LandmarkAnalysis to deal with lattices whose structures - actually change over the course of the trajectory. - - In certain cases this is better delt with by MergeSitesByDynamics. - :param int max_mobile_per_site: The maximum number of mobile atoms that can - be assigned to a single site without throwing an error. Regardless of the - value, assignments of more than one mobile atom to a single site will - be recorded and reported. - - Setting this to 2 can be necessary for very diffusive, liquid-like - materials at high temperatures. - - Statistics related to this are reported in self.avg_mobile_per_site - and self.n_multiple_assignments. - :param bool force_no_memmap: if True, landmark vectors will be stored only in memory. - Only useful if access to landmark vectors after the analysis has run is desired. - :param bool verbose: If `True`, progress bars and messages will be printed to stdout. - """ - self._cutoff_midpoint = cutoff_midpoint self._cutoff_steepness = cutoff_steepness self._minimum_site_occupancy = minimum_site_occupancy @@ -101,14 +120,16 @@ def __init__(self, self._cluster_algo = clustering_algorithm self._clustering_params = clustering_params - if not peak_evening in ['none', 'clip']: - raise ValueError("Invalid value `%s` for peak_evening" % peak_evening) - self._peak_evening = peak_evening - self.verbose = verbose self.check_for_zero_landmarks = check_for_zero_landmarks - self.weighted_site_positions = weighted_site_positions + self.site_centers_method = site_centers_method + + if dynamic_lattice_mapping is True: + from sitator.landmark.dynamic_mapping import within_species + dynamic_lattice_mapping = within_species self.dynamic_lattice_mapping = dynamic_lattice_mapping + self.static_anchoring = static_anchoring + self.relaxed_lattice_checks = relaxed_lattice_checks self._landmark_vectors = None @@ -123,26 +144,34 @@ def __init__(self, @property def cutoff(self): - return self._cutoff + return self._cutoff @analysis_result def landmark_vectors(self): + """Landmark vectors from the last invocation of ``run()``""" view = self._landmark_vectors[:] view.flags.writeable = False return view @analysis_result def landmark_dimension(self): + """Number of components in a single landmark vector.""" return self._landmark_dimension - def run(self, sn, frames): """Run the landmark analysis. - The input SiteNetwork is a network of predicted sites; it's sites will + The input ``SiteNetwork`` is a network of predicted sites; it's sites will be used as the "basis" for the landmark vectors. - Takes a SiteNetwork and returns a SiteTrajectory. + Wraps a copy of ``frames`` into the unit cell. + + Args: + sn (SiteNetwork): The landmark basis. Each site is a landmark defined + by its vertex static atoms, as indicated by `sn.vertices`. + (Typically from ``VoronoiSiteGenerator``.) + frames (ndarray n_frames x n_atoms x 3): A trajectory. Can be unwrapped; + a copy will be wrapped before the analysis. """ assert isinstance(sn, SiteNetwork) @@ -150,24 +179,33 @@ def run(self, sn, frames): raise ValueError("Cannot rerun LandmarkAnalysis!") if frames.shape[1:] != (sn.n_total, 3): - raise ValueError("Wrong shape %s for frames." % frames.shape) + raise ValueError("Wrong shape %s for frames." % (frames.shape,)) if sn.vertices is None: raise ValueError("Input SiteNetwork must have vertices") n_frames = len(frames) - if self.verbose: - print "--- Running Landmark Analysis ---" + logger.info("--- Running Landmark Analysis ---") # Create PBCCalculator self._pbcc = PBCCalculator(sn.structure.cell) + # -- Step 0: Wrap to Unit Cell + orig_frames = frames # Keep a reference around + frames = frames.copy() + # Flatten to list of points for wrapping + orig_frame_shape = frames.shape + frames.shape = (orig_frame_shape[0] * orig_frame_shape[1], 3) + self._pbcc.wrap_points(frames) + # Back to list of frames + frames.shape = orig_frame_shape + # -- Step 1: Compute site-to-vertex distances self._landmark_dimension = sn.n_sites longest_vert_set = np.max([len(v) for v in sn.vertices]) - verts_np = np.array([v + [-1] * (longest_vert_set - len(v)) for v in sn.vertices]) + verts_np = np.array([np.concatenate((v, [-1] * (longest_vert_set - len(v)))) for v in sn.vertices], dtype = np.int) site_vert_dists = np.empty(shape = verts_np.shape, dtype = np.float) site_vert_dists.fill(np.nan) @@ -177,8 +215,17 @@ def run(self, sn, frames): site_vert_dists[i, :len(polyhedron)] = dists # -- Step 2: Compute landmark vectors - if self.verbose: print " - computing landmark vectors -" - # Compute landmark vectors + logger.info(" - computing landmark vectors -") + + if self.dynamic_lattice_mapping: + dynmap_compat = self.dynamic_lattice_mapping(sn) + else: + # If no dynamic mapping, each is only compatable with itself. + dynmap_compat = np.arange(sn.n_static)[:, np.newaxis] + assert len(dynmap_compat) == sn.n_static + static_anchors = self.static_anchoring(sn) + # This also validates the anchors + lattice_pt_order = self._get_lattice_order_from_anchors(static_anchors) # The dimension of one landmark vector is the number of Voronoi regions shape = (n_frames * sn.n_mobile, self._landmark_dimension) @@ -193,25 +240,48 @@ def run(self, sn, frames): shape = shape) helpers._fill_landmark_vectors(self, sn, verts_np, site_vert_dists, - frames, check_for_zeros = self.check_for_zero_landmarks, - tqdm = tqdm) + frames, + dynmap_compat = dynmap_compat, + lattice_pt_anchors = static_anchors, + lattice_pt_order = lattice_pt_order, + check_for_zeros = self.check_for_zero_landmarks, + tqdm = tqdm, logger = logger) + + if not self.check_for_zero_landmarks and self.n_all_zero_lvecs > 0: + logger.warning(" Had %i all-zero landmark vectors; no error because `check_for_zero_landmarks = False`." % self.n_all_zero_lvecs) + elif self.check_for_zero_landmarks: + assert self.n_all_zero_lvecs == 0 # -- Step 3: Cluster landmark vectors - if self.verbose: print " - clustering landmark vectors -" - # - Preprocess - - self._do_peak_evening() + logger.info(" - clustering landmark vectors -") # - Cluster - - cluster_func = importlib.import_module("..cluster." + self._cluster_algo, package = __name__).do_landmark_clustering + # FIXME: remove reload after development done + clustermod = importlib.import_module("..cluster." + self._cluster_algo, package = __name__) + importlib.reload(clustermod) + cluster_func = clustermod.do_landmark_clustering - cluster_counts, lmk_lbls, lmk_confs = \ + clustering = \ cluster_func(self._landmark_vectors, clustering_params = self._clustering_params, min_samples = self._minimum_site_occupancy / float(sn.n_mobile), verbose = self.verbose) - if self.verbose: - print " Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls))) + cluster_counts = clustering[LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE] + lmk_lbls = clustering[LandmarkAnalysis.CLUSTERING_LABELS] + lmk_confs = clustering[LandmarkAnalysis.CLUSTERING_CONFIDENCES] + if LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS in clustering: + landmark_clusters = clustering[LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS] + assert len(cluster_counts) == len(landmark_clusters) + else: + landmark_clusters = None + if LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS in clustering: + rep_lvecs = np.asarray(clustering[LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS]) + assert rep_lvecs.shape == (len(cluster_counts), self._landmark_vectors.shape[1]) + else: + rep_lvecs = None + + logging.info(" Failed to assign %i%% of mobile particle positions to sites." % (100.0 * np.sum(lmk_lbls < 0) / float(len(lmk_lbls)))) # reshape lables and confidences lmk_lbls.shape = (n_frames, sn.n_mobile) @@ -219,61 +289,77 @@ def run(self, sn, frames): n_sites = len(cluster_counts) - if n_sites < sn.n_mobile: - raise ValueError("There are %i mobile particles, but only identified %i sites. Check clustering_params." % (sn.n_mobile, n_sites)) + if n_sites < (sn.n_mobile / self.max_mobile_per_site): + raise InsufficientSitesError( + verb = "Landmark analysis", + n_sites = n_sites, + n_mobile = sn.n_mobile + ) - if self.verbose: - print " Identified %i sites with assignment counts %s" % (n_sites, cluster_counts) - - # Check that multiple particles are never assigned to one site at the - # same time, cause that would be wrong. - n_more_than_ones = 0 - avg_mobile_per_site = 0 - divisor = 0 - for frame_i, site_frame in enumerate(lmk_lbls): - _, counts = np.unique(site_frame[site_frame >= 0], return_counts = True) - count_msk = counts > self.max_mobile_per_site - if np.any(count_msk): - raise ValueError("%i mobile particles were assigned to only %i site(s) (%s) at frame %i." % (np.sum(counts[count_msk]), np.sum(count_msk), np.where(count_msk)[0], frame_i)) - n_more_than_ones += np.sum(counts > 1) - avg_mobile_per_site += np.sum(counts) - divisor += len(counts) - - self.n_multiple_assignments = n_more_than_ones - self.avg_mobile_per_site = avg_mobile_per_site / float(divisor) + logging.info(" Identified %i sites with assignment counts %s" % (n_sites, cluster_counts)) # -- Do output + out_sn = sn.copy() # - Compute site centers site_centers = np.empty(shape = (n_sites, 3), dtype = frames.dtype) - - for site in xrange(n_sites): - mask = lmk_lbls == site - pts = frames[:, sn.mobile_mask][mask] - if self.weighted_site_positions: - site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask]) - else: - site_centers[site] = self._pbcc.average(pts) - - # Build output obejcts - out_sn = sn.copy() - + if self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_WEIGHTED or \ + self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_UNWEIGHTED: + for site in range(n_sites): + mask = lmk_lbls == site + pts = frames[:, sn.mobile_mask][mask] + if self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REAL_WEIGHTED: + site_centers[site] = self._pbcc.average(pts, weights = lmk_confs[mask]) + else: + site_centers[site] = self._pbcc.average(pts) + elif self.site_centers_method == LandmarkAnalysis.SITE_CENTERS_REPRESENTATIVE_LANDMARK: + if rep_lvecs is None: + raise ValueError("Chosen clustering method (with current parameters) didn't return representative landmark vectors; can't use SITE_CENTERS_REPRESENTATIVE_LANDMARK.") + for site in range(n_sites): + weights_nonzero = rep_lvecs[site] > 0 + site_centers[site] = self._pbcc.average( + sn.centers[weights_nonzero], + weights = rep_lvecs[site, weights_nonzero] + ) + else: + raise ValueError("Invalid site centers method '%s'" % self.site_centers_method) out_sn.centers = site_centers - assert out_sn.vertices is None + # - If clustering gave us that, compute site vertices + if landmark_clusters is not None: + vertices = [] + for lclust in landmark_clusters: + vertices.append(set.union(*[set(sn.vertices[l]) for l in lclust])) + out_sn.vertices = vertices out_st = SiteTrajectory(out_sn, lmk_lbls, lmk_confs) - out_st.set_real_traj(frames) + + # Check that multiple particles are never assigned to one site at the + # same time, cause that would be wrong. + self.n_multiple_assignments, self.avg_mobile_per_site = out_st.check_multiple_occupancy( + max_mobile_per_site = self.max_mobile_per_site + ) + + out_st.set_real_traj(orig_frames) self._has_run = True return out_st - # -------- "private" methods -------- - - def _do_peak_evening(self): - if self._peak_evening == 'none': - return - elif self._peak_evening == 'clip': - lvec_peaks = np.max(self._landmark_vectors, axis = 1) - # Clip all peaks to the lowest "normal" (stdev.) peak - lvec_clip = np.mean(lvec_peaks) - np.std(lvec_peaks) - # Do the clipping - self._landmark_vectors[self._landmark_vectors > lvec_clip] = lvec_clip + def _get_lattice_order_from_anchors(self, lattice_pt_anchors): + absolute_lattice_mask = lattice_pt_anchors == -1 + lattice_pt_order = [] + # -1 (absolte anchor of origin) is always known + known = np.zeros(shape = len(lattice_pt_anchors) + 1, dtype = np.bool) + known[-1] = True + + while True: + can_know = known[lattice_pt_anchors] + new_know = can_know & ~known[:-1] + known[:-1] |= can_know + lattice_pt_order.extend(np.where(new_know)[0]) + if not np.any(new_know): + break + if len(lattice_pt_order) < len(lattice_pt_anchors): + raise ValueError("Lattice point anchors %s contains a unsatisfiable dependency (likely a circular dependency)." % lattice_pt_anchors) + # Remove points with absolute anchors from the order; there's no need to + # do any computations for them. + lattice_pt_order = lattice_pt_order[np.sum(absolute_lattice_mask):] + return lattice_pt_order diff --git a/sitator/landmark/__init__.py b/sitator/landmark/__init__.py index e100d24..a084133 100644 --- a/sitator/landmark/__init__.py +++ b/sitator/landmark/__init__.py @@ -1,4 +1,4 @@ -from errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError +from .errors import StaticLatticeError, ZeroLandmarkError, LandmarkAnalysisError -from LandmarkAnalysis import LandmarkAnalysis +from .LandmarkAnalysis import LandmarkAnalysis diff --git a/sitator/landmark/anchor.py b/sitator/landmark/anchor.py new file mode 100644 index 0000000..067c859 --- /dev/null +++ b/sitator/landmark/anchor.py @@ -0,0 +1,49 @@ +import numpy as np + +from collections import Counter + +import ase.data + +from sitator.util.chemistry import identify_polyatomic_ions + +import logging +logger = logging.getLogger(__name__) + +def to_origin(sn): + """Anchor all static atoms to the origin; i.e. their positions are absolute.""" + return np.full( + shape = sn.n_static, + fill_value = -1, + dtype = np.int + ) + +# ------ + +def within_polyatomic_ions(**kwargs): + """Anchor the auxiliary atoms of a polyatomic ion to the central atom. + + In phosphate (PO4), for example, the four coordinating Oxygen atoms + will be anchored to the central Phosphorous. + + Args: + **kwargs: passed to ``sitator.util.chemistry.identify_polyatomic_ions``. + """ + def func(sn): + anchors = np.full(shape = sn.n_static, fill_value = -1, dtype = np.int) + polyions = identify_polyatomic_ions(sn.static_structure, **kwargs) + logger.info("Identified %i polyatomic anions: %s" % (len(polyions), Counter(i[0] for i in polyions))) + for _, center, others in polyions: + anchors[others] = center + return anchors + return func + +# ------ +# TODO +def to_heavy_elements(minimum_mass, maximum_distance = np.inf): + """Anchor "light" elements to their nearest "heavy" element. + + Lightness/heaviness is determined by ``minimum_mass``, a cutoff in + """ + def func(sn): + pass + return func diff --git a/sitator/landmark/cluster/dbscan.py b/sitator/landmark/cluster/dbscan.py deleted file mode 100644 index 0ff4856..0000000 --- a/sitator/landmark/cluster/dbscan.py +++ /dev/null @@ -1,68 +0,0 @@ - -import numpy as np - -import numbers -from sklearn.cluster import DBSCAN - -DEFAULT_PARAMS = { - 'eps' : 0.05, - 'min_samples' : 5, - 'n_jobs' : -1 -} - -def do_landmark_clustering(landmark_vectors, - clustering_params, - min_samples, - verbose): - - tmp = DEFAULT_PARAMS.copy() - tmp.update(clustering_params) - clustering_params = tmp - - landmark_classifier = \ - DBSCAN(eps = clustering_params['eps'], - min_samples = clustering_params['min_samples'], - n_jobs = clustering_params['n_jobs'], - metric = 'cosine') - - lmk_lbls = \ - landmark_classifier.fit_predict(landmark_vectors) - - # - Filter low occupancy sites - cluster_counts = np.bincount(lmk_lbls[lmk_lbls >= 0]) - n_assigned = np.sum(cluster_counts) - - min_n_samples_cluster = None - if isinstance(min_samples, numbers.Integral): - min_n_samples_cluster = min_samples - elif isinstance(min_samples, numbers.Real): - min_n_samples_cluster = int(np.floor(min_samples * n_assigned)) - else: - raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples) - - to_remove_mask = cluster_counts < min_n_samples_cluster - to_remove = np.where(to_remove_mask)[0] - - trans_table = np.empty(shape = len(cluster_counts) + 1, dtype = np.int) - # Map unknown to unknown - trans_table[-1] = -1 - # Map removed to unknwon - trans_table[:-1][to_remove_mask] = -1 - # Map known to rescaled known - trans_table[:-1][~to_remove_mask] = np.arange(len(cluster_counts) - len(to_remove)) - # Do the remapping - lmk_lbls = trans_table[lmk_lbls] - - if verbose: - print "DBSCAN landmark: %i/%i assignment counts below threshold %f (%i); %i clusters remain." % \ - (len(to_remove), len(cluster_counts), min_samples, min_n_samples_cluster, len(cluster_counts) - len(to_remove)) - - # Remove counts - cluster_counts = cluster_counts[~to_remove_mask] - - # There are no confidences with DBSCAN, so just give everything confidence 1 - # so as not to screw up later weighting. - confs = np.ones(shape = lmk_lbls.shape, dtype = np.float) - confs[lmk_lbls == -1] = 0.0 - - return cluster_counts, lmk_lbls, confs diff --git a/sitator/landmark/cluster/dotprod.py b/sitator/landmark/cluster/dotprod.py index 20cfc81..5b98e90 100644 --- a/sitator/landmark/cluster/dotprod.py +++ b/sitator/landmark/cluster/dotprod.py @@ -1,5 +1,7 @@ +"""Cluster landmark vectors using the custom online algorithm from the original paper.""" from sitator.util import DotProdClassifier +from sitator.landmark import LandmarkAnalysis DEFAULT_PARAMS = { 'clustering_threshold' : 0.45, @@ -23,4 +25,9 @@ def do_landmark_clustering(landmark_vectors, predict_threshold = clustering_params['assignment_threshold'], verbose = verbose) - return landmark_classifier.cluster_counts, lmk_lbls, lmk_confs + return { + LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts, + LandmarkAnalysis.CLUSTERING_LABELS: lmk_lbls, + LandmarkAnalysis.CLUSTERING_CONFIDENCES : lmk_confs, + LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : landmark_classifier.cluster_centers + } diff --git a/sitator/landmark/cluster/mcl.py b/sitator/landmark/cluster/mcl.py new file mode 100644 index 0000000..741cfcf --- /dev/null +++ b/sitator/landmark/cluster/mcl.py @@ -0,0 +1,131 @@ +"""Cluster landmarks into sites using Markov Clustering and then assign each landmark vector. + +Valid clustering params include: + - ``"assignment_threshold"`` (float between 0 and 1): The similarity threshold + below which a landmark vector will be marked unassigned. + - ``"good_site_normed_threshold"`` (float between 0 and 1): The minimum for + the cosine similarity between a good site's representative unit vector and + its best match landmark vector. + - ``"good_site_projected_threshold"`` (positive float): The minimum inner product + between a good site's representative unit vector and its best match + landmark vector. + - All other params are passed along to `sitator.util.mcl.markov_clustering`. +""" + +import numpy as np + +from sitator.util.progress import tqdm +from sitator.util.mcl import markov_clustering +from sitator.util import DotProdClassifier +from sitator.landmark import LandmarkAnalysis + +from sklearn.covariance import empirical_covariance + +from scipy.sparse.linalg import eigsh + +import logging +logger = logging.getLogger(__name__) + +DEFAULT_PARAMS = { + 'inflation' : 4, + 'assignment_threshold' : 0.7, +} + +def cov2corr( A ): + """ + covariance matrix to correlation matrix. + """ + d = np.sqrt(A.diagonal()) + d[d == 0] = np.inf # Forces correlations to zero where variance is 0 + A = ((A.T/d).T)/d + return A + +def do_landmark_clustering(landmark_vectors, + clustering_params, + min_samples, + verbose): + tmp = DEFAULT_PARAMS.copy() + tmp.update(clustering_params) + clustering_params = tmp + + n_lmk = landmark_vectors.shape[1] + # Center landmark vectors + seen_ntimes = np.count_nonzero(landmark_vectors, axis = 0) + cov = np.dot(landmark_vectors.T, landmark_vectors) / landmark_vectors.shape[0] + corr = cov2corr(cov) + graph = np.clip(corr, 0, None) + for i in range(n_lmk): + if graph[i, i] == 0: # i.e. no self correlation = 0 variance = landmark never seen + graph[i, i] = 1 # Needs a self loop for Markov clustering not to degenerate. Arbitrary value, shouldn't affect anyone else. + + predict_threshold = clustering_params.pop('assignment_threshold') + good_site_normed_threshold = clustering_params.pop('good_site_normed_threshold', predict_threshold) + good_site_project_thresh = clustering_params.pop('good_site_projected_threshold', predict_threshold) + + # -- Cluster Landmarks + clusters = markov_clustering(graph, **clustering_params) + # Filter out single element clusters of landmarks that never appear. + clusters = [list(c) for c in clusters if seen_ntimes[c[0]] > 0] + n_clusters = len(clusters) + centers = np.zeros(shape = (n_clusters, n_lmk)) + maxbuf = np.empty(shape = len(landmark_vectors)) + good_clusters = np.zeros(shape = n_clusters, dtype = np.bool) + for i, cluster in enumerate(clusters): + if len(cluster) == 1: + eigenvec = [1.0] # Eigenvec is trivial + else: + # PCA inspired: + _, eigenvec = eigsh(cov[cluster][:, cluster], k = 1) + eigenvec = eigenvec.T + centers[i, cluster] = eigenvec + np.dot(landmark_vectors, centers[i], out = maxbuf) + np.abs(maxbuf, out = maxbuf) + best_match = np.argmax(maxbuf) + best_match_lvec = landmark_vectors[best_match] + best_match_dot = np.abs(np.dot(best_match_lvec, centers[i])) + best_match_dot_norm = best_match_dot / np.linalg.norm(best_match_lvec) + good_clusters[i] = best_match_dot_norm >= good_site_normed_threshold + good_clusters[i] &= best_match_dot >= good_site_project_thresh + centers[i] /= best_match_dot + + logger.debug("Kept %i/%i landmark clusters as good sites" % (np.sum(good_clusters), len(good_clusters))) + + # Filter out "bad" sites + clusters = [c for i, c in enumerate(clusters) if good_clusters[i]] + centers = centers[good_clusters] + n_clusters = len(clusters) + + landmark_classifier = \ + DotProdClassifier(threshold = np.nan, # We're not fitting + min_samples = min_samples) + + landmark_classifier.set_cluster_centers(centers) + + lmk_lbls, lmk_confs, info = \ + landmark_classifier.fit_predict(landmark_vectors, + predict_threshold = predict_threshold, + predict_normed = False, + verbose = verbose, + return_info = True) + + msk = info['kept_clusters_mask'] + clusters = [c for i, c in enumerate(clusters) if msk[i]] # Only need the ones above the threshold + + # Find the average landmark vector at each site + weighted_reps = clustering_params.get('weighted_representative_landmarks', True) + centers = np.zeros(shape = (len(clusters), n_lmk)) + weights = np.empty(shape = lmk_lbls.shape) + for site in range(len(clusters)): + np.equal(lmk_lbls, site, out = weights) + if weighted_reps: + weights *= lmk_confs + centers[site] = np.average(landmark_vectors, weights = weights, axis = 0) + + + return { + LandmarkAnalysis.CLUSTERING_CLUSTER_SIZE : landmark_classifier.cluster_counts, + LandmarkAnalysis.CLUSTERING_LABELS : lmk_lbls, + LandmarkAnalysis.CLUSTERING_CONFIDENCES: lmk_confs, + LandmarkAnalysis.CLUSTERING_LANDMARK_GROUPINGS : clusters, + LandmarkAnalysis.CLUSTERING_REPRESENTATIVE_LANDMARKS : centers + } diff --git a/sitator/landmark/dynamic_mapping.py b/sitator/landmark/dynamic_mapping.py new file mode 100644 index 0000000..be7ce83 --- /dev/null +++ b/sitator/landmark/dynamic_mapping.py @@ -0,0 +1,5 @@ +import numpy as np + +def within_species(sn): + nums = sn.static_structure.numbers + return [(nums == nums[s]).nonzero()[0] for s in range(sn.n_static)] diff --git a/sitator/landmark/errors.py b/sitator/landmark/errors.py index 55dcfb2..69f06eb 100644 --- a/sitator/landmark/errors.py +++ b/sitator/landmark/errors.py @@ -2,16 +2,14 @@ class LandmarkAnalysisError(Exception): pass -class StaticLatticeError(Exception): +class StaticLatticeError(LandmarkAnalysisError): """Error raised when static lattice atoms break any limits on their movement/position. Attributes: lattice_atoms (list, optional): The indexes of the atoms in the static lattice that caused the error. frame (int, optional): The frame in the trajectory at which the error occured. - """ - TRY_RECENTERING_MSG = "Try recentering the input trajectory (sitator.util.RecenterTrajectory)" def __init__(self, message, lattice_atoms = None, frame = None, try_recentering = False): @@ -25,7 +23,13 @@ def __init__(self, message, lattice_atoms = None, frame = None, try_recentering self.lattice_atoms = lattice_atoms self.frame = frame -class ZeroLandmarkError(Exception): +class ZeroLandmarkError(LandmarkAnalysisError): + """Error raised when a landmark vector containing only zeros is encountered. + + Attributes: + mobile_index (int): Which mobile atom had the all-zero vector. + frame (int): At which frame it was encountered. + """ def __init__(self, mobile_index, frame): message = "Encountered a zero landmark vector for mobile ion %i at frame %i. Try increasing `cutoff_midpoint` and/or decreasing `cutoff_steepness`." % (mobile_index, frame) diff --git a/sitator/landmark/helpers.pyx b/sitator/landmark/helpers.pyx index d54bbfd..2ed9eb9 100644 --- a/sitator/landmark/helpers.pyx +++ b/sitator/landmark/helpers.pyx @@ -9,7 +9,17 @@ from sitator.landmark import StaticLatticeError, ZeroLandmarkError ctypedef double precision -def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_for_zeros = True, tqdm = lambda i: i): +def _fill_landmark_vectors(self, + sn, + verts_np, + site_vert_dists, + frames, + dynmap_compat, + lattice_pt_anchors, + lattice_pt_order, + check_for_zeros = True, + tqdm = lambda i: i, + logger = None): if self._landmark_dimension is None: raise ValueError("_fill_landmark_vectors called before Voronoi!") @@ -19,20 +29,31 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo cdef pbcc = self._pbcc - frame_shift = np.empty(shape = (sn.n_static, 3)) + frame_shift = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype) temp_distbuff = np.empty(shape = sn.n_static, dtype = frames.dtype) mobile_idexes = np.where(sn.mobile_mask)[0] - - if self.dynamic_lattice_mapping: - lattice_map = np.empty(shape = sn.n_static, dtype = np.int) - else: - # Otherwise just map to themselves - lattice_map = np.arange(sn.n_static, dtype = np.int) - - lattice_pt = np.empty(shape = 3, dtype = sn.static_structure.positions.dtype) - lattice_pt_dists = np.empty(shape = sn.n_static, dtype = np.float) + # Static lattice point buffers + lattice_pts_resolved = np.empty(shape = (sn.n_static, 3), dtype = sn.static_structure.positions.dtype) + # Determine resolution order + absolute_lattice_mask = lattice_pt_anchors == -1 + assert len(lattice_pt_order) == len(lattice_pt_anchors) - np.sum(absolute_lattice_mask), "Order must contain all non-absolute anchored static lattice points" + cdef Py_ssize_t [:] lattice_pt_order_c = np.asarray(lattice_pt_order, dtype = np.int) + # Absolute (relative to origin) ones never need to be resolved, put it in + # the buffer now + lattice_pts_resolved[absolute_lattice_mask] = sn.static_structure.positions[absolute_lattice_mask] + assert not np.any(absolute_lattice_mask[lattice_pt_order]), "None of the absolute lattice points should be in the resolution order" + # Precompute the offsets for relative static lattice points: + relative_lattice_offsets = sn.static_structure.positions[lattice_pt_order] - sn.static_structure.positions[lattice_pt_anchors[lattice_pt_order]] + # Buffers for dynamic mapping + max_n_dynmat_compat = max(len(dm) for dm in dynmap_compat) + lattice_pt_dists = np.empty(shape = max_n_dynmat_compat, dtype = np.float) + static_pos_buffer = np.empty(shape = (max_n_dynmat_compat, 3), dtype = lattice_pts_resolved.dtype) static_positions_seen = np.empty(shape = sn.n_static, dtype = np.bool) + lattice_map = np.empty(shape = sn.n_static, dtype = np.int) + # Instant static position buffers + static_positions = np.empty(shape = (sn.n_static, 3), dtype = frames.dtype) + static_mask_idexes = sn.static_mask.nonzero()[0] # - Precompute cutoff function rounding point # TODO: Think about the 0.0001 value @@ -42,43 +63,65 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo self._cutoff_steepness, 0.0001) + cdef Py_ssize_t n_all_zero_lvecs = 0 + cdef Py_ssize_t landmark_dim = self._landmark_dimension cdef Py_ssize_t current_landmark_i = 0 - # Iterate through time - for i, frame in enumerate(tqdm(frames, desc = "Frame")): - static_positions = frame[sn.static_mask] + cdef Py_ssize_t nearest_static_position + cdef precision nearest_static_distance + cdef Py_ssize_t n_dynmap_allowed + # Iterate through time + for i, frame in enumerate(tqdm(frames, desc = "Landmark Frame")): + # Copy static positions to buffer + np.take(frame, + static_mask_idexes, + out = static_positions, + axis = 0, + mode = 'clip') # Every frame, update the lattice map static_positions_seen.fill(False) - for lattice_index in xrange(sn.n_static): - lattice_pt = sn.static_structure.positions[lattice_index] + # - Resolve static lattice positions from their origins + for order_i, lattice_index in enumerate(lattice_pt_order_c): + lattice_pts_resolved[lattice_index] = static_positions[lattice_pt_anchors[lattice_index]] + lattice_pts_resolved[lattice_index] += relative_lattice_offsets[order_i] - if self.dynamic_lattice_mapping: - # Only compute all distances if dynamic remapping is on - pbcc.distances(lattice_pt, static_positions, out = lattice_pt_dists) - nearest_static_position = np.argmin(lattice_pt_dists) - nearest_static_distance = lattice_pt_dists[nearest_static_position] - else: - nearest_static_position = lattice_index - nearest_static_distance = pbcc.distances(lattice_pt, static_positions[nearest_static_position:nearest_static_position+1])[0] + # - Map static positions to static lattice sites + for lattice_index in xrange(sn.n_static): + dynmap_allowed = dynmap_compat[lattice_index] + n_dynmap_allowed = len(dynmap_allowed) + + np.take(static_positions, + dynmap_allowed, + out = static_pos_buffer[:n_dynmap_allowed], + axis = 0, + mode = 'clip') + + pbcc.distances( + lattice_pts_resolved[lattice_index], + static_pos_buffer[:n_dynmap_allowed], + out = lattice_pt_dists[:n_dynmap_allowed] + ) + nearest_static_position = np.argmin(lattice_pt_dists[:n_dynmap_allowed]) + nearest_static_distance = lattice_pt_dists[nearest_static_position] + nearest_static_position = dynmap_allowed[nearest_static_position] if static_positions_seen[nearest_static_position]: # We've already seen this one... error - print "Static atom %i is the closest to more than one static lattice position" % nearest_static_position + logger.warning("Static atom %i is the closest to more than one static lattice position" % nearest_static_position) #raise ValueError("Static atom %i is the closest to more than one static lattice position" % nearest_static_position) static_positions_seen[nearest_static_position] = True if nearest_static_distance > self.static_movement_threshold: - raise StaticLatticeError("No static atom position within %f A threshold of static lattice position %i" % (self.static_movement_threshold, lattice_index), + raise StaticLatticeError("Nearest static atom to lattice position %i is %.2fÅ away, above threshold of %.2fÅ" % (lattice_index, nearest_static_distance, self.static_movement_threshold), lattice_atoms = [lattice_index], frame = i, try_recentering = True) - if self.dynamic_lattice_mapping: - lattice_map[lattice_index] = nearest_static_position + lattice_map[lattice_index] = nearest_static_position # In normal circumstances, every current static position should be assigned. # Just a sanity check @@ -89,13 +132,14 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo frame = i, try_recentering = True) - + # - Compute landmark vectors for mobile for j in xrange(sn.n_mobile): mobile_pt = frame[mobile_idexes[j]] # Shift the Li in question to the center of the unit cell - np.copyto(frame_shift, frame[sn.static_mask]) - frame_shift += (pbcc.cell_centroid - mobile_pt) + frame_shift[:] = static_positions + frame_shift -= mobile_pt + frame_shift += pbcc.cell_centroid # Wrap all positions into the unit cell pbcc.wrap_points(frame_shift) @@ -111,17 +155,24 @@ def _fill_landmark_vectors(self, sn, verts_np, site_vert_dists, frames, check_fo cutoff_round_to_zero, temp_distbuff) - if check_for_zeros and (np.count_nonzero(self._landmark_vectors[current_landmark_i]) == 0): - raise ZeroLandmarkError(mobile_index = j, frame = i) + if np.count_nonzero(self._landmark_vectors[current_landmark_i]) == 0: + if check_for_zeros: + raise ZeroLandmarkError(mobile_index = j, frame = i) + else: + n_all_zero_lvecs += 1 current_landmark_i += 1 + self.n_all_zero_lvecs = n_all_zero_lvecs + + cdef precision cutoff_round_to_zero_point(precision cutoff_midpoint, precision cutoff_steepness, precision threshold): # Computed by solving for x: return cutoff_midpoint + log((1/threshold) - 1.) / cutoff_steepness + @cython.boundscheck(False) @cython.wraparound(False) cdef void fill_landmark_vec(precision [:,:] landmark_vectors, @@ -139,18 +190,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors, precision cutoff_round_to_zero, precision [:] distbuff) nogil: - # Pure Python equiv: - # for k in xrange(landmark_dim): - # lvec = np.linalg.norm(lattice_positions[verts[k]] - cell_centroid, axis = 1) - # past_cutoff = lvec > cutoff - - # # Short circut it, since the product then goes to zero too. - # if np.any(past_cutoff): - # landmark_vectors[(i * n_li) + j, k] = 0 - # else: - # lvec = (np.cos((np.pi / cutoff) * lvec) + 1.0) / 2.0 - # landmark_vectors[(i * n_li) + j, k] = np.product(lvec) - # Fill the landmark vector cdef int [:] vert cdef Py_ssize_t v @@ -168,11 +207,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors, distbuff[idex] = temp - # if temp > cutoff: - # distbuff[idex] = 0.0 - # else: - # distbuff[idex] = (cos((M_PI / cutoff) * temp) + 1.0) * 0.5 - # For each component for k in xrange(landmark_dim): ci = 1.0 @@ -196,7 +230,6 @@ cdef void fill_landmark_vec(precision [:,:] landmark_vectors, temp = 1.0 / (1.0 + exp(cutoff_steepness * (temp - cutoff_midpoint))) # Multiply into accumulator - #ci *= distbuff[v] ci *= temp # "Normalize" to number of vertices diff --git a/sitator/landmark/pointmerge.py b/sitator/landmark/pointmerge.py deleted file mode 100644 index ece3107..0000000 --- a/sitator/landmark/pointmerge.py +++ /dev/null @@ -1,86 +0,0 @@ - -import numpy as np - -# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698 -import sys -try: - ipy_str = str(type(get_ipython())) - if 'zmqshell' in ipy_str: - from tqdm import tqdm_notebook as tqdm - if 'terminal' in ipy_str: - from tqdm import tqdm -except: - if sys.stderr.isatty(): - from tqdm import tqdm - else: - def tqdm(iterable, **kwargs): - return iterable - -def merge_points_soap_paths(tsoap, - pbcc, - points, - connectivity_dict, - threshold, - n_steps = 5, - sanity_check_cutoff = np.inf, - verbose = True): - """Merge points using SOAP paths method. - - :param SOAP tsoap: to compute SOAPs with. - :param dict connectivity_dict: Maps a point index to a set of point indexes - it is connected to, however defined. - :param threshold: Similarity threshold, 0 < threshold <= 1 - """ - - merge_sets = set() - - points_along = np.empty(shape = (n_steps, 3), dtype = np.float) - step_vec_mult = np.linspace(0.0, 1.0, num = n_steps)[:, np.newaxis] - - for pt_idex in tqdm(connectivity_dict.keys()): - merge_set = set() - current_pts = [pt_idex] - from_soap = None - keep_going = True - while keep_going: - added_this_iter = set() - for edge_from in current_pts: - offset = pbcc.cell_centroid - points[edge_from] - edge_from_pt = pbcc.cell_centroid - - for edge_to in connectivity_dict[edge_from] - merge_set: - edge_to_pt = points[edge_to].copy() - edge_to_pt += offset - pbcc.wrap_point(edge_to_pt) - - step_vec = edge_to_pt - edge_from_pt - edge_length = np.linalg.norm(step_vec) - - assert edge_length <= sanity_check_cutoff, "edge_length %s" % edge_length - - # Points along the line - for i in xrange(n_steps): - points_along[i] = step_vec - points_along *= step_vec_mult - points_along += edge_from_pt - # Re-center back to original center - points_along -= offset - # Wrap back into original unit cell - the one frame_atoms has - pbcc.wrap_points(points_along) - - merge = tsoap.soaps_similar_for_points(points_along, threshold = threshold) - - if merge: - added_this_iter.add(edge_from) - added_this_iter.add(edge_to) - - if len(added_this_iter) == 0: - keep_going = False - else: - current_pts = added_this_iter - merge_set - merge_set.update(added_this_iter) - - if len(merge_set) > 0: - merge_sets.add(frozenset(merge_set)) - - return merge_sets diff --git a/sitator/misc/GenerateAroundSites.py b/sitator/misc/GenerateAroundSites.py index dea0306..c668239 100644 --- a/sitator/misc/GenerateAroundSites.py +++ b/sitator/misc/GenerateAroundSites.py @@ -4,7 +4,12 @@ from sitator.util import PBCCalculator class GenerateAroundSites(object): - """Generate n normally distributed sites around each input site""" + """Generate ``n`` normally distributed sites around each site. + + Args: + n (int): How many sites to produce for each input site. + sigma (float): Standard deviation of the spatial Gaussian. + """ def __init__(self, n, sigma): self.n = n self.sigma = sigma @@ -14,9 +19,9 @@ def run(self, sn): out = sn.copy() pbcc = PBCCalculator(sn.structure.cell) - print out.centers.shape + newcenters = out.centers.repeat(self.n, axis = 0) - print newcenters.shape + assert len(newcenters) == self.n * len(out.centers) newcenters += self.sigma * np.random.standard_normal(size = newcenters.shape) pbcc.wrap_points(newcenters) diff --git a/sitator/misc/GenerateClampedTrajectory.pyx b/sitator/misc/GenerateClampedTrajectory.pyx new file mode 100644 index 0000000..d442664 --- /dev/null +++ b/sitator/misc/GenerateClampedTrajectory.pyx @@ -0,0 +1,128 @@ +# cython: language_level=3 + +import numpy as np + +from sitator import SiteTrajectory +from sitator.util.PBCCalculator cimport PBCCalculator, precision +from sitator.util.progress import tqdm + +from libc.math cimport floor + +ctypedef Py_ssize_t site_int + +class GenerateClampedTrajectory(object): + """Create a real-space trajectory with the fixed site/static structure positions. + + Generate a real-space trajectory where the atoms are clamped to the fixed + positions of the current site/their fixed static position. + + Args: + wrap (bool): If ``True``, all clamped positions will be in + the unit cell; if ``False``, the clamped position will be the minimum + image of the clamped position with respect to the real-space position. + (This can generate a clamped, unwrapped real-space trajectory + from an unwrapped real space trajectory.) + pass_through_unassigned (bool): If ``True``, when a + mobile atom is supposed to be clamped but is unassigned at some + frame, its real-space position will be passed through from the + real trajectory. If False, an error will be raised. + """ + def __init__(self, wrap = False, pass_through_unassigned = False): + self.wrap = wrap + self.pass_through_unassigned = pass_through_unassigned + + + def run(self, st, clamp_mask = None): + """Create a real-space trajectory with the fixed site/static structure positions. + + Generate a real-space trajectory where the atoms indicated in ``clamp_mask`` -- + any mixture of static and mobile -- are clamped to: (1) the fixed position of + their current site, if mobile, or (2) the corresponding fixed position in + the ``SiteNetwork``'s static structure, if static. + + Atoms not indicated in ``clamp_mask`` will have their positions from + ``real_traj`` passed through. + + Args: + clamp_mask (ndarray, len(sn.structure)) + Returns: + ndarray (n_frames x n_atoms x 3) + """ + wrap = self.wrap + pass_through_unassigned = self.pass_through_unassigned + cell = st._sn.structure.cell + cdef PBCCalculator pbcc = PBCCalculator(cell) + + n_atoms = len(st._sn.structure) + if clamp_mask is None: + clamp_mask = np.ones(shape = n_atoms, dtype = np.bool) + if st._real_traj is None and not np.all(clamp_mask): + raise RuntimeError("This `SiteTrajectory` has no real-space trajectory, but the given clamp mask leaves some atoms unclamped.") + + clamptrj = np.empty(shape = (st.n_frames, n_atoms, 3)) + # Pass through unclamped positions + if not np.all(clamp_mask): + clamptrj[:, ~clamp_mask, :] = st._real_traj[:, ~clamp_mask, :] + # Clamp static atoms + static_clamp = clamp_mask & st._sn.static_mask + clamptrj[:, static_clamp, :] = st._sn.structure.get_positions()[static_clamp] + # Clamp mobile atoms + mobile_clamp = clamp_mask & st._sn.mobile_mask + selected_sitetraj = st._traj[:, mobile_clamp] + mobile_clamp_indexes = np.where(mobile_clamp)[0] + if not pass_through_unassigned and np.min(selected_sitetraj) < 0: + raise RuntimeError("The mobile atoms indicated for clamping are unassigned at some point during the trajectory and `pass_through_unassigned` is set to False. Try `assign_to_last_known_site()`?") + + cdef site_int at_site + cdef Py_ssize_t frame_i + cdef Py_ssize_t mobile_i + cdef Py_ssize_t [:] mobile_clamp_indexes_c = mobile_clamp_indexes + cdef precision [:, :] buf + cdef precision [:] site_pt + cdef site_int site_unknown = SiteTrajectory.SITE_UNKNOWN + cdef const site_int [:, :] sitetrj_c = st._traj + cdef precision [:, :, :] clamptrj_c = clamptrj + cdef const precision [:, :, :] realtrj_c = st._real_traj + cdef const precision [:, :] centers_c = st.site_network.centers + cdef int site_mic_int + cdef int [3] site_mic + cdef int [3] pt_in_image + cdef precision [:, :] centers_crystal_c + cdef Py_ssize_t dim + if wrap: + for frame_i in tqdm(range(len(clamptrj))): + for mobile_i in mobile_clamp_indexes_c: + at_site = sitetrj_c[frame_i, mobile_i] + if at_site == site_unknown: # we already know that this means pass_through_unassigned = True + clamptrj_c[frame_i, mobile_i] = realtrj_c[frame_i, mobile_i] + else: + clamptrj_c[frame_i, mobile_i] = centers_c[at_site] + else: + buf = np.empty(shape = (1, 3)) + site_pt = np.empty(shape = 3) + centers_crystal_c = st.site_network.centers.copy() + pbcc.to_cell_coords(centers_crystal_c) + for frame_i in tqdm(range(len(clamptrj))): + for mobile_i in mobile_clamp_indexes_c: + buf[:, :] = realtrj_c[frame_i, mobile_i] + at_site = sitetrj_c[frame_i, mobile_i] + if at_site == site_unknown: # we already know that this means pass_through_unassigned = True + clamptrj_c[frame_i, mobile_i] = realtrj_c[frame_i, mobile_i] + continue + site_pt[:] = centers_c[at_site] + pbcc.wrap_point(site_pt) + pbcc.wrap_points(buf) + site_mic_int = pbcc.min_image(buf[0], site_pt) + for dim in range(3): + site_mic[dim] = (site_mic_int // 10**(2 - dim) % 10) - 1 + buf[:, :] = realtrj_c[frame_i, mobile_i] + pbcc.to_cell_coords(buf) + for dim in range(3): + pt_in_image[dim] = floor(buf[0, dim]) + site_mic[dim] + buf[0] = centers_crystal_c[at_site] + for dim in range(3): + buf[0, dim] += pt_in_image[dim] + pbcc.to_real_coords(buf) + clamptrj_c[frame_i, mobile_i] = buf[0] + + return clamptrj diff --git a/sitator/misc/NAvgsPerSite.py b/sitator/misc/NAvgsPerSite.py index 41d74bf..b4d1f01 100644 --- a/sitator/misc/NAvgsPerSite.py +++ b/sitator/misc/NAvgsPerSite.py @@ -5,15 +5,15 @@ from sitator.util import PBCCalculator class NAvgsPerSite(object): - """Given a SiteTrajectory, return a SiteNetwork containing n avg. positions per site. + """Given a ``SiteTrajectory``, return a ``SiteNetwork`` containing n avg. positions per site. - The `types` of sites in the output are the index of the site in the input that generated - that average. + The ``site_types`` of sites in the output are the index of the site in the + input that generated that average. :param int n: How many averages to take :param bool error_on_insufficient: Whether to throw an error if n points cannot be provided for a site, or just take all that are available. - :param bool weighted: Use SiteTrajectory confidences to weight the averages. + :param bool weighted: Use ``SiteTrajectory`` confidences to weight the averages. """ def __init__(self, n, @@ -25,6 +25,12 @@ def __init__(self, n, self.weighted = weighted def run(self, st): + """ + Args: + st (SiteTrajectory) + Returns: + A ``SiteNetwork``. + """ assert isinstance(st, SiteTrajectory) if st.real_trajectory is None: raise ValueError("SiteTrajectory must have associated real trajectory.") @@ -32,10 +38,12 @@ def run(self, st): pbcc = PBCCalculator(st.site_network.structure.cell) # Maximum length centers = np.empty(shape = (self.n * st.site_network.n_sites, 3), dtype = st.real_trajectory.dtype) + centers.fill(np.nan) types = np.empty(shape = centers.shape[0], dtype = np.int) + types.fill(np.nan) current_idex = 0 - for site in xrange(st.site_network.n_sites): + for site in range(st.site_network.n_sites): if self.weighted: pts, confs = st.real_positions_for_site(site, return_confidences = True) else: @@ -46,7 +54,7 @@ def run(self, st): if len(pts) > self.n: sanity = 0 - for i in xrange(self.n): + for i in range(self.n): ps = pts[i::self.n] sanity += len(ps) c = confs[i::self.n] @@ -64,7 +72,9 @@ def run(self, st): types[old_idex:current_idex] = site sn = st.site_network.copy() - sn.centers = centers - sn.site_types = types + sn.centers = centers[:current_idex] + sn.site_types = types[:current_idex] + + assert not (np.isnan(np.sum(sn.centers)) or np.isnan(np.sum(sn.site_types))) return sn diff --git a/sitator/misc/SiteVolumes.py b/sitator/misc/SiteVolumes.py deleted file mode 100644 index 7a8f48c..0000000 --- a/sitator/misc/SiteVolumes.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np - -from scipy.spatial import ConvexHull -from scipy.spatial.qhull import QhullError - -from sitator import SiteTrajectory -from sitator.util import PBCCalculator - -class SiteVolumes(object): - """Computes the volumes of convex hulls around all positions associated with a site. - - Adds the `site_volumes` and `site_surface_areas` attributes to the SiteNetwork. - """ - def __init__(self, n_recenterings = 8): - self.n_recenterings = n_recenterings - - def run(self, st): - vols = np.empty(shape = st.site_network.n_sites, dtype = np.float) - areas = np.empty(shape = st.site_network.n_sites, dtype = np.float) - - pbcc = PBCCalculator(st.site_network.structure.cell) - - for site in xrange(st.site_network.n_sites): - pos = st.real_positions_for_site(site) - - assert pos.flags['OWNDATA'] - - vol = np.inf - area = None - for i in xrange(self.n_recenterings): - # Recenter - offset = pbcc.cell_centroid - pos[int(i * (len(pos)/self.n_recenterings))] - pos += offset - pbcc.wrap_points(pos) - - try: - hull = ConvexHull(pos) - except QhullError as qhe: - print "For site %i, iter %i: %s" % (site, i, qhe) - vols[site] = np.nan - areas[site] = np.nan - continue - - if hull.volume < vol: - vol = hull.volume - area = hull.area - - vols[site] = vol - areas[site] = area - - st.site_network.add_site_attribute('site_volumes', vols) - st.site_network.add_site_attribute('site_surface_areas', areas) diff --git a/sitator/misc/__init__.py b/sitator/misc/__init__.py index 66901f4..8240d62 100644 --- a/sitator/misc/__init__.py +++ b/sitator/misc/__init__.py @@ -1,4 +1,4 @@ -from NAvgsPerSite import NAvgsPerSite -from GenerateAroundSites import GenerateAroundSites -from SiteVolumes import SiteVolumes +from .NAvgsPerSite import NAvgsPerSite +from .GenerateAroundSites import GenerateAroundSites +from .GenerateClampedTrajectory import GenerateClampedTrajectory diff --git a/sitator/misc/oldio.py b/sitator/misc/oldio.py new file mode 100644 index 0000000..ddd7ce2 --- /dev/null +++ b/sitator/misc/oldio.py @@ -0,0 +1,82 @@ +import numpy as np + +import tempfile +import tarfile +import os + +import ase +import ase.io + +from sitator import SiteNetwork + +_STRUCT_FNAME = "structure.xyz" +_SMASK_FNAME = "static_mask.npy" +_MMASK_FNAME = "mobile_mask.npy" +_MAIN_FNAMES = ['centers', 'vertices', 'site_types'] + +def save(sn, file): + """Save this SiteNetwork to a tar archive.""" + with tempfile.TemporaryDirectory() as tmpdir: + # -- Write the structure + ase.io.write(os.path.join(tmpdir, _STRUCT_FNAME), sn.structure, parallel = False) + # -- Write masks + np.save(os.path.join(tmpdir, _SMASK_FNAME), sn.static_mask) + np.save(os.path.join(tmpdir, _MMASK_FNAME), sn.mobile_mask) + # -- Write what we have + for arrname in _MAIN_FNAMES: + if not getattr(sn, arrname) is None: + np.save(os.path.join(tmpdir, "%s.npy" % arrname), getattr(sn, arrname)) + # -- Write all site/edge attributes + for atype, attrs in zip(("site_attr", "edge_attr"), (sn._site_attrs, sn._edge_attrs)): + for attr in attrs: + np.save(os.path.join(tmpdir, "%s-%s.npy" % (atype, attr)), attrs[attr]) + # -- Write final archive + with tarfile.open(file, mode = 'w:gz', format = tarfile.PAX_FORMAT) as outf: + outf.add(tmpdir, arcname = "") + + +def from_file(file): + """Load a SiteNetwork from a tar file/file descriptor.""" + all_others = {} + site_attrs = {} + edge_attrs = {} + structure = None + with tarfile.open(file, mode = 'r:gz', format = tarfile.PAX_FORMAT) as input: + # -- Load everything + for member in input.getmembers(): + if member.name == '': + continue + f = input.extractfile(member) + if member.name == _STRUCT_FNAME: + with tempfile.TemporaryDirectory() as tmpdir: + input.extract(member, path = tmpdir) + structure = ase.io.read(os.path.join(tmpdir, member.name), format = 'xyz') + else: + basename = os.path.splitext(os.path.basename(member.name))[0] + data = np.load(f) + if basename.startswith("site_attr"): + site_attrs[basename.split('-')[1]] = data + elif basename.startswith("edge_attr"): + edge_attrs[basename.split('-')[1]] = data + else: + all_others[basename] = data + + # Create SiteNetwork + assert not structure is None + assert all(k in all_others for k in ("static_mask", "mobile_mask")), "Malformed SiteNetwork file." + sn = SiteNetwork(structure, + all_others['static_mask'], + all_others['mobile_mask']) + if 'centers' in all_others: + sn.centers = all_others['centers'] + for key in all_others: + if key in ('centers', 'static_mask', 'mobile_mask'): + continue + setattr(sn, key, all_others[key]) + + assert all(len(sa) == sn.n_sites for sa in site_attrs.values()) + assert all(ea.shape == (sn.n_sites, sn.n_sites) for ea in edge_attrs.values()) + sn._site_attrs = site_attrs + sn._edge_attrs = edge_attrs + + return sn diff --git a/sitator/network/DiffusionPathwayAnalysis.py b/sitator/network/DiffusionPathwayAnalysis.py new file mode 100644 index 0000000..7dc714c --- /dev/null +++ b/sitator/network/DiffusionPathwayAnalysis.py @@ -0,0 +1,228 @@ + +import numpy as np + +import numbers +import itertools + +from scipy.sparse import lil_matrix +from scipy.sparse.csgraph import connected_components + +from sitator import SiteNetwork +from sitator.util import PBCCalculator + +import logging +logger = logging.getLogger(__name__) + +class DiffusionPathwayAnalysis(object): + """Find connected diffusion pathways in a SiteNetwork. + + :param float|int connectivity_threshold: The percentage of the total number of + (non-self) jumps, or absolute number of jumps, that must occur over an edge + for it to be considered connected. + :param int minimum_n_sites: The minimum number of sites that must be part of + a pathway for it to be considered as such. + :param bool true_periodic_pathways: Whether only to return true periodic + pathways that include sites and their periodic images (i.e. conductive + in the bulk) rather than just connected components. If ``True``, + ``minimum_n_sites`` is NOT respected. + """ + + NO_PATHWAY = -1 + + def __init__(self, + connectivity_threshold = 1, + true_periodic_pathways = True, + minimum_n_sites = 0): + assert minimum_n_sites >= 0 + + self.true_periodic_pathways = true_periodic_pathways + self.connectivity_threshold = connectivity_threshold + self.minimum_n_sites = minimum_n_sites + + def run(self, sn, return_count = False, return_direction = False): + """ + Expects a ``SiteNetwork`` that has had a ``JumpAnalysis`` run on it. + + Adds information to ``sn`` in place. + + Args: + sn (SiteNetwork): Must have jump statistics from a ``JumpAnalysis``. + return_count (bool): Return the number of connected pathways. + return_direction (bool): If True and `self.true_periodic_pathways`, + return for each pathway the direction matrix indicating which + directions it connects accross periodic boundaries. + Returns: + sn, [n_pathways], [list of set of tuple] + """ + if not sn.has_attribute('n_ij'): + raise ValueError("SiteNetwork has no `n_ij`; run a JumpAnalysis on it first.") + + nondiag = np.ones(shape = sn.n_ij.shape, dtype = np.bool) + np.fill_diagonal(nondiag, False) + n_non_self_jumps = np.sum(sn.n_ij[nondiag]) + + if isinstance(self.connectivity_threshold, numbers.Integral): + threshold = self.connectivity_threshold + elif isinstance(self.connectivity_threshold, numbers.Real): + threshold = self.connectivity_threshold * n_non_self_jumps + else: + raise TypeError("Don't know how to interpret connectivity_threshold `%s`" % self.connectivity_threshold) + + connectivity_matrix = sn.n_ij >= threshold + + if self.true_periodic_pathways: + connectivity_matrix, mask_000, images = self._build_mic_connmat(sn, connectivity_matrix) + + n_ccs, ccs = connected_components(connectivity_matrix, + directed = False, # even though the matrix is symmetric + connection = 'weak') # diffusion could be unidirectional + + _, counts = np.unique(ccs, return_counts = True) + + if self.true_periodic_pathways: + # is_pathway = np.ones(shape = n_ccs, dtype = np.bool) + # We have to check that the pathways include a site and its periodic + # image, and throw out those that don't + new_n_ccs = 1 + new_ccs = np.zeros(shape = len(sn), dtype = np.int) + + # Add a non-path (contains no sites, all False) so the broadcasting works + site_masks = [np.zeros(shape = len(sn), dtype = np.bool)] + + pathway_dirs = [set()] + + for pathway_i in np.arange(n_ccs): + path_mask = ccs == pathway_i + + if not np.any(path_mask & mask_000): + # If the pathway is entirely outside the unit cell, we don't care + continue + + # Sum along each site's periodic images, giving a count site-by-site + site_counts = np.sum(path_mask.reshape((-1, sn.n_sites)).astype(np.int), axis = 0) + if not np.any(site_counts > 1): + # Not percolating; doesn't contain any site and its periodic image. + continue + + pdirs = set() + for periodic_site in np.where(site_counts > 1)[0]: + at_images = images[path_mask[periodic_site::len(sn)]] + # The direction from 0 to 1 should be the same as any other pair. + # Cause periodic. + direction = (at_images[0] - at_images[1]) != 0 + pdirs.add(tuple(direction)) + + cur_site_mask = site_counts > 0 + + intersects_with = np.where(np.any(np.logical_and(site_masks, cur_site_mask), axis = 1))[0] + # Merge them: + if len(intersects_with) > 0: + path_mask = cur_site_mask | np.logical_or.reduce([site_masks[i] for i in intersects_with], axis = 0) + pdirs = pdirs.union(*[pathway_dirs[i] for i in intersects_with]) + else: + path_mask = cur_site_mask + # Remove individual merged paths + # Going in reverse order means indexes don't become invalid as deletes happen + for i in sorted(intersects_with, reverse=True): + del site_masks[i] + del pathway_dirs[i] + # Add new (super)path + site_masks.append(path_mask) + pathway_dirs.append(pdirs) + + new_ccs[path_mask] = new_n_ccs + new_n_ccs += 1 + + n_ccs = new_n_ccs + ccs = new_ccs + # Only actually take the ones that were assigned to in the end + # This will deal with the ones that were merged. + is_pathway = np.in1d(np.arange(n_ccs), ccs) + is_pathway[0] = False # Cause this was the "unassigned" value, we initialized with zeros up above + pathway_dirs = pathway_dirs[1:] # Get rid of the dummy pathway's direction + assert len(pathway_dirs) == np.sum(is_pathway) + else: + is_pathway = counts >= self.minimum_n_sites + + logging.info("Taking all edges with at least %i/%i jumps..." % (threshold, n_non_self_jumps)) + logging.info("Found %i connected components, of which %i are large enough to qualify as pathways (%i sites)." % (n_ccs, np.sum(is_pathway), self.minimum_n_sites)) + + n_pathway = np.sum(is_pathway) + translation = np.empty(n_ccs, dtype = np.int) + translation[~is_pathway] = DiffusionPathwayAnalysis.NO_PATHWAY + translation[is_pathway] = np.arange(n_pathway) + + node_pathways = translation[ccs] + + outmat = np.empty(shape = (sn.n_sites, sn.n_sites), dtype = np.int) + + for i in range(sn.n_sites): + rowmask = node_pathways[i] == node_pathways + outmat[i, rowmask] = node_pathways[i] + outmat[i, ~rowmask] = DiffusionPathwayAnalysis.NO_PATHWAY + + sn.add_site_attribute('site_diffusion_pathway', node_pathways) + sn.add_edge_attribute('edge_diffusion_pathway', outmat) + + retval = [sn] + if return_count: + retval.append(n_pathway) + if return_direction: + retval.append(pathway_dirs) + return tuple(retval) + + + def _build_mic_connmat(self, sn, connectivity_matrix): + # We use a 3x3x3 = 27 supercell, so there are 27x as many sites + assert len(sn) == connectivity_matrix.shape[0] + + images = np.asarray(list(itertools.product(range(-1, 2), repeat = 3))) + image_to_idex = dict((100 * (image[0] + 1) + 10 * (image[1] + 1) + (image[2] + 1), i) for i, image in enumerate(images)) + n_images = len(images) + assert n_images == 27 + + n_sites = len(sn) + pos = sn.centers #.copy() # TODO: copy not needed after reinstall of sitator! + n_total_sites = len(images) * n_sites + newmat = lil_matrix((n_total_sites, n_total_sites), dtype = np.bool) + + mask_000 = np.zeros(shape = n_total_sites, dtype = np.bool) + index_000 = image_to_idex[111] + mask_000[index_000:index_000 + n_sites] = True + assert np.sum(mask_000) == len(sn) + + pbcc = PBCCalculator(sn.structure.cell) + buf = np.empty(shape = 3) + + internal_mat = np.zeros_like(connectivity_matrix) + external_connections = [] + for from_site, to_site in zip(*np.where(connectivity_matrix)): + buf[:] = pos[to_site] + if pbcc.min_image(pos[from_site], buf) == 111: + # If we're in the main image, keep the connection: it's internal + internal_mat[from_site, to_site] = True + #internal_mat[to_site, from_site] = True # fake FIXME + else: + external_connections.append((from_site, to_site)) + #external_connections.append((to_site, from_site)) # FAKE FIXME + + for image_idex, image in enumerate(images): + # Make the block diagonal + newmat[image_idex * n_sites:(image_idex + 1) * n_sites, + image_idex * n_sites:(image_idex + 1) * n_sites] = internal_mat + + # Check all external connections from this image; add other sparse entries + for from_site, to_site in external_connections: + buf[:] = pos[to_site] + to_mic = pbcc.min_image(pos[from_site], buf) + to_in_image = image + [(to_mic // 10**(2 - i) % 10) - 1 for i in range(3)] # FIXME: is the -1 right + assert to_in_image is not None, "%s" % to_in_image + assert np.max(np.abs(to_in_image)) <= 2 + if not np.any(np.abs(to_in_image) > 1): + to_in_image = 100 * (to_in_image[0] + 1) + 10 * (to_in_image[1] + 1) + 1 * (to_in_image[2] + 1) + newmat[image_idex * n_sites + from_site, + image_to_idex[to_in_image] * n_sites + to_site] = True + + assert np.sum(newmat) >= n_images * np.sum(internal_mat) # Lowest it can be is if every one is internal + + return newmat, mask_000, images diff --git a/sitator/network/MergeSitesByBarrier.py b/sitator/network/MergeSitesByBarrier.py new file mode 100644 index 0000000..3865952 --- /dev/null +++ b/sitator/network/MergeSitesByBarrier.py @@ -0,0 +1,139 @@ +import numpy as np + +import itertools +import math + +from scipy.sparse.csgraph import connected_components + +from ase.calculators.calculator import all_changes + +from sitator.util import PBCCalculator +from sitator.util.progress import tqdm +from sitator.network.merging import MergeSites + +import logging +logger = logging.getLogger(__name__) + +class MergeSitesByBarrier(MergeSites): + """Merge sites based on the energy barrier between them. + + Uses a cheap coordinate driving system; this may not be sophisticated enough + for complex cases. For each pair of sites within the pairwise distance cutoff, + a linear spatial interpolation is applied to produce ``n_driven_images``. + Two sites are considered mergable if their energies are within + ``final_initial_energy_threshold`` and the barrier between them is below + ``barrier_threshold``. The barrier is defined as the maximum image energy minus + the average of the initial and final energy. + + The energies of the mobile atom are calculated in a static lattice given + by ``coordinating_mask``; if ``None``, this is set to the system's + ``static_mask``. + + For resonable performance, ``calculator`` should be something simple like + ``ase.calculators.lj.LennardJones``. + + Takes species of first mobile atom as mobile species. + + Args: + calculator (ase.Calculator): For computing total potential energies. + barrier_threshold (float, eV): The barrier value above which two sites + are not mergable. + n_driven_images (int, default: None): The number of evenly distributed + driven images to use. + maximum_pairwise_distance (float, Angstrom): The maximum distance + between two sites for them to be considered for merging. + minimum_jumps_mergable (int): The minimum number of observed jumps + between two sites for their merging to be considered. Setting this + higher can avoid unnecessary computations. + maximum_merge_distance (float, Angstrom): The maxiumum pairwise distance + among a group of sites chosed to be merged. + """ + def __init__(self, + calculator, + barrier_threshold, + n_driven_images = 20, + maximum_pairwise_distance = 2, + minimum_jumps_mergable = 1, + maximum_merge_distance = 2, + **kwargs): + super().__init__(maximum_merge_distance = maximum_merge_distance, **kwargs) + self.barrier_threshold = barrier_threshold + self.maximum_pairwise_distance = maximum_pairwise_distance + self.minimum_jumps_mergable = minimum_jumps_mergable + assert n_driven_images >= 3, "Must have at least initial, transition, and final." + self.n_driven_images = n_driven_images + self.calculator = calculator + + + def _get_sites_to_merge(self, st, coordinating_mask = None): + sn = st.site_network + + # -- Compute jump statistics + if not sn.has_attribute('n_ij'): + ja = JumpAnalysis() + ja.run(st) + + pos = sn.centers + if coordinating_mask is None: + coordinating_mask = sn.static_mask + else: + assert not np.any(coordinating_mask & sn.mobile_mask) + # -- Build images + mobile_idex = np.where(sn.mobile_mask)[0][0] + one_mobile_structure = sn.structure[coordinating_mask] + one_mobile_structure.extend(sn.structure[mobile_idex]) + mobile_idex = -1 + one_mobile_structure.set_calculator(self.calculator) + interpolation_coeffs = np.linspace(0, 1, self.n_driven_images) + energies = np.empty(shape = self.n_driven_images) + + # -- Decide on pairs to check + pbcc = PBCCalculator(sn.structure.cell) + dists = pbcc.pairwise_distances(pos) + # At the start, all within distance cutoff are mergable + mergable = dists <= self.maximum_pairwise_distance + mergable &= sn.n_ij >= self.minimum_jumps_mergable + + # -- Check pairs' barriers + # Symmetric, and diagonal is trivially true. Combinations avoids those cases. + jbuf = pos[0].copy() + first_calculate = True + mergable_pairs = (p for p in itertools.combinations(range(sn.n_sites), r = 2) if mergable[p] or mergable[p[1], p[0]]) + n_mergable = (np.sum(mergable) - sn.n_sites) // 2 + for i, j in tqdm(mergable_pairs, total = n_mergable): + jbuf[:] = pos[j] + # Get minimage + _ = pbcc.min_image(pos[i], jbuf) + # Do coordinate driving + vector = jbuf - pos[i] + for image_i in range(self.n_driven_images): + one_mobile_structure.positions[mobile_idex] = vector + one_mobile_structure.positions[mobile_idex] *= interpolation_coeffs[image_i] + one_mobile_structure.positions[mobile_idex] += pos[i] + energies[image_i] = one_mobile_structure.get_potential_energy() + first_calculate = False + # Check barrier + barrier_idex = np.argmax(energies) + forward_barrier = energies[barrier_idex] - energies[0] + backward_barrier = energies[barrier_idex] - energies[-1] + # If it's an actual maxima barrier between them, then we want to + # check its height + if barrier_idex != 0 and barrier_idex != self.n_driven_images - 1: + mergable[i, j] = forward_barrier <= self.barrier_threshold + mergable[j, i] = backward_barrier <= self.barrier_threshold + # Otherwise, if there's no maxima between them, they are in the same + # basin. + + # Get mergable groups + n_merged_sites, labels = connected_components( + mergable, + directed = True, + connection = 'strong' + ) + # MergeSites will check pairwise distances; we just need to make it the + # right format. + merge_groups = [] + for lbl in range(n_merged_sites): + merge_groups.append(np.where(labels == lbl)[0]) + + return merge_groups diff --git a/sitator/network/__init__.py b/sitator/network/__init__.py new file mode 100644 index 0000000..0dd8174 --- /dev/null +++ b/sitator/network/__init__.py @@ -0,0 +1,4 @@ +from .DiffusionPathwayAnalysis import DiffusionPathwayAnalysis + +from . import merging +from .MergeSitesByBarrier import MergeSitesByBarrier diff --git a/sitator/network/merging.py b/sitator/network/merging.py new file mode 100644 index 0000000..c714e7c --- /dev/null +++ b/sitator/network/merging.py @@ -0,0 +1,145 @@ +import numpy as np + +import abc + +from sitator.util import PBCCalculator +from sitator import SiteNetwork, SiteTrajectory +from sitator.errors import InsufficientSitesError + +import logging +logger = logging.getLogger(__name__) + +class MergeSitesError(Exception): + pass + +class MergedSitesTooDistantError(MergeSitesError): + pass + + +class MergeSites(abc.ABC): + """Abstract base class for merging sites. + + :param bool check_types: If ``True``, only sites of the same type are candidates to + be merged; if ``False``, type information is ignored. Merged sites will only + be assigned types if this is ``True``. + :param float maximum_merge_distance: Maximum distance between two sites + that are in a merge group, above which an error will be raised. + :param bool set_merged_into: If ``True``, a site attribute ``"merged_into"`` + will be added to the original ``SiteNetwork`` indicating which new site + each old site was merged into. + :param bool weighted_spatial_average: If ``True``, the spatial average giving + the position of the merged site will be weighted by occupancy. + """ + def __init__(self, + check_types = True, + maximum_merge_distance = None, + set_merged_into = False, + weighted_spatial_average = True): + self.check_types = check_types + self.maximum_merge_distance = maximum_merge_distance + self.set_merged_into = set_merged_into + self.weighted_spatial_average = weighted_spatial_average + + + def run(self, st, **kwargs): + """Takes a ``SiteTrajectory`` and returns a new ``SiteTrajectory``.""" + + if self.check_types and st.site_network.site_types is None: + raise ValueError("Cannot run a check_types=True MergeSites on a SiteTrajectory without type information.") + + # -- Compute jump statistics + pbcc = PBCCalculator(st.site_network.structure.cell) + site_centers = st.site_network.centers + if self.check_types: + site_types = st.site_network.site_types + + clusters = self._get_sites_to_merge(st, **kwargs) + + old_n_sites = st.site_network.n_sites + new_n_sites = len(clusters) + + logger.info("After merging %i sites there will be %i sites for %i mobile particles" % (len(site_centers), new_n_sites, st.site_network.n_mobile)) + + if new_n_sites < st.site_network.n_mobile: + raise InsufficientSitesError( + verb = "Merging", + n_sites = new_n_sites, + n_mobile = st.site_network.n_mobile + ) + + if self.check_types: + new_types = np.empty(shape = new_n_sites, dtype = np.int) + merge_verts = st.site_network.vertices is not None + if merge_verts: + new_verts = [] + + # -- Merge Sites + new_centers = np.empty(shape = (new_n_sites, 3), dtype = st.site_network.centers.dtype) + translation = np.empty(shape = st.site_network.n_sites, dtype = np.int) + translation.fill(-1) + + for newsite in range(new_n_sites): + mask = list(clusters[newsite]) + # Update translation table + if np.any(translation[mask] != -1): + # We've assigned a different cluster for this before... weird + # degeneracy + raise ValueError("Site merging tried to merge site(s) into more than one new site. This shouldn't happen.") + translation[mask] = newsite + + to_merge = site_centers[mask] + + # Check distances + if not self.maximum_merge_distance is None: + dists = pbcc.distances(to_merge[0], to_merge[1:]) + if not np.all(dists <= self.maximum_merge_distance): + raise MergedSitesTooDistantError("Markov clustering tried to merge sites more than %.2f apart. Lower your distance_threshold?" % self.maximum_merge_distance) + + # New site center + if self.weighted_spatial_average: + new_centers[newsite] = pbcc.average(to_merge) + else: + occs = st.site_network.occupancies[mask] + new_centers[newsite] = pbcc.average(to_merge, weights = occs) + + if self.check_types: + assert np.all(site_types[mask] == site_types[mask][0]) + new_types[newsite] = site_types[mask][0] + if merge_verts: + new_verts.append(set.union(*[set(st.site_network.vertices[i]) for i in mask])) + + newsn = st.site_network.copy() + newsn.centers = new_centers + if self.check_types: + newsn.site_types = new_types + if merge_verts: + newsn.vertices = new_verts + + newtraj = translation[st._traj] + newtraj[st._traj == SiteTrajectory.SITE_UNKNOWN] = SiteTrajectory.SITE_UNKNOWN + + # It doesn't make sense to propagate confidence information through a + # transform that might completely invalidate it + newst = SiteTrajectory(newsn, newtraj, confidences = None) + + if not st.real_trajectory is None: + newst.set_real_traj(st.real_trajectory) + + if self.set_merged_into: + if st.site_network.has_attribute("merged_into"): + st.site_network.remove_attribute("merged_into") + st.site_network.add_site_attribute("merged_into", translation) + + return newst + + @abc.abstractmethod + def _get_sites_to_merge(self, st, **kwargs): + """Get the groups of sites to merge. + + Returns a list of list/tuples each containing the numbers of sites to be merged. + There should be no overlap, and every site should be mentioned in at most + one site merging group. + + If not mentioned in any, the site will disappear. + """ + pass diff --git a/sitator/plotting.py b/sitator/plotting.py deleted file mode 100644 index cff27bd..0000000 --- a/sitator/plotting.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import matplotlib -from mpl_toolkits.mplot3d import Axes3D - -import ase - -from samos.analysis.jumps.voronoi import collapse_into_unit_cell - -DEFAULT_COLORS = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'] * 2 - -# From https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to -def set_axes_equal(ax): - '''Make axes of 3D plot have equal scale so that spheres appear as spheres, - cubes as cubes, etc.. This is one possible solution to Matplotlib's - ax.set_aspect('equal') and ax.axis('equal') not working for 3D. - - Input - ax: a matplotlib axis, e.g., as output from plt.gca(). - ''' - - x_limits = ax.get_xlim3d() - y_limits = ax.get_ylim3d() - z_limits = ax.get_zlim3d() - - x_range = abs(x_limits[1] - x_limits[0]) - x_middle = np.mean(x_limits) - y_range = abs(y_limits[1] - y_limits[0]) - y_middle = np.mean(y_limits) - z_range = abs(z_limits[1] - z_limits[0]) - z_middle = np.mean(z_limits) - - # The plot bounding box is a sphere in the sense of the infinity - # norm, hence I call half the max range the plot radius. - plot_radius = 0.5*max([x_range, y_range, z_range]) - - ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) - ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) - ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) - -def plot_atoms(atms, species, - pts=None, pts_cs = None, pts_marker = 'x', - cell = None, - hide_species = (), - wrap = False, - title = ""): - fig = plt.figure() - ax = fig.add_subplot(111, projection="3d") - cs = { - 'Li' : "blue", 'O' : "red", 'Ta' : "gray", 'Ge' : "darkgray", 'P' : "orange", - 'point' : 'black' - } - - if wrap and not cell is None: - atms = np.asarray([collapse_into_unit_cell(pt, cell) for pt in atms]) - if not pts is None: - pts = np.asarray([collapse_into_unit_cell(pt, cell) for pt in pts]) - - for s in hide_species: - cs[s] = 'none' - - ax.scatter(atms[:,0], - atms[:,1], - atms[:,2], - c = [cs.get(e, 'gray') for e in species], - s = [10.0*(ase.data.atomic_numbers[s])**0.5 for s in species]) - - - if not pts is None: - c = None - if pts_cs is None: - c = cs['point'] - else: - c = pts_cs - ax.scatter(pts[:,0], - pts[:,1], - pts[:,2], - marker = pts_marker, - c = c, - cmap=matplotlib.cm.Dark2) - - if not cell is None: - for cvec in cell: - cvec = np.array([[0, 0, 0], cvec]) - ax.plot(cvec[:,0], - cvec[:,1], - cvec[:,2], - color = "gray", - alpha=0.5, - linestyle="--") - - ax.set_title(title) - - set_axes_equal(ax) - - return ax diff --git a/sitator/site_descriptors/SOAP.py b/sitator/site_descriptors/SOAP.py index f2575ae..df0c128 100644 --- a/sitator/site_descriptors/SOAP.py +++ b/sitator/site_descriptors/SOAP.py @@ -2,40 +2,15 @@ import numpy as np from abc import ABCMeta, abstractmethod -from sitator.SiteNetwork import SiteNetwork -from sitator.SiteTrajectory import SiteTrajectory - -try: - import quippy as qp - from quippy import descriptors -except ImportError: - raise ImportError("Quippy with GAP is required for using SOAP descriptors.") +from sitator import SiteNetwork, SiteTrajectory +from sitator.util.progress import tqdm from ase.data import atomic_numbers -DEFAULT_SOAP_PARAMS = { - 'cutoff' : 3.0, - 'cutoff_transition_width' : 1.0, - 'l_max' : 6, 'n_max' : 6, - 'atom_sigma' : 0.4 -} - -# From https://github.com/tqdm/tqdm/issues/506#issuecomment-373126698 -import sys -try: - ipy_str = str(type(get_ipython())) - if 'zmqshell' in ipy_str: - from tqdm import tqdm_notebook as tqdm - if 'terminal' in ipy_str: - from tqdm import tqdm -except: - if sys.stderr.isatty(): - from tqdm import tqdm - else: - def tqdm(iterable, **kwargs): - return iterable - -class SOAP(object): +import logging +logger = logging.getLogger(__name__) + +class SOAP(object, metaclass=ABCMeta): """Abstract base class for computing SOAP vectors in a SiteNetwork. SOAP computations are *not* thread-safe; use one SOAP object per thread. @@ -45,42 +20,42 @@ class SOAP(object): of the environment to consider. I.e. for Li2CO3, can be set to ['O'] or [8] for oxygen only, or ['C', 'O'] / ['C', 8] / [6,8] if carbon and oxygen are considered an environment. - Defaults to `None`, in which case all non-mobile atoms are considered + Defaults to ``None``, in which case all non-mobile atoms are considered regardless of species. - :param soap_mask: Which atoms in the SiteNetwork's structure + :param soap_mask: Which atoms in the ``SiteNetwork``'s structure to use in SOAP calculations. Can be either a boolean mask ndarray or a tuple of species. - If `None`, the entire static_structure of the SiteNetwork will be used. + If ``None``, the entire ``static_structure`` of the ``SiteNetwork`` will be used. Mobile atoms cannot be used for the SOAP host structure. Even not masked, species not considered in environment will be not accounted for. - For ideal performance: Specify environment and soap_mask correctly! + For ideal performance: Specify environment and ``soap_mask`` correctly! :param dict soap_params = {}: Any custom SOAP params. + :param func backend: A function that can be called with + ``sn, soap_mask, tracer_atomic_number, environment_list`` as + parameters, returning a function that, given the current soap structure + along with tracer atoms, returns SOAP vectors in a numpy array. (i.e. + its signature is ``soap(structure, positions)``). The returned function + can also have a property, ``n_dim``, giving the length of a single SOAP + vector. """ - __metaclass__ = ABCMeta + + from .backend.quip import quip_soap_backend as backend_quip + from .backend.dscribe import dscribe_soap_backend as backend_dscribe + def __init__(self, tracer_atomic_number, environment = None, - soap_mask=None, soap_params={}, verbose =True): + soap_mask = None, + backend = None): from ase.data import atomic_numbers - # Creating a dictionary for convenience, to check the types and values: - self.tracer_atomic_number = 3 - centers_list = [self.tracer_atomic_number] + self.tracer_atomic_number = tracer_atomic_number self._soap_mask = soap_mask - # -- Create the descriptor object - soap_opts = dict(DEFAULT_SOAP_PARAMS) - soap_opts.update(soap_params) - soap_cmd_line = ["soap"] - - # User options - for opt in soap_opts: - soap_cmd_line.append("{}={}".format(opt, soap_opts[opt])) + if backend is None: + backend = SOAP.dscribe_soap_backend + self._backend = backend - # - soap_cmd_line.append('n_Z={} Z={{{}}}'.format(len(centers_list), ' '.join(map(str, centers_list)))) - - # - Add environment species controls if given - self._environment = None - if not environment is None: + # - Standardize environment species controls if given + if not environment is None: # User given environment if not isinstance(environment, (list, tuple)): raise TypeError('environment has to be a list or tuple of species (atomic number' ' or symbol of the environment to consider') @@ -99,40 +74,38 @@ def __init__(self, tracer_atomic_number, environment = None, raise TypeError("Environment has to be a list of atomic numbers or atomic symbols") self._environment = environment_list - soap_cmd_line.append('n_species={} species_Z={{{}}}'.format(len(environment_list), ' '.join(map(str, environment_list)))) - - soap_cmd_line = " ".join(soap_cmd_line) - - if verbose: - print("SOAP command line: %s" % soap_cmd_line) - - self._soaper = descriptors.Descriptor(soap_cmd_line) - self._verbose = verbose - self._cutoff = soap_opts['cutoff'] - - - - @property - def n_dim(self): - return self._soaper.n_dim + else: + self._environment = None def get_descriptors(self, stn): - """ - Get the descriptors. - :param stn: A valid instance of SiteTrajectory or SiteNetwork - :returns: an array of descriptor vectors and an equal length array of - labels indicating which descriptors correspond to which sites. + """Get the descriptors. + + Args: + stn (SiteTrajectory or SiteNetwork) + Returns: + An array of descriptor vectors and an equal length array of labels + indicating which descriptors correspond to which sites. """ # Build SOAP host structure if isinstance(stn, SiteTrajectory): - structure, tracer_index, soap_mask = self._make_structure(stn.site_network) + sn = stn.site_network elif isinstance(stn, SiteNetwork): - structure, tracer_index, soap_mask = self._make_structure(stn) + sn = stn else: raise TypeError("`stn` must be SiteNetwork or SiteTrajectory") + structure, tracer_atomic_number, soap_mask = self._make_structure(sn) + + if self._environment is not None: + environment_list = self._environment + else: + # Set it to all species represented by the soap_mask + environment_list = np.unique(sn.structure.get_atomic_numbers()[soap_mask]) + + soaper = self._backend(sn, soap_mask, tracer_atomic_number, environment_list) + # Compute descriptors - return self._get_descriptors(stn, structure, tracer_index, soap_mask) + return self._get_descriptors(stn, structure, tracer_atomic_number, soap_mask, soaper) # ---- @@ -140,7 +113,7 @@ def _make_structure(self, sn): if self._soap_mask is None: # Make a copy of the static structure - structure = qp.Atoms(sn.static_structure) + structure = sn.static_structure.copy() soap_mask = sn.static_mask # soap mask is the else: if isinstance(self._soap_mask, tuple): @@ -149,7 +122,7 @@ def _make_structure(self, sn): soap_mask = self._soap_mask assert not np.any(soap_mask & sn.mobile_mask), "Error for atoms %s; No atom can be both static and mobile" % np.where(soap_mask & sn.mobile_mask)[0] - structure = qp.Atoms(sn.structure[soap_mask]) + structure = sn.structure[soap_mask] assert np.any(soap_mask), "Given `soap_mask` excluded all host atoms." if not self._environment is None: @@ -161,14 +134,16 @@ def _make_structure(self, sn): else: tracer_atomic_number = self.tracer_atomic_number - structure.add_atoms((0.0, 0.0, 0.0), tracer_atomic_number) + if np.any(structure.get_atomic_numbers() == tracer_atomic_number): + raise ValueError("Structure cannot have static atoms (that are enabled in the SOAP mask) of the same species as `tracer_atomic_number`.") + structure.set_pbc([True, True, True]) - tracer_index = len(structure) - 1 - return structure, tracer_index, soap_mask + return structure, tracer_atomic_number, soap_mask + @abstractmethod - def _get_descriptors(self, stn, structure, tracer_index): + def _get_descriptors(self, stn, structure, tracer_atomic_number, soaper): pass @@ -177,46 +152,35 @@ def _get_descriptors(self, stn, structure, tracer_index): class SOAPCenters(SOAP): """Compute the SOAPs of the site centers in the fixed host structure. - Requires a SiteNetwork as input. + Requires a ``SiteNetwork`` as input. """ - def _get_descriptors(self, sn, structure, tracer_index, soap_mask): - assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork, not `%s`" % sn + def _get_descriptors(self, sn, structure, tracer_atomic_number, soap_mask, soaper): + if isinstance(sn, SiteTrajectory): + sn = sn.site_network + assert isinstance(sn, SiteNetwork), "SOAPCenters requires a SiteNetwork or SiteTrajectory, not `%s`" % sn pts = sn.centers - out = np.empty(shape = (len(pts), self.n_dim), dtype = np.float) - - structure.set_cutoff(self._soaper.cutoff()) - - for i, pt in enumerate(tqdm(pts, desc="SOAP") if self._verbose else pts): - # Move tracer - structure.positions[tracer_index] = pt - - # SOAP requires connectivity data to be computed first - structure.calc_connect() - - #There should only be one descriptor, since there should only be one Li - out[i] = self._soaper.calc(structure)['descriptor'][0] + out = soaper(structure, pts) return out, np.arange(sn.n_sites) class SOAPSampledCenters(SOAPCenters): - """Compute the SOAPs of representative points for each site, as determined by `sampling_transform`. + """Compute the SOAPs of representative points for each site, as determined by ``sampling_transform``. - Takes either a SiteNetwork or SiteTrajectory as input; requires that - `sampling_transform` produce a SiteNetwork where `site_types` indicates - which site in the original SiteNetwork/SiteTrajectory it was sampled from. + Takes either a ``SiteNetwork`` or ``SiteTrajectory`` as input; requires that + ``sampling_transform`` produce a ``SiteNetwork`` where ``site_types`` indicates + which site in the original ``SiteNetwork``/``SiteTrajectory`` it was sampled from. - Typical sampling transforms are `sitator.misc.NAvgsPerSite` (for a SiteTrajectory) - and `sitator.misc.GenerateAroundSites` (for a SiteNetwork). + Typical sampling transforms are ``sitator.misc.NAvgsPerSite`` (for a ``SiteTrajectory``) + and ``sitator.misc.GenerateAroundSites`` (for a ``SiteNetwork``). """ def __init__(self, *args, **kwargs): self.sampling_transform = kwargs.pop('sampling_transform', 1) super(SOAPSampledCenters, self).__init__(*args, **kwargs) def get_descriptors(self, stn): - # Do sampling sampled = self.sampling_transform.run(stn) assert isinstance(sampled, SiteNetwork), "Sampling transform returned `%s`, not a SiteNetwork" % sampled @@ -237,11 +201,11 @@ class SOAPDescriptorAverages(SOAP): then averaged in SOAP space to give the final SOAP vectors for each site. This method often performs better than SOAPSampledCenters on more dynamic - systems, but requires significantly more computation. + systems, but requires significantly more computation. :param int stepsize: Stride (in frames) when computing SOAPs. Default 1. :param int averaging: Number of SOAP vectors to average for each output vector. - :param int avg_descriptors_per_site: Can be specified instead of `averaging`. + :param int avg_descriptors_per_site: Can be specified instead of ``averaging``. Specifies the _average_ number of average SOAP vectors to compute for each site. This does not guerantee that number of SOAP vectors for any site, rather, it allows a trajectory-size agnostic way to specify approximately @@ -281,7 +245,7 @@ def __init__(self, *args, **kwargs): super(SOAPDescriptorAverages, self).__init__(*args, **kwargs) - def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask): + def _get_descriptors(self, site_trajectory, structure, tracer_atomic_number, soap_mask, soaper): """ calculate descriptors """ @@ -291,63 +255,66 @@ def _get_descriptors(self, site_trajectory, structure, tracer_index, soap_mask): mob_indices = np.where(site_trajectory.site_network.mobile_mask)[0] # real_traj is the real space positions, site_traj the site trajectory # (i.e. for every mobile species the site index) - # I load into new variable, only the steps I need (memory???) real_traj = site_trajectory._real_traj[::self._stepsize] site_traj = site_trajectory.traj[::self._stepsize] # Now, I need to allocate the output # so for each site, I count how much data there is! - counts = np.array([np.count_nonzero(site_traj==site_idx) for site_idx in xrange(nsit)], dtype=int) + counts = np.zeros(shape = nsit + 1, dtype = np.int) + for frame in site_traj: + counts[frame] += 1 # A duplicate in `frame` will still only cause a single addition due to Python rules. + counts = counts[:-1] if self._averaging is not None: averaging = self._averaging else: averaging = int(np.floor(np.mean(counts) / self._avg_desc_per_site)) + logger.debug("Will average %i SOAP vectors for every output vector" % averaging) - nr_of_descs = counts // averaging + if averaging == 0: + logger.warning("Asking for too many average descriptors per site; got averaging = 0; setting averaging = 1") + averaging = 1 - if np.any(nr_of_descs == 0): - raise ValueError("You are asking too much, averaging with {} gives a problem".format(averaging)) + nr_of_descs = counts // averaging + insufficient = nr_of_descs == 0 + if np.any(insufficient): + logger.warning("You're asking to average %i SOAP vectors, but at this stepsize, %i sites are insufficiently occupied. Num occ./averaging: %s" % (averaging, np.sum(insufficient), counts[insufficient] / averaging)) + averagings = np.full(shape = len(nr_of_descs), fill_value = averaging) + averagings[insufficient] = counts[insufficient] + nr_of_descs = np.maximum(nr_of_descs, 1) # If it's 0, just make one with whatever we've got + assert np.all(nr_of_descs >= 1) + logger.debug("Minimum # of descriptors/site: %i; maximum: %i" % (np.min(nr_of_descs), np.max(nr_of_descs))) # This is where I load the descriptor: - descs = np.zeros((np.sum(nr_of_descs), self.n_dim)) + descs = np.zeros(shape = (np.sum(nr_of_descs), soaper.n_dim)) # An array that tells me the index I'm at for each site type - desc_index = [np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))] - max_index = [np.sum(nr_of_descs[:i+1]) for i in range(len(nr_of_descs))] + desc_index = np.asarray([np.sum(nr_of_descs[:i]) for i in range(len(nr_of_descs))]) + max_index = np.asarray([np.sum(nr_of_descs[:i+1]) for i in range(len(nr_of_descs))]) count_of_site = np.zeros(len(nr_of_descs), dtype=int) - blocked = np.empty(nsit, dtype=bool) - blocked[:] = False - structure.set_cutoff(self._soaper.cutoff()) - for site_traj_t, pos in tqdm(zip(site_traj, real_traj), desc="SOAP"): + allowed = np.ones(nsit, dtype = np.bool) + + for site_traj_t, pos in zip(tqdm(site_traj, desc="SOAP Frame"), real_traj): # I update the host lattice positions here, once for every timestep - structure.positions[:tracer_index] = pos[soap_mask] - for mob_idx, site_idx in enumerate(site_traj_t): - if site_idx >= 0 and not blocked[site_idx]: - # Now, for every lithium that has been associated to a site of index site_idx, - # I take my structure and load the position of this mobile atom: - structure.positions[tracer_index] = pos[mob_indices[mob_idx]] - # calc_connect to calculated distance -# structure.calc_connect() - #There should only be one descriptor, since there should only be one mobile - # I also divide by averaging, to avoid getting into large numbers. -# soapv = self._soaper.calc(structure)['descriptor'][0] / self._averaging - structure.set_cutoff(self._cutoff) - structure.calc_connect() - soapv = self._soaper.calc(structure, grad=False)["descriptor"] - - #~ soapv ,_,_ = get_fingerprints([structure], d) - # So, now I need to figure out where to load the soapv into desc - idx_to_add_desc = desc_index[site_idx] - descs[idx_to_add_desc, :] += soapv[0] / averaging - count_of_site[site_idx] += 1 - # Now, if the count reaches the averaging I want, I augment - if count_of_site[site_idx] == averaging: - desc_index[site_idx] += 1 - count_of_site[site_idx] = 0 - # Now I check whether I have to block this site from accumulating more descriptors - if max_index[site_idx] == desc_index[site_idx]: - blocked[site_idx] = True - - desc_to_site = np.repeat(range(nsit), nr_of_descs) + structure.positions[:] = pos[soap_mask] + + to_describe = (site_traj_t != SiteTrajectory.SITE_UNKNOWN) & allowed[site_traj_t] + + if np.any(to_describe): + sites_to_describe = site_traj_t[to_describe] + soaps = soaper(structure, pos[mob_indices[to_describe]]) + soaps /= averagings[sites_to_describe][:, np.newaxis] + idx_to_add_desc = desc_index[sites_to_describe] + descs[idx_to_add_desc] += soaps + count_of_site[sites_to_describe] += 1 + + # Reset and increment full averages + full_average = count_of_site == averagings + desc_index[full_average] += 1 + count_of_site[full_average] = 0 + allowed[max_index == desc_index] = False + + assert not np.any(allowed) # We should have maxed out all of them after processing all frames. + + desc_to_site = np.repeat(list(range(nsit)), nr_of_descs) return descs, desc_to_site diff --git a/sitator/site_descriptors/SiteCoordinationEnvironment.py b/sitator/site_descriptors/SiteCoordinationEnvironment.py new file mode 100644 index 0000000..8ad2ef8 --- /dev/null +++ b/sitator/site_descriptors/SiteCoordinationEnvironment.py @@ -0,0 +1,148 @@ +import numpy as np + +from sitator.util.progress import tqdm + +try: + from pymatgen.io.ase import AseAtomsAdaptor + import pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder as cgf + from pymatgen.analysis.chemenv.coordination_environments.structure_environments import \ + LightStructureEnvironments + from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions + from pymatgen.analysis.bond_valence import BVAnalyzer + has_pymatgen = True +except ImportError: + has_pymatgen = False + +import logging +logger = logging.getLogger(__name__) + + +class SiteCoordinationEnvironment(object): + """Determine site types based on local coordination environments. + + Determine site types using the method from the following paper: + + David Waroquiers, Xavier Gonze, Gian-Marco Rignanese, + Cathrin Welker-Nieuwoudt, Frank Rosowski, Michael Goebel, Stephan Schenk, + Peter Degelmann, Rute Andre, Robert Glaum, and Geoffroy Hautier + + Statistical analysis of coordination environments in oxides + + Chem. Mater., 2017, 29 (19), pp 8346–8360, DOI: 10.1021/acs.chemmater.7b02766 + + as implement in ``pymatgen``'s + ``pymatgen.analysis.chemenv.coordination_environments``. + + Adds three site attributes: + - ``coordination_environments``: The name of the coordination environment, + as returned by ``pymatgen``. Example: ``"T:4"`` (tetrahedral, coordination + of 4). + - ``site_type_confidences``: The ``ce_fraction`` of the best match chemical + environment (from 0 to 1). + - ``coordination_numbers``: The coordination number of the site. + + Args: + only_ionic_bonds (bool): If True, assumes ``sn`` is an ``IonicSiteNetwork`` + and uses its charge information to only consider as neighbors + atoms with compatable (anion and cation, ion and neutral) charges. + full_chemenv_site_types (bool): If True, ``sitator`` site types on the + final ``SiteNetwork`` will be assigned based on unique chemical + environments, including shape. If False, they will be assigned + solely based on coordination number. Either way, both sets of information + are included in the ``SiteNetwork``, this just changes which determines + the ``site_types``. + **kwargs: passed to ``compute_structure_environments``. + """ + def __init__(self, + only_ionic_bonds = True, + full_chemenv_site_types = False, + **kwargs): + if not has_pymatgen: + raise ImportError("Pymatgen (or a recent enough version including `pymatgen.analysis.chemenv.coordination_environments`) cannot be imported.") + self._kwargs = kwargs + self._only_ionic_bonds = only_ionic_bonds + self._full_chemenv_site_types = full_chemenv_site_types + + def run(self, sn): + """ + Args: + sn (SiteNetwork) + Returns: + ``sn``, with type information. + """ + # -- Determine local environments + # Get an ASE structure with a single mobile site that we'll move around + if sn.n_sites == 0: + logger.warning("Site network had no sites.") + return sn + site_struct, site_species = sn[0:1].get_structure_with_sites() + pymat_struct = AseAtomsAdaptor.get_structure(site_struct) + lgf = cgf.LocalGeometryFinder() + site_atom_index = len(site_struct) - 1 + + coord_envs = [] + vertices = [] + + valences = 'undefined' + if self._only_ionic_bonds: + valences = list(sn.static_charges) + [sn.mobile_charge] + + logger.info("Running site coordination environment analysis...") + # Do this once. + # __init__ here defaults to disabling structure refinement, so all this + # method is doing is making a copy of the structure and setting some + # variables to None. + lgf.setup_structure(structure = pymat_struct) + + for site in tqdm(range(sn.n_sites), desc = 'Site'): + # Update the position of the site + lgf.structure[site_atom_index].coords = sn.centers[site] + # Compute structure environments for the site + struct_envs = lgf.compute_structure_environments( + only_indices = [site_atom_index], + valences = valences, + additional_conditions = [AdditionalConditions.ONLY_ANION_CATION_BONDS], + **self._kwargs + ) + struct_envs = LightStructureEnvironments.from_structure_environments( + strategy = cgf.LocalGeometryFinder.DEFAULT_STRATEGY, + structure_environments = struct_envs + ) + # Store the results + # We take the first environment for each site since it's the most likely + this_site_envs = struct_envs.coordination_environments[site_atom_index] + if len(this_site_envs) > 0: + coord_envs.append(this_site_envs[0]) + vertices.append( + [n['index'] for n in struct_envs.neighbors_sets[site_atom_index][0].neighb_indices_and_images] + ) + else: + coord_envs.append({'ce_symbol' : 'BAD:0', 'ce_fraction' : 0.}) + vertices.append([]) + + del lgf + del struct_envs + + # -- Postprocess + # TODO: allow user to ask for full fractional breakdown + str_coord_environments = [env['ce_symbol'] for env in coord_envs] + # The closer to 1 this is, the better + site_type_confidences = np.array([env['ce_fraction'] for env in coord_envs]) + coordination_numbers = np.array([int(env['ce_symbol'].split(':')[1]) for env in coord_envs]) + assert np.all(coordination_numbers == [len(v) for v in vertices]) + + typearr = str_coord_environments if self._full_chemenv_site_types else coordination_numbers + unique_envs = list(set(typearr)) + site_types = np.array([unique_envs.index(t) for t in typearr]) + + n_types = len(unique_envs) + logger.info(("Type " + "{:<8}" * n_types).format(*unique_envs)) + logger.info(("# of sites " + "{:<8}" * n_types).format(*np.bincount(site_types))) + + sn.site_types = site_types + sn.vertices = vertices + sn.add_site_attribute("coordination_environments", str_coord_environments) + sn.add_site_attribute("site_type_confidences", site_type_confidences) + sn.add_site_attribute("coordination_numbers", coordination_numbers) + + return sn diff --git a/sitator/site_descriptors/SiteTypeAnalysis.py b/sitator/site_descriptors/SiteTypeAnalysis.py index 30a7143..26efa8f 100644 --- a/sitator/site_descriptors/SiteTypeAnalysis.py +++ b/sitator/site_descriptors/SiteTypeAnalysis.py @@ -1,7 +1,3 @@ -from __future__ import (absolute_import, division, - print_function, unicode_literals) -from builtins import * - import numpy as np from sitator.misc import GenerateAroundSites @@ -14,39 +10,59 @@ import itertools +import logging +logger = logging.getLogger(__name__) + +has_pydpc = False try: import pydpc + has_pydpc = True except ImportError: - raise ImportError("SiteTypeAnalysis requires the `pydpc` package") + pass class SiteTypeAnalysis(object): - """Cluster sites into types using a descriptor and DPCLUS. - - -- descriptor -- - Some kind of object implementing: - - n_dim: the number of components in a descriptor vector - - get_descriptors(site_traj or site_network): returns an array of descriptor vectors - of dimension (M, n_dim) and an array of length M indicating which - descriptor vectors correspond to which sites in (site_traj.)site_network. + """Cluster sites into types using a continuous descriptor and Density Peak Clustering. + + Computes descriptor vectors, processes them with Principal Component Analysis, + and then clusters using Density Peak Clustering. + + Args: + descriptor (object): Must implement ``get_descriptors(st|sn)``, which + returns an array of descriptor vectors of dimension (M, n_dim) and + an array of length M indicating which descriptor vectors correspond + to which sites in (``site_traj.``)``site_network``. + min_pca_variance (float): The minimum proportion of the total variance + that the taken principal components of the descriptor must explain. + min_pca_dimensions (int): Force taking at least this many principal + components. + n_site_types_max (int): Maximum number of clusters. Must be set reasonably + for the automatic selection of cluster number to work. """ def __init__(self, descriptor, min_pca_variance = 0.9, min_pca_dimensions = 2, - verbose = True, n_site_types_max = 20): + n_site_types_max = 20): + if not has_pydpc: + raise ImportError("SiteTypeAnalysis requires the `pydpc` package") + self.descriptor = descriptor self.min_pca_variance = min_pca_variance self.min_pca_dimensions = min_pca_dimensions - self.verbose = verbose self.n_site_types_max = n_site_types_max self._n_dvecs = None def run(self, descriptor_input, **kwargs): + """ + Args: + descriptor_input (SiteNetwork or SiteTrajectory) + Returns: + ``SiteNetwork`` + """ if not self._n_dvecs is None: raise ValueError("Can't run SiteTypeAnalysis more than once!") # -- Sample enough points - if self.verbose: - print(" -- Running SiteTypeAnalysis --") + logger.info(" -- Running SiteTypeAnalysis --") if isinstance(descriptor_input, SiteNetwork): sn = descriptor_input.copy() @@ -56,30 +72,26 @@ def run(self, descriptor_input, **kwargs): raise RuntimeError("Input {}".format(type(descriptor_input))) # -- Compute descriptor vectors - if self.verbose: - print(" - Computing Descriptor Vectors") + logger.info(" - Computing Descriptor Vectors") self.dvecs, dvecs_to_site = self.descriptor.get_descriptors(descriptor_input, **kwargs) assert len(self.dvecs) == len(dvecs_to_site), "Length mismatch in descriptor return values" assert np.min(dvecs_to_site) == 0 and np.max(dvecs_to_site) < sn.n_sites # -- Dimensionality Reduction - if self.verbose: - print(" - Clustering Descriptor Vectors") + logger.info(" - Clustering Descriptor Vectors") self.pca = PCA(self.min_pca_variance) pca_dvecs = self.pca.fit_transform(self.dvecs) if pca_dvecs.shape[1] < self.min_pca_dimensions: - if self.verbose: - print(" PCA accounted for %.0f%% variance in only %i dimensions; less than minimum of %.0f." % (100.0 * np.sum(self.pca.explained_variance_ratio_), pca_dvecs.shape[1], self.min_pca_dimensions)) - print(" Forcing PCA to use %i dimensions." % self.min_pca_dimensions) - self.pca = PCA(n_components = self.min_pca_dimensions) - pca_dvecs = self.pca.fit_transform(self.dvecs) + logger.info(" PCA accounted for %.0f%% variance in only %i dimensions; less than minimum of %.0f." % (100.0 * np.sum(self.pca.explained_variance_ratio_), pca_dvecs.shape[1], self.min_pca_dimensions)) + logger.info(" Forcing PCA to use %i dimensions." % self.min_pca_dimensions) + self.pca = PCA(n_components = self.min_pca_dimensions) + pca_dvecs = self.pca.fit_transform(self.dvecs) self.dvecs = pca_dvecs - if self.verbose: - print(" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1])) + logger.info((" Accounted for %.0f%% of variance in %i dimensions" % (100.0 * np.sum(self.pca.explained_variance_ratio_), self.dvecs.shape[1]))) # -- Do clustering # pydpc requires a C-contiguous array @@ -122,14 +134,13 @@ def run(self, descriptor_input, **kwargs): assert self.n_types == len(site_type_counts), "Got %i types from pydpc, but counted %i" % (self.n_types, len(site_type_counts)) - if self.verbose: - print(" Found %i site type clusters" % self.n_types ) - print(" Failed to assign %i/%i descriptor vectors to clusters." % (self._n_unassigned, self._n_dvecs)) + logger.info(" Found %i site type clusters" % self.n_types ) + logger.info(" Failed to assign %i/%i descriptor vectors to clusters." % (self._n_unassigned, self._n_dvecs)) # -- Voting types = np.empty(shape = sn.n_sites, dtype = np.int) self.winning_vote_percentages = np.empty(shape = sn.n_sites, dtype = np.float) - for site in xrange(sn.n_sites): + for site in range(sn.n_sites): corresponding_samples = dvecs_to_site == site votes = assignments[corresponding_samples] n_votes = len(votes) @@ -145,12 +156,11 @@ def run(self, descriptor_input, **kwargs): n_sites_of_each_type = np.bincount(types, minlength = self.n_types) sn.site_types = types - if self.verbose: - print((" " + "Type {:<2} " * self.n_types).format(*xrange(self.n_types))) - print(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type)) + logger.info((" " + "Type {:<2} " * self.n_types).format(*range(self.n_types))) + logger.info(("# of sites " + "{:<8}" * self.n_types).format(*n_sites_of_each_type)) - if np.any(n_sites_of_each_type == 0): - print("WARNING: Had site types with no sites; check clustering settings/voting!") + if np.any(n_sites_of_each_type == 0): + logger.warning("WARNING: Had site types with no sites; check clustering settings/voting!") return sn @@ -172,11 +182,11 @@ def plot_dvecs(self, fig = None, ax = None, **kwargs): @plotter(is3D = False) def plot_clustering(self, fig = None, ax = None, **kwargs): ccycle = itertools.cycle(DEFAULT_COLORS) - for cluster in xrange(self.n_types): + for cluster in range(self.n_types): mask = self.dpc.membership == cluster dvecs_core = self.dvecs[mask & ~self.dpc.border_member] dvecs_border = self.dvecs[mask & self.dpc.border_member] - color = ccycle.next() + color = next(ccycle) ax.scatter(dvecs_core[:,0], dvecs_core[:,1], s = 3, color = color, label = "Type %i" % cluster) ax.scatter(dvecs_border[:,0], dvecs_border[:,1], s = 3, color = color, alpha = 0.3) diff --git a/sitator/site_descriptors/SiteVolumes.py b/sitator/site_descriptors/SiteVolumes.py new file mode 100644 index 0000000..e065680 --- /dev/null +++ b/sitator/site_descriptors/SiteVolumes.py @@ -0,0 +1,134 @@ +import numpy as np + +from scipy.spatial import ConvexHull +from scipy.spatial.qhull import QhullError + +from sitator import SiteNetwork, SiteTrajectory +from sitator.util import PBCCalculator + +import logging +logger = logging.getLogger(__name__) + +class InsufficientCoordinatingAtomsError(Exception): + pass + +class SiteVolumes(object): + """Compute the volumes of sites. + + Args: + error_on_insufficient_coord (bool): To compute an ideal + site volume (``compute_volumes()``), at least 4 coordinating + atoms (because we are in 3D space) must be specified in ``vertices``. + If ``True``, an error will be thrown when a site with less than four + vertices is encountered; if ``False``, a volume of 0 and surface area + of NaN will be returned. + """ + def __init__(self, error_on_insufficient_coord = True): + self.error_on_insufficient_coord = error_on_insufficient_coord + + + def compute_accessable_volumes(self, st, n_recenterings = 8): + """Computes the volumes of convex hulls around all positions associated with a site. + + Uses the shift-and-wrap trick for dealing with periodicity, so sites that + take up the majority of the unit cell may give bogus results. + + Adds the ``accessable_site_volumes`` attribute to the ``SiteNetwork``. + + Args: + st (SiteTrajectory) + n_recenterings (int): How many different recenterings to try (the + algorithm will recenter around n of the points and take the minimal + resulting volume; this deals with cases where there is one outlier + where recentering around it gives very bad results.) + """ + assert isinstance(st, SiteTrajectory) + vols = np.empty(shape = st.site_network.n_sites, dtype = np.float) + areas = np.empty(shape = st.site_network.n_sites, dtype = np.float) + + pbcc = PBCCalculator(st.site_network.structure.cell) + + for site in range(st.site_network.n_sites): + pos = st.real_positions_for_site(site) + + assert pos.flags['OWNDATA'] + + vol = np.inf + area = None + for i in range(n_recenterings): + # Recenter + offset = pbcc.cell_centroid - pos[int(i * (len(pos)/n_recenterings))] + pos += offset + pbcc.wrap_points(pos) + + try: + hull = ConvexHull(pos) + except QhullError as qhe: + logger.warning("For site %i, iter %i: %s" % (site, i, qhe)) + vols[site] = np.nan + areas[site] = np.nan + continue + + if hull.volume < vol: + vol = hull.volume + area = hull.area + + vols[site] = vol + areas[site] = area + + st.site_network.add_site_attribute('accessable_site_volumes', vols) + + + def compute_volumes(self, sn): + """Computes the volume of the convex hull defined by each sites' static verticies. + + Requires vertex information in the SiteNetwork. + + Adds the ``site_volumes`` and ``site_surface_areas`` attributes. + + Volumes can be NaN for degenerate hulls/point sets on which QHull fails. + + Args: + - sn (SiteNetwork) + """ + assert isinstance(sn, SiteNetwork) + if sn.vertices is None: + raise ValueError("SiteNetwork must have verticies to compute volumes!") + + vols = np.empty(shape = sn.n_sites, dtype = np.float) + areas = np.empty(shape = sn.n_sites, dtype = np.float) + + pbcc = PBCCalculator(sn.structure.cell) + + for site in range(sn.n_sites): + pos = sn.static_structure.positions[list(sn.vertices[site])] + if len(pos) < 4: + if self.error_on_insufficient_coord: + raise InsufficientCoordinatingAtomsError("Site %i had only %i vertices (less than needed 4)" % (site, len(pos))) + else: + vols[site] = 0 + areas[site] = np.nan + continue + + assert pos.flags['OWNDATA'] # It should since we're indexing with index lists + # Recenter + offset = pbcc.cell_centroid - sn.centers[site] + pos += offset + pbcc.wrap_points(pos) + + try: + hull = ConvexHull(pos) + vols[site] = hull.volume + areas[site] = hull.area + except QhullError as qhe: + logger.warning("Had QHull failure when computing volume of site %i" % site) + vols[site] = np.nan + areas[site] = np.nan + + sn.add_site_attribute('site_volumes', vols) + sn.add_site_attribute('site_surface_areas', areas) + + + def run(self, st): + """For backwards compatability.""" + self.compute_accessable_volumes(st) diff --git a/sitator/site_descriptors/__init__.py b/sitator/site_descriptors/__init__.py index ca132b2..fd5e7af 100644 --- a/sitator/site_descriptors/__init__.py +++ b/sitator/site_descriptors/__init__.py @@ -1,3 +1,3 @@ -from SiteTypeAnalysis import SiteTypeAnalysis - -from SOAP import SOAPCenters, SOAPSampledCenters, SOAPDescriptorAverages +from .SiteTypeAnalysis import SiteTypeAnalysis +from .SiteCoordinationEnvironment import SiteCoordinationEnvironment +from .SiteVolumes import SiteVolumes diff --git a/sitator/site_descriptors/backend/__init__.py b/sitator/site_descriptors/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sitator/site_descriptors/backend/dscribe.py b/sitator/site_descriptors/backend/dscribe.py new file mode 100644 index 0000000..66fc5d5 --- /dev/null +++ b/sitator/site_descriptors/backend/dscribe.py @@ -0,0 +1,40 @@ + +import numpy as np + +DEFAULT_SOAP_PARAMS = { + 'cutoff' : 3.0, + 'l_max' : 6, 'n_max' : 6, + 'atom_sigma' : 0.4, + 'rbf' : 'gto', + 'crossover' : False, + 'periodic' : True, +} + +def dscribe_soap_backend(soap_params = {}): + from dscribe.descriptors import SOAP + + soap_opts = dict(DEFAULT_SOAP_PARAMS) + soap_opts.update(soap_params) + + def backend(sn, soap_mask, tracer_atomic_number, environment_list): + soap = SOAP( + species = environment_list, + crossover = soap_opts['crossover'], + rcut = soap_opts['cutoff'], + nmax = soap_opts['n_max'], + lmax = soap_opts['l_max'], + rbf = soap_opts['rbf'], + sigma = soap_opts['atom_sigma'], + periodic = soap_opts['periodic'], + sparse = False + ) + + def dscribe_soap(structure, positions): + out = soap.create(structure, positions = positions).astype(np.float) + return out + + dscribe_soap.n_dim = soap.get_number_of_features() + + return dscribe_soap + + return backend diff --git a/sitator/site_descriptors/backend/quip.py b/sitator/site_descriptors/backend/quip.py new file mode 100644 index 0000000..d97459c --- /dev/null +++ b/sitator/site_descriptors/backend/quip.py @@ -0,0 +1,86 @@ +""" +quip.py: Compute SOAP vectors for given positions in a structure using the command line QUIP tool +""" + +import numpy as np + +import os + +import ase + +from tempfile import NamedTemporaryFile +import subprocess + +DEFAULT_SOAP_PARAMS = { + 'cutoff' : 3.0, + 'cutoff_transition_width' : 1.0, + 'l_max' : 6, 'n_max' : 6, + 'atom_sigma' : 0.4 +} + +def quip_soap_backend(soap_params = {}, quip_path = os.getenv("SITATOR_QUIP_PATH", default = 'quip')): + def backend(sn, soap_mask, tracer_atomic_number, environment_list): + + soap_opts = dict(DEFAULT_SOAP_PARAMS) + soap_opts.update(soap_params) + soap_cmd_line = ["soap"] + + # User options + for opt in soap_opts: + soap_cmd_line.append("{}={}".format(opt, soap_opts[opt])) + + # + soap_cmd_line.append('n_Z=1 Z={{{}}}'.format(tracer_atomic_number)) + + soap_cmd_line.append('n_species={} species_Z={{{}}}'.format(len(environment_list), ' '.join(map(str, environment_list)))) + + soap_cmd_line = " ".join(soap_cmd_line) + + def soaper(structure, positions): + structure = structure.copy() + for i in range(len(positions)): + structure.append(ase.Atom(position = tuple(positions[i]), symbol = tracer_atomic_number)) + return _soap(soap_cmd_line, structure, quip_path = quip_path) + + return soaper + return backend + + +def _soap(descriptor_str, + structure, + quip_path = 'quip'): + """Calculate SOAP vectors by calling `quip` as a subprocess. + + Args: + - descriptor_str (str): The QUIP descriptor str, i.e. `soap cutoff=3 ...` + - structure (ase.Atoms) + - quip_path (str): Path to `quip` executable + """ + + with NamedTemporaryFile() as xyz: + structure.write(xyz.name, format = 'extxyz') + + quip_cmd = [ + quip_path, + "atoms_filename=" + xyz.name, + "descriptor_str=\"" + descriptor_str + "\"" + ] + + result = subprocess.run(quip_cmd, stdout = subprocess.PIPE, check = True, text = True).stdout + + lines = result.splitlines() + + soaps = [] + for line in lines: + if line.startswith("DESC"): + soaps.append(np.fromstring(line.lstrip("DESC"), dtype = np.float, sep = ' ')) + elif line.startswith("Error"): + e = subprocess.CalledProcessError(returncode = 0, cmd = quip_cmd) + e.stdout = result + raise e + else: + continue + + soaps = np.asarray(soaps) + + return soaps diff --git a/sitator/util/DotProdClassifier.pyx b/sitator/util/DotProdClassifier.pyx index d933a5c..51a61c9 100644 --- a/sitator/util/DotProdClassifier.pyx +++ b/sitator/util/DotProdClassifier.pyx @@ -4,23 +4,13 @@ import numpy as np import numbers -import sys -try: - from IPython import get_ipython - ipy_str = str(type(get_ipython())) - if 'zmqshell' in ipy_str: - from tqdm import tqdm_notebook as tqdm - if 'terminal' in ipy_str: - from tqdm import tqdm -except: - if sys.stderr.isatty(): - from tqdm import tqdm - else: - def tqdm(iterable, **kwargs): - return iterable +from sitator.util.progress import tqdm N_SITES_ALLOC_INCREMENT = 100 +import logging +logger = logging.getLogger(__name__) + class OneValueListlike(object): def __init__(self, value, length = np.inf): self.length = length @@ -33,20 +23,26 @@ class OneValueListlike(object): return self.value class DotProdClassifier(object): + """Assign vectors to clusters indicated by a representative vector using a cosine metric. + + Cluster centers can be given through ``set_cluster_centers()`` or approximated + using the custom method described in the appendix of the main landmark + analysis paper (``fit_centers()``). + + Args: + threshold (float): Similarity threshold for joining a cluster. + In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal) + max_converge_iters (int): Maximum number of iterations. If the algorithm + hasn't converged by then, it will exit with a warning. + min_samples (float or int): filter out clusters with low sample counts. + If an int, filters out clusters with fewer samples than this. + If a float, filters out clusters with fewer than + floor(min_samples * n_assigned_samples) samples assigned to them. + """ def __init__(self, threshold = 0.9, max_converge_iters = 10, min_samples = 1): - """ - :param float threshold: Similarity threshold for joining a cluster. - In cos-of-angle-between-vectors (i.e. 1 is exactly the same, 0 is orthogonal) - :param int max_converge_iters: Maximum number of iterations. If the algorithm hasn't converged - by then, it will exit with a warning. - :param int|float min_samples: filter out clusters with low sample counts. - If an int, filters out clusters with fewer samples than this. - If a float, filters out clusters with fewer than floor(min_samples * n_assigned_samples) - samples assigned to them. - """ self._threshold = threshold self._max_iters = max_converge_iters self._min_samples = min_samples @@ -58,6 +54,9 @@ class DotProdClassifier(object): def cluster_centers(self): return self._cluster_centers + def set_cluster_centers(self, centers): + self._cluster_centers = centers + @property def cluster_counts(self): return self._cluster_counts @@ -66,7 +65,7 @@ class DotProdClassifier(object): def n_clusters(self): return len(self._cluster_counts) - def fit_predict(self, X, verbose = True, predict_threshold = None, return_info = False): + def fit_predict(self, X, verbose = True, predict_threshold = None, predict_normed = True, return_info = False): """ Fit the data vectors X and return their cluster labels. """ @@ -80,22 +79,136 @@ class DotProdClassifier(object): if predict_threshold is None: predict_threshold = self._threshold + if self._cluster_centers is None: + self.fit_centers(X) + + # Run a predict now: + labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold, predict_normed = predict_normed) + + total_n_assigned = np.sum(labels >= 0) + + # -- filter out low counts + if not self._min_samples is None: + self._cluster_counts = np.bincount(labels[labels >= 0], minlength = len(self._cluster_centers)) + + assert len(self._cluster_counts) == len(self._cluster_centers) + + min_samples = None + if isinstance(self._min_samples, numbers.Integral): + min_samples = self._min_samples + elif isinstance(self._min_samples, numbers.Real): + min_samples = int(np.floor(self._min_samples * total_n_assigned)) + else: + raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples) + min_samples = max(min_samples, 1) + + count_mask = self._cluster_counts >= min_samples + + self._cluster_centers = self._cluster_centers[count_mask] + self._cluster_counts = self._cluster_counts[count_mask] + + if len(self._cluster_centers) == 0: + # Then we removed everything... + raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask)) + + logger.info("DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \ + (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts))) + + # Do another predict -- this could be more efficient, but who cares? + labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold, predict_normed = predict_normed) + + if return_info: + info = { + 'clusters_below_min_samples' : np.sum(~count_mask), + 'kept_clusters_mask' : count_mask + } + return labels, confs, info + else: + return labels, confs + + def predict(self, X, return_confidences = False, threshold = None, predict_normed = True, verbose = True, ignore_zeros = True): + """Return a predicted cluster label for vectors X. + + :param float threshold: alternate threshold. Defaults to None, when self.threshold + is used. + + :returns: an array of labels. -1 indicates no assignment. + :returns: an array of confidences in assignments. Normalzied + values from 0 (no confidence, no label) to 1 (identical to cluster center). + """ + + assert len(X.shape) == 2, "Data must be 2D." + + if not X.shape[1] == (self._featuredim): + raise TypeError("X has wrong dimension %s; should be (%i)" % (X.shape, self._featuredim)) + + labels = np.empty(shape = len(X), dtype = np.int) + + if threshold is None: + threshold = self._threshold + + confidences = None + if return_confidences: + confidences = np.empty(shape = len(X), dtype = np.float) + + zeros_count = 0 + + if predict_normed: + center_norms = np.linalg.norm(self._cluster_centers, axis = 1) + normed_centers = self._cluster_centers.copy() + normed_centers /= center_norms[:, np.newaxis] + else: + normed_centers = self._cluster_centers + + # preallocate buffers + diffs = np.empty(shape = len(normed_centers), dtype = np.float) + + for i, x in enumerate(tqdm(X, desc = "Sample")): + + if np.all(x == 0): + if ignore_zeros: + labels[i] = -1 + zeros_count += 1 + continue + else: + raise ValueError("Data %i is all zeros!" % i) + + np.dot(normed_centers, x, out = diffs) + if predict_normed: + diffs /= np.linalg.norm(x) + np.abs(diffs, out = diffs) + + assigned_to = np.argmax(diffs) + assignment_confidence = diffs[assigned_to] + + if assignment_confidence < threshold: + assigned_to = -1 + assignment_confidence = 0.0 + + labels[i] = assigned_to + confidences[i] = assignment_confidence + + if zeros_count > 0: + logger.warning("Encountered %i zero vectors during prediction" % zeros_count) + + if return_confidences: + return labels, confidences + else: + return labels + + def fit_centers(self, X): # Essentially hierarchical clustering that stops when no cluster *centers* # are more similar than the threshold. - labels = np.empty(shape = len(X), dtype = np.int) labels.fill(-1) # Start with each sample as a cluster - #old_centers = np.copy(X) - #old_n_assigned = np.ones(shape = len(X), dtype = np.int) # For memory's sake, no copying old_centers = X old_n_assigned = OneValueListlike(value = 1, length = len(X)) old_n_clusters = len(X) # -- Classification loop - # Maximum number of iterations last_n_sites = -1 did_converge = False @@ -109,7 +222,7 @@ class DotProdClassifier(object): first_iter = True - for iteration in xrange(self._max_iters): + for iteration in tqdm(xrange(self._max_iters), desc = "Clustering iter.", total = float('inf')): # This iteration's centers # The first sample is always its own cluster cluster_centers[0] = old_centers[0] @@ -117,7 +230,7 @@ class DotProdClassifier(object): n_assigned_to[0] = old_n_assigned[0] n_clusters = 1 # skip the first sample which has already been accounted for - for i, vec in tqdm(zip(xrange(1, old_n_clusters), old_centers[1:old_n_clusters]), desc = "Iteration %i" % iteration): + for i, vec in zip(xrange(1, old_n_clusters), old_centers[1:old_n_clusters]): assigned_to = -1 assigned_cosang = 0.0 @@ -200,114 +313,3 @@ class DotProdClassifier(object): raise ValueError("Clustering did not converge after %i iterations" % (self._max_iters)) self._cluster_centers = np.asarray(cluster_centers[:n_clusters]) - - # Run a predict now: - labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold) - - total_n_assigned = np.sum(labels >= 0) - - # -- filter out low counts - if not self._min_samples is None: - self._cluster_counts = np.bincount(labels[labels >= 0]) - - assert len(self._cluster_counts) == len(self._cluster_centers) - - min_samples = None - if isinstance(self._min_samples, numbers.Integral): - min_samples = self._min_samples - elif isinstance(self._min_samples, numbers.Real): - min_samples = int(np.floor(self._min_samples * total_n_assigned)) - else: - raise ValueError("Invalid value `%s` for min_samples; must be integral or float." % self._min_samples) - - count_mask = self._cluster_counts >= min_samples - - self._cluster_centers = self._cluster_centers[count_mask] - self._cluster_counts = self._cluster_counts[count_mask] - - if len(self._cluster_centers) == 0: - # Then we removed everything... - raise ValueError("`min_samples` too large; all %i clusters under threshold." % len(count_mask)) - - if verbose: - print "DotProdClassifier: %i/%i assignment counts below threshold %s (%s); %i clusters remain." % \ - (np.sum(~count_mask), len(count_mask), self._min_samples, min_samples, len(self._cluster_counts)) - - # Do another predict -- this could be more efficient, but who cares? - labels, confs = self.predict(X, return_confidences = True, verbose = verbose, threshold = predict_threshold) - - if return_info: - info = { - 'clusters_below_min_samples' : np.sum(~count_mask) - } - return labels, confs, info - else: - return labels, confs - - def predict(self, X, return_confidences = False, threshold = None, verbose = True, ignore_zeros = True): - """Return a predicted cluster label for vectors X. - - :param float threshold: alternate threshold. Defaults to None, when self.threshold - is used. - - :returns: an array of labels. -1 indicates no assignment. - :returns: an array of confidences in assignments. Normalzied - values from 0 (no confidence, no label) to 1 (identical to cluster center). - """ - - assert len(X.shape) == 2, "Data must be 2D." - - if not X.shape[1] == (self._featuredim): - raise TypeError("X has wrong dimension %s; should be (%i)" % (X.shape, self._featuredim)) - - labels = np.empty(shape = len(X), dtype = np.int) - - if threshold is None: - threshold = self._threshold - - confidences = None - if return_confidences: - confidences = np.empty(shape = len(X), dtype = np.float) - - zeros_count = 0 - - center_norms = np.linalg.norm(self._cluster_centers, axis = 1) - normed_centers = self._cluster_centers.copy() - normed_centers /= center_norms[:, np.newaxis] - - # preallocate buffers - diffs = np.empty(shape = len(center_norms), dtype = np.float) - - for i, x in enumerate(tqdm(X, desc = "Sample")): - - if np.all(x == 0): - if ignore_zeros: - labels[i] = -1 - zeros_count += 1 - continue - else: - raise ValueError("Data %i is all zeros!" % i) - - # diffs = np.sum(x * self._cluster_centers, axis = 1) - # diffs /= np.linalg.norm(x) * center_norms - np.dot(normed_centers, x, out = diffs) - diffs /= np.linalg.norm(x) - #diffs /= center_norms - - assigned_to = np.argmax(diffs) - assignment_confidence = diffs[assigned_to] - - if assignment_confidence < threshold: - assigned_to = -1 - assignment_confidence = 0.0 - - labels[i] = assigned_to - confidences[i] = assignment_confidence - - if verbose and zeros_count > 0: - print "Encountered %i zero vectors during prediction" % zeros_count - - if return_confidences: - return labels, confidences - else: - return labels diff --git a/sitator/util/PBCCalculator.pyx b/sitator/util/PBCCalculator.pyx index f34830d..caf827f 100644 --- a/sitator/util/PBCCalculator.pyx +++ b/sitator/util/PBCCalculator.pyx @@ -11,18 +11,20 @@ ctypedef double precision ctypedef double cell_precision cdef class PBCCalculator(object): - """Performs calculations on collections of 3D points under PBC""" + """Performs calculations on collections of 3D points under PBC.""" cdef cell_precision [:, :] _cell_mat_array cdef cell_precision [:, :] _cell_mat_inverse_array cdef cell_precision [:] _cell_centroid cdef cell_precision [:, :] _cell + def __init__(self, cell): """ :param DxD ndarray: the unit cell -- an array of cell vectors, like the cell of an ASE :class:Atoms object. """ + cell = cell[:] cellmat = np.matrix(cell).T assert cell.shape[1] == cell.shape[0], "Cell must be square" @@ -32,15 +34,39 @@ cdef class PBCCalculator(object): self._cell_mat_inverse_array = np.asarray(cellmat.I) self._cell_centroid = np.sum(0.5 * cell, axis = 0) + @property def cell_centroid(self): return self._cell_centroid + + cpdef pairwise_distances(self, pts, out = None): + """Compute the pairwise distance matrix of ``pts`` with itself. + + :returns ndarray (len(pts), len(pts)): distances + """ + if out is None: + out = np.empty(shape = (len(pts), len(pts)), dtype = pts.dtype) + + buf = pts.copy() + + for i in xrange(len(pts) - 1): + out[i, i] = 0 + self.distances(pts[i], buf[i + 1:], in_place = True, out = out[i, i + 1:]) + out[i + 1:, i] = out[i, i + 1:] + buf[:] = pts + + out[len(pts) - 1, len(pts) - 1] = 0 + + return out + + cpdef distances(self, pt1, pts2, in_place = False, out = None): - """Compute the Euclidean distances from pt1 to all points in pts2, using - shift-and-wrap. + """ + Compute the Euclidean distances from ``pt1`` to all points in + ``pts2``, using shift-and-wrap. - Makes a copy of pts2 unless in_place == True. + Makes a copy of ``pts2`` unless ``in_place == True``. :returns ndarray len(pts2): distances """ @@ -76,6 +102,7 @@ cdef class PBCCalculator(object): #return np.linalg.norm(self._cell_centroid - pts2, axis = 1) return np.sqrt(out, out = out) + cpdef average(self, points, weights = None): """Average position of a "cloud" of points using the shift-and-wrap hack. @@ -83,13 +110,18 @@ cdef class PBCCalculator(object): Assumes that the points are relatively close (within a half unit cell) together, and that the first point is not a particular outsider (the - cell is centered at that point). + cell is centered at that point). If the average is weighted, the + maximally weighted point will be taken as the center. Can be a weighted average with the semantics of :func:numpy.average. """ assert points.shape[1] == 3 and points.ndim == 2 - offset = self._cell_centroid - points[0] + center_about = 0 + if weights is not None: + center_about = np.argmax(weights) + + offset = self._cell_centroid - points[center_about] ptbuf = points.copy() @@ -106,6 +138,7 @@ cdef class PBCCalculator(object): return out + cpdef time_average(self, frames): """Do multiple PBC correct means. Frames is n_frames x n_pts x 3. @@ -137,6 +170,7 @@ cdef class PBCCalculator(object): del posbuf return out + cpdef void wrap_point(self, precision [:] pt): """Wrap a single point into the unit cell, IN PLACE. 3D only.""" cdef cell_precision [:, :] cell = self._cell_mat_array @@ -158,7 +192,8 @@ cdef class PBCCalculator(object): pt[0] = buf[0]; pt[1] = buf[1]; pt[2] = buf[2]; - cpdef bint is_in_unit_cell(self, precision [:] pt): + + cpdef bint is_in_unit_cell(self, const precision [:] pt): cdef cell_precision [:, :] cell = self._cell_mat_array cdef cell_precision [:, :] cell_I = self._cell_mat_inverse_array @@ -174,13 +209,15 @@ cdef class PBCCalculator(object): return (buf[0] < 1.0) and (buf[1] < 1.0) and (buf[2] < 1.0) and \ (buf[0] >= 0.0) and (buf[1] >= 0.0) and (buf[2] >= 0.0) - cpdef bint all_in_unit_cell(self, precision [:, :] pts): + + cpdef bint all_in_unit_cell(self, const precision [:, :] pts): for pt in pts: if not self.is_in_unit_cell(pt): return False return True - cpdef bint is_in_image_of_cell(self, precision [:] pt, image): + + cpdef bint is_in_image_of_cell(self, const precision [:] pt, image): cdef cell_precision [:, :] cell = self._cell_mat_array cdef cell_precision [:, :] cell_I = self._cell_mat_inverse_array @@ -199,6 +236,7 @@ cdef class PBCCalculator(object): return out + cpdef void to_cell_coords(self, precision [:, :] points): """Convert to cell coordinates in place.""" assert points.shape[1] == 3, "Points must be 3D" @@ -220,11 +258,15 @@ cdef class PBCCalculator(object): # Store into points points[i, 0] = buf[0]; points[i, 1] = buf[1]; points[i, 2] = buf[2]; - cpdef int min_image(self, precision [:] ref, precision [:] pt): - """Find the minimum image of `pt` relative to `ref`. In place in pt. + + cpdef int min_image(self, const precision [:] ref, precision [:] pt): + """Find the minimum image of ``pt`` relative to ``ref``. In place in pt. Uses the brute force algorithm for correctness; returns the minimum image. + Assumes that ``ref`` and ``pt`` are already in the *same* cell (though not + necessarily the <0,0,0> cell -- any periodic image will do). + :returns int[3] minimg: Which image was the minimum image. """ # # There are 27 possible minimum images @@ -273,6 +315,7 @@ cdef class PBCCalculator(object): return 100 * minimg[0] + 10 * minimg[1] + 1 * minimg[2] + cpdef void to_real_coords(self, precision [:, :] points): """Convert to real coords from crystal coords in place.""" assert points.shape[1] == 3, "Points must be 3D" @@ -294,8 +337,9 @@ cdef class PBCCalculator(object): # Store into points points[i, 0] = buf[0]; points[i, 1] = buf[1]; points[i, 2] = buf[2]; + cpdef void wrap_points(self, precision [:, :] points): - """Wrap `points` into a unit cell, IN PLACE. 3D only. + """Wrap ``points`` into a unit cell, IN PLACE. 3D only. """ assert points.shape[1] == 3, "Points must be 3D" diff --git a/sitator/util/RecenterTrajectory.pyx b/sitator/util/RecenterTrajectory.pyx index 39e5d7c..1e3a7c9 100644 --- a/sitator/util/RecenterTrajectory.pyx +++ b/sitator/util/RecenterTrajectory.pyx @@ -18,7 +18,7 @@ class RecenterTrajectory(object): ``static_mask``, IN PLACE. Args: - structure (ASE Atoms): An atoms representing the structure of the + structure (ase.Atoms): An atoms representing the structure of the simulation. static_mask (ndarray): Boolean mask indicating which atoms to recenter on positions (ndarray): (n_frames, n_atoms, 3), modified in place @@ -32,6 +32,8 @@ class RecenterTrajectory(object): masses of all atoms in the system. """ + assert np.any(static_mask), "Static mask all false; there must be static atoms to recenter on." + factors = static_mask.astype(np.float) n_static = np.sum(static_mask) diff --git a/sitator/util/__init__.py b/sitator/util/__init__.py index 76ae207..b59424f 100644 --- a/sitator/util/__init__.py +++ b/sitator/util/__init__.py @@ -1,8 +1,8 @@ -from PBCCalculator import PBCCalculator +from .PBCCalculator import PBCCalculator -from DotProdClassifier import DotProdClassifier +from .DotProdClassifier import DotProdClassifier -from zeo import Zeopy +from .zeo import Zeopy -from RecenterTrajectory import RecenterTrajectory +from .RecenterTrajectory import RecenterTrajectory diff --git a/sitator/util/chemistry.py b/sitator/util/chemistry.py new file mode 100644 index 0000000..c0070ea --- /dev/null +++ b/sitator/util/chemistry.py @@ -0,0 +1,100 @@ +import numpy as np + +import ase.data +try: + from ase.formula import Formula +except ImportError: + from ase.util.formula import Formula + +from sitator.util import PBCCalculator + +# Key is a central atom, value is a list of the formulas for the rest. +# See, http://www.fccj.us/PolyatomicIons/CompletePolyatomicIonList.htm +DEFAULT_POLYATOMIC_IONS = { + 'Ar' : ['O4', 'O3'], + 'C' : ['N', 'O3', 'O2', 'NO', 'SN'], + 'B' : ['O3','O2'], + 'Br' : ['O4', 'O3', 'O2', 'O'], + 'Fe' : ['O4'], + 'I' : ['O3', 'O4', 'O2', 'O'], + 'Si' : ['O4', 'O3'], + 'S' : ['O5', 'O4', 'O3', 'SO3', 'O2'], + 'Sb' : ['O4', 'O3'], + 'Se' : ['O4', 'O3'], + 'Sn' : ['O3', 'O2'], + 'N' : ['O3', 'O2', 'CO'], + 'Re' : ['O4'], + 'Cl' : ['O4', 'O3', 'O2', 'O'], + 'Mn' : ['O4'], + 'Mo' : ['O4'], + 'Cr' : ['O4', 'CrO7', 'O2'], + 'P' : ['O4', 'O3', 'O2'], + 'Tc' : ['O4'], + 'Te' : ['O4', 'O6', 'O3'], + 'Pb' : ['O3', 'O2'], + 'W' : ['O4'] +} + +def identify_polyatomic_ions(structure, + cutoff_factor = 1.0, + ion_definitions = DEFAULT_POLYATOMIC_IONS): + """Find polyatomic ions in a structure. + + Goes to each potential polyatomic ion center and first checks if the nearest + neighbor is of a viable species. If it is, all nearest neighbors within the + maximum possible pairwise summed covalent radii are found and matched against + the database. + + Args: + structure (ase.Atoms) + cutoff_factor (float): Coefficient for the cutoff. Allows tuning just + how closely the polyatomic ions must be bound. Defaults to 1.0. + ion_definitions (dict): Database of polyatomic ions where the key is a + central atom species symbol and the value is a list of the formulas + for the rest. Defaults to a list of common, non-hydrogen-containing + polyatomic ions with Oxygen. + Returns: + list of tuples, one for each identified polyatomic anion containing its + formula, the index of the central atom, and the indexes of the remaining + atoms. + """ + ion_definitions = { + k : [Formula(f) for f in v] + for k, v in ion_definitions.items() + } + out = [] + # Go to each possible center atom in data, check that nearest neighbor is + # of right species, then get all within covalent distances and check if + # composition matches database... + pbcc = PBCCalculator(structure.cell) + dmat = pbcc.pairwise_distances(structure.positions) # Precompute + np.fill_diagonal(dmat, np.inf) + for center_i, center_symbol in enumerate(structure.symbols): + if center_symbol in ion_definitions: + nn_sym = structure.symbols[np.argmin(dmat[center_i])] + could_be = [f for f in ion_definitions[center_symbol] if nn_sym in f] + if len(could_be) == 0: + # Nearest neighbor isn't even a plausible polyatomic ion species, + # so skip this potential center. + continue + # Take the largest possible other species covalent radius for all + # other species it possibily could be. + cutoff = max( + ase.data.covalent_radii[ase.data.atomic_numbers[other_sym]] + for form in could_be for other_sym in form + ) + cutoff += ase.data.covalent_radii[ase.data.atomic_numbers[center_symbol]] + cutoff *= cutoff_factor + neighbors = dmat[center_i] <= cutoff + neighbors = np.where(neighbors)[0] + neighbor_formula = Formula.from_list(structure.symbols[neighbors]) + it_is = [f for f in could_be if f == neighbor_formula] + if len(it_is) > 1: + raise ValueError("Somehow identified single center %s (atom %i) as multiple polyatomic ions %s" % (center_symbol, center_i, it_is)) + elif len(it_is) == 1: + out.append(( + center_symbol + neighbor_formula.format('hill'), + center_i, + neighbors + )) + return out diff --git a/sitator/util/elbow.py b/sitator/util/elbow.py index ec038ac..078bec9 100644 --- a/sitator/util/elbow.py +++ b/sitator/util/elbow.py @@ -2,7 +2,7 @@ # See discussion around this question: https://stackoverflow.com/questions/2018178/finding-the-best-trade-off-point-on-a-curve/2022348#2022348 def index_of_elbow(points): - """Returns the index of the "elbow" in points. + """Returns the index of the "elbow" in ``points``. Decently fast and pretty approximate. Performs worse with disproportionately long "flat" tails. For example, in a dataset with a nearly right-angle elbow, diff --git a/sitator/util/mcl.py b/sitator/util/mcl.py new file mode 100644 index 0000000..784cf7a --- /dev/null +++ b/sitator/util/mcl.py @@ -0,0 +1,60 @@ +import numpy as np + +def markov_clustering(transition_matrix, + expansion = 2, + inflation = 2, + pruning_threshold = 0.00001, + iterlimit = 100): + """Compute the Markov Clustering of a graph. + See https://micans.org/mcl/. + + Because we're dealing with matrixes that are stochastic already, + there's no need to add artificial loop values. + + Implementation inspired by https://github.com/GuyAllard/markov_clustering + """ + + assert transition_matrix.shape[0] == transition_matrix.shape[1] + + # Check for nonzero diagonal -- self loops needed to avoid div by zero and NaNs + assert np.count_nonzero(transition_matrix.diagonal()) == len(transition_matrix) + + m1 = transition_matrix.copy() + + # Normalize (though it should be close already) + m1 /= np.sum(m1, axis = 0) + + allcols = np.arange(m1.shape[1]) + + converged = False + for i in range(iterlimit): + # -- Expansion + m2 = np.linalg.matrix_power(m1, expansion) + # -- Inflation + np.power(m2, inflation, out = m2) + m2 /= np.sum(m2, axis = 0) + # -- Prune + to_prune = m2 < pruning_threshold + # Exclude the max of every column + to_prune[np.argmax(m2, axis = 0), allcols] = False + m2[to_prune] = 0.0 + # -- Check converged + if np.allclose(m1, m2): + converged = True + break + + m1[:] = m2 + + if not converged: + raise ValueError("Markov Clustering couldn't converge in %i iterations" % iterlimit) + + # -- Get clusters + attractors = m2.diagonal().nonzero()[0] + + clusters = set() + + for a in attractors: + cluster = tuple(m2[a].nonzero()[0]) + clusters.add(cluster) + + return list(clusters) diff --git a/sitator/util/progress.py b/sitator/util/progress.py new file mode 100644 index 0000000..4ffe766 --- /dev/null +++ b/sitator/util/progress.py @@ -0,0 +1,14 @@ +import os + +progress = os.getenv('SITATOR_PROGRESSBAR', 'true').lower() +progress = (progress == 'true') or (progress == 'yes') or (progress == 'on') + +if progress: + try: + from tqdm.autonotebook import tqdm + except: + def tqdm(iterable, **kwargs): + return iterable +else: + def tqdm(iterable, **kwargs): + return iterable diff --git a/sitator/util/qvoronoi.py b/sitator/util/qvoronoi.py deleted file mode 100644 index 4156a06..0000000 --- a/sitator/util/qvoronoi.py +++ /dev/null @@ -1,144 +0,0 @@ - -import tempfile -import subprocess -import sys - -import re -from collections import OrderedDict - -import numpy as np - -from sklearn.neighbors import KDTree - -from sitator.util import PBCCalculator - -def periodic_voronoi(structure, logfile = sys.stdout): - """ - :param ASE.Atoms structure: - """ - - pbcc = PBCCalculator(structure.cell) - - # Make a 3x3x3 supercell - supercell = structure.repeat((3, 3, 3)) - - qhull_output = None - - logfile.write("Qvoronoi ---") - - # Run qhull - with tempfile.NamedTemporaryFile('w', - prefix = 'qvor', - suffix='.in', delete = False) as infile, \ - tempfile.NamedTemporaryFile('r', - prefix = 'qvor', - suffix='.out', - delete=True) as outfile: - # -- Write input file -- - infile.write("3\n") # num of dimensions - infile.write("%i\n" % len(supercell)) # num of points - np.savetxt(infile, supercell.get_positions(), fmt = '%.16f') - infile.flush() - - cmdline = ["qvoronoi", "TI", infile.name, "FF", "Fv", "TO", outfile.name] - process = subprocess.Popen(cmdline, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) - retcode = process.wait() - logfile.write(process.stdout.read()) - if retcode != 0: - raise RuntimeError("qvoronoi returned exit code %i" % retcode) - - qhull_output = outfile.read() - - facets_regex = re.compile( - """ - -[ \t](?Pf[0-9]+) [\n] - [ \t]*-[ ]flags: .* [\n] - [ \t]*-[ ]normal: .* [\n] - [ \t]*-[ ]offset: .* [\n] - [ \t]*-[ ]center:(?P
([ ][\-]?[0-9]*[\.]?[0-9]*(e[-?[0-9]+)?){3}) [ \t] [\n] - [ \t]*-[ ]vertices:(?P([ ]p[0-9]+\(v[0-9]+\))+) [ \t]? [\n] - [ \t]*-[ ]neighboring[ ]facets:(?P([ ]f[0-9]+)+) - """, re.X | re.M) - - vertices_re = re.compile('(?<=p)[0-9]+') - - # Allocate stuff - centers = [] - vertices = [] - facet_indexes_taken = set() - - facet_index_to_our_index = {} - all_facets_centers = [] - - # ---- Read facets - facet_index = -1 - next_our_index = 0 - for facet_match in facets_regex.finditer(qhull_output): - center = np.asarray(map(float, facet_match.group('center').split())) - facet_index += 1 - - all_facets_centers.append(center) - - if not pbcc.is_in_image_of_cell(center, (1, 1, 1)): - continue - - verts = map(int, vertices_re.findall(facet_match.group('vertices'))) - verts_in_main_cell = tuple(v % len(structure) for v in verts) - - facet_indexes_taken.add(facet_index) - - centers.append(center) - vertices.append(verts_in_main_cell) - - facet_index_to_our_index[facet_index] = next_our_index - - next_our_index += 1 - - end_of_facets = facet_match.end() - - facet_count = facet_index + 1 - - logfile.write(" qhull gave %i vertices; kept %i" % (facet_count, len(centers))) - - # ---- Read ridges - qhull_output_after_facets = qhull_output[end_of_facets:].strip() - ridge_re = re.compile('^\d+ \d+ \d+(?P( \d+)+)$', re.M) - - ridges = [ - [int(v) for v in match.group('verts').split()] - for match in ridge_re.finditer(qhull_output_after_facets) - ] - # only take ridges with at least 1 facet in main unit cell. - ridges = [ - r for r in ridges if any(f in facet_indexes_taken for f in r) - ] - - # shift centers back into normal unit cell - centers -= np.sum(structure.cell, axis = 0) - - nearest_center = KDTree(centers) - - ridges_in_main_cell = set() - threw_out = 0 - for r in ridges: - ridge_centers = np.asarray([all_facets_centers[f] for f in r if f < len(all_facets_centers)]) - if not pbcc.all_in_unit_cell(ridge_centers): - continue - - pbcc.wrap_points(ridge_centers) - dists, ridge_centers_in_main = nearest_center.query(ridge_centers, return_distance = True) - - if np.any(dists > 0.00001): - threw_out += 1 - continue - - assert ridge_centers_in_main.shape == (len(ridge_centers), 1), "%s" % ridge_centers_in_main.shape - ridge_centers_in_main = ridge_centers_in_main[:,0] - - ridges_in_main_cell.add(frozenset(ridge_centers_in_main)) - - logfile.write(" Threw out %i ridges" % threw_out) - - logfile.flush() - - return centers, vertices, ridges_in_main_cell diff --git a/sitator/util/zeo.py b/sitator/util/zeo.py index 430e4bd..19bddf0 100644 --- a/sitator/util/zeo.py +++ b/sitator/util/zeo.py @@ -1,9 +1,6 @@ #zeopy: simple Python interface to the Zeo++ `network` tool. # Alby Musaelian 2018 -from __future__ import (absolute_import, division, - print_function) - import os import sys import tempfile @@ -19,10 +16,13 @@ from sitator.util import PBCCalculator +import logging +logger = logging.getLogger(__name__) + # TODO: benchmark CUC vs CIF class Zeopy(object): - """A wrapper for the Zeo++ `network` tool. + """A wrapper for the Zeo++ ``network`` tool. :warning: Do not use a single instance of Zeopy in parallel. """ @@ -30,7 +30,7 @@ class Zeopy(object): def __init__(self, path_to_zeo): """Create a Zeopy. - :param str path_to_zeo: Path to the `network` executable. + :param str path_to_zeo: Path to the ``network`` executable. """ if not (os.path.exists(path_to_zeo) and os.access(path_to_zeo, os.X_OK)): raise ValueError("`%s` doesn't seem to be the path to an executable file." % path_to_zeo) @@ -43,7 +43,7 @@ def __enter__(self): def __exit__(self, *args): shutil.rmtree(self._tmpdir) - def voronoi(self, structure, radial = False, verbose=True): + def voronoi(self, structure, radial = False): """ :param Atoms structure: The ASE Atoms to compute the Voronoi decomposition of. """ @@ -55,7 +55,7 @@ def voronoi(self, structure, radial = False, verbose=True): outp = os.path.join(self._tmpdir, "out.nt2") v1out = os.path.join(self._tmpdir, "out.v1") - ase.io.write(inp, structure) + ase.io.write(inp, structure, parallel = False) # with open(inp, "w") as inf: # inf.write(self.ase2cuc(structure)) @@ -69,12 +69,11 @@ def voronoi(self, structure, radial = False, verbose=True): output = subprocess.check_output([self._exe] + args + ["-v1", v1out, "-nt2", outp, inp], stderr = subprocess.STDOUT) except subprocess.CalledProcessError as e: - print("Zeo++ returned an error:", file = sys.stderr) - print(e.output, file = sys.stderr) + logger.error("Zeo++ returned an error:") + logger.error(str(e.output)) raise - if verbose: - print(output) + logger.debug(output) with open(outp, "r") as outf: verts, edges = self.parse_nt2(outf.readlines()) @@ -130,12 +129,12 @@ def parse_v1_cell(v1lines): # remove blank lines: v1lines = iter(filter(None, v1lines)) # First line is just "Unit cell vectors:" - assert v1lines.next().strip() == "Unit cell vectors:" + assert next(v1lines).strip() == "Unit cell vectors:" # Unit cell: cell = np.empty(shape = (3, 3), dtype = np.float) cellvec_re = re.compile('v[abc]=') - for i in xrange(3): - cellvec = v1lines.next().strip().split() + for i in range(3): + cellvec = next(v1lines).strip().split() assert cellvec_re.match(cellvec[0]) cell[i] = [float(e) for e in cellvec[1:]] # number of atoms, etc. diff --git a/sitator/visualization/SiteNetworkPlotter.py b/sitator/visualization/SiteNetworkPlotter.py index 157da1d..8c529e8 100644 --- a/sitator/visualization/SiteNetworkPlotter.py +++ b/sitator/visualization/SiteNetworkPlotter.py @@ -1,7 +1,3 @@ -from __future__ import (absolute_import, division, - print_function, unicode_literals) -from builtins import * - import numpy as np import itertools @@ -10,21 +6,35 @@ from mpl_toolkits.mplot3d.art3d import Line3DCollection from sitator.util import PBCCalculator -from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS +from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS, set_axes_equal class SiteNetworkPlotter(object): - """Plot a SiteNetwork. - - site_mappings defines how to show different properties. Each entry maps a - visual aspect ('marker', 'color', 'size') to the name of a site attribute - including 'site_type'. - - Likewise for edge_mappings, each key maps a visual property ('intensity', 'color', - 'width', 'linestyle') to an edge attribute in the SiteNetwork. + """Plot a ``SiteNetwork``. Note that for edges, the average of the edge property for i -> j and j -> i is often used for visual clarity; if your edge properties are not almost symmetric, the visualization might not be useful. + + Args: + site_mappings (dict): defines how to show different properties. Each + entry maps a visual aspect ('marker', 'color', 'size') to the name + of a site attribute including 'site_type'. The markers can also be + arbitrary text (key `"text"`) in which case the value can also be a + 2-tuple of an attribute name and a `%` format string. + edge_mappings (dict): each key maps a visual property ('intensity', + 'color', 'width', 'linestyle') to an edge attribute in the SiteNetwork. + markers (list of str): What `matplotlib` markers to use for sites. + plot_points_params (dict): User options for plotting site points. + minmax_linewidth (2-tuple): Minimum and maximum linewidth to use. + minmax_edge_alpha (2-tuple): Similar, for edge line alphas. + minmax_markersize (2-tuple): Similar, for markersize. + min_color_threshold (float): Minimum (normalized) color intensity for + the corresponding line to be shown. Defaults to zero, i.e., all + nonzero edges will be drawn. + min_width_threshold (float): Minimum normalized edge width for the + corresponding edge to be shown. Defaults to zero, i.e., all + nonzero edges will be drawn. + title (str): Title for the figure. """ DEFAULT_SITE_MAPPINGS = { @@ -34,7 +44,7 @@ class SiteNetworkPlotter(object): DEFAULT_MARKERS = ['x', '+', 'v', '<', '^', '>', '*', 'd', 'h', 'p'] DEFAULT_LINESTYLES = ['--', ':', '-.', '-'] - EDGE_GROUP_COLORS = ['b', 'g', 'm', 'lightseagreen', 'crimson'] + ['gray'] # gray last for -1's + EDGE_GROUP_COLORS = ['b', 'g', 'm', 'crimson', 'lightseagreen', 'darkorange', 'sandybrown', 'gold', 'hotpink'] + ['gray'] # gray last for -1's def __init__(self, site_mappings = DEFAULT_SITE_MAPPINGS, @@ -43,11 +53,12 @@ def __init__(self, plot_points_params = {}, minmax_linewidth = (1.5, 7), minmax_edge_alpha = (0.15, 0.75), - minmax_markersize = (80.0, 180.0), - min_color_threshold = 0.005, - min_width_threshold = 0.005, + minmax_markersize = (20.0, 80.0), + min_color_threshold = 0.0, + min_width_threshold = 0.0, title = ""): self.site_mappings = site_mappings + assert not ("marker" in site_mappings and "text" in site_mappings) self.edge_mappings = edge_mappings self.markers = markers self.plot_points_params = plot_points_params @@ -66,7 +77,6 @@ def __call__(self, sn, *args, **kwargs): # -- Plot actual SiteNetwork -- l = [(plot_atoms, {'atoms' : sn.static_structure})] l += self._site_layers(sn, self.plot_points_params) - l += self._plot_edges(sn, *args, **kwargs) # -- Some visual clean up -- @@ -74,7 +84,6 @@ def __call__(self, sn, *args, **kwargs): ax.set_title(self.title) - ax.set_aspect('equal') ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False @@ -90,26 +99,47 @@ def __call__(self, sn, *args, **kwargs): # -- Put it all together -- layers(*l, **kwargs) - def _site_layers(self, sn, plot_points_params): + def _site_layers(self, sn, plot_points_params, same_normalization = False): pts_arrays = {'points' : sn.centers} - pts_params = {'cmap' : 'rainbow'} + pts_params = {'cmap' : 'winter'} # -- Apply mapping # - other mappings markers = None for key in self.site_mappings: - val = getattr(sn, self.site_mappings[key]) + val = self.site_mappings[key] + if isinstance(val, tuple): + val, param = val + else: + param = None + val = getattr(sn, val) if key == 'marker': if not val is None: markers = val.copy() + istextmarker = False + elif key == 'text': + istextmarker = True + format_str = "%s" if param is None else param + format_str = "$" + format_str + "$" + markers = val.copy() elif key == 'color': pts_arrays['c'] = val.copy() - pts_params['norm'] = matplotlib.colors.Normalize(vmin = np.min(val), vmax = np.max(val)) + if not same_normalization: + self._color_minmax = [np.min(val), np.max(val)] + if self._color_minmax[0] == self._color_minmax[1]: + self._color_minmax[0] -= 1 # Just to avoid div by zero + color_minmax = self._color_minmax + pts_params['norm'] = matplotlib.colors.Normalize(vmin = color_minmax[0], vmax = color_minmax[1]) elif key == 'size': + if not same_normalization: + self._size_minmax = [np.min(val), np.max(val)] + if self._size_minmax[0] == self._size_minmax[1]: + self._size_minmax[0] -= 1 # Just to avoid div by zero + size_minmax = self._size_minmax s = val.copy() - s += np.min(s) - s /= np.max(s) + s -= size_minmax[0] + s /= size_minmax[1] - size_minmax[0] s *= self.minmax_markersize[1] s += self.minmax_markersize[0] pts_arrays['s'] = s @@ -122,17 +152,27 @@ def _site_layers(self, sn, plot_points_params): # Just one layer with all points and one marker marker_layers[SiteNetworkPlotter.DEFAULT_MARKERS[0]] = np.ones(shape = sn.n_sites, dtype = np.bool) else: - markers = self._make_discrete(markers) + if not istextmarker: + markers = self._make_discrete(markers) unique_markers = np.unique(markers) - marker_i = 0 + if not same_normalization: + if istextmarker: + self._marker_table = dict(zip(unique_markers, (format_str % um for um in unique_markers))) + else: + if len(unique_markers) > len(self.markers): + raise ValueError("Too many distinct values of the site property mapped to markers (there are %i) for the %i markers in `self.markers`" % (len(unique_markers), len(self.markers))) + self._marker_table = dict(zip(unique_markers, self.markers[:len(unique_markers)])) + for um in unique_markers: - marker_layers[SiteNetworkPlotter.DEFAULT_MARKERS[marker_i]] = (markers == um) - marker_i += 1 + marker_layers[self._marker_table[um]] = (markers == um) # -- Do plot # If no color info provided, a fallback if not 'color' in pts_params and not 'c' in pts_arrays: pts_params['color'] = 'k' + # If no color info provided, a fallback + if not 's' in pts_params and not 's' in pts_arrays: + pts_params['s'] = sum(self.minmax_markersize) / 2 # Add user options for `plot_points` pts_params.update(plot_points_params) @@ -201,8 +241,8 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs): sites_to_plot = [] sites_to_plot_positions = [] - for i in xrange(n_sites): - for j in xrange(n_sites): + for i in range(n_sites): + for j in range(n_sites): # No self edges if i == j: continue @@ -210,9 +250,9 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs): if done_already[i, j]: continue # Ignore anything below the threshold - if all_cs[i, j] < self.min_color_threshold: + if all_cs[i, j] <= self.min_color_threshold: continue - if do_widths and all_linewidths[i, j] < self.min_width_threshold: + if do_widths and all_linewidths[i, j] <= self.min_width_threshold: continue segment = np.empty(shape = (2, 3), dtype = centers.dtype) @@ -254,7 +294,9 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs): lccolors = np.empty(shape = (len(cs), 4), dtype = np.float) # Group colors if do_groups: - for i in xrange(len(cs)): + for i in range(len(cs)): + if groups[i] >= len(SiteNetworkPlotter.EDGE_GROUP_COLORS) - 1: + raise ValueError("Too many groups, not enough group colors") lccolors[i] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[groups[i]]) else: lccolors[:] = matplotlib.colors.to_rgba(SiteNetworkPlotter.EDGE_GROUP_COLORS[0]) @@ -273,12 +315,14 @@ def _plot_edges(self, sn, ax = None, *args, **kwargs): ax.add_collection(lc) # -- Plot new sites - sn2 = sn[sites_to_plot] - sn2.update_centers(np.asarray(sites_to_plot_positions)) - - pts_params = dict(self.plot_points_params) - pts_params['alpha'] = 0.2 - return self._site_layers(sn2, pts_params) + if len(sites_to_plot) > 0: + sn2 = sn[sites_to_plot] + sn2.update_centers(np.asarray(sites_to_plot_positions)) + pts_params = dict(self.plot_points_params) + pts_params['alpha'] = 0.2 + return self._site_layers(sn2, pts_params, same_normalization = True) + else: + return [] else: return [] diff --git a/sitator/visualization/SiteTrajectoryPlotter.py b/sitator/visualization/SiteTrajectoryPlotter.py new file mode 100644 index 0000000..9994dfe --- /dev/null +++ b/sitator/visualization/SiteTrajectoryPlotter.py @@ -0,0 +1,158 @@ +import numpy as np + +import matplotlib +from matplotlib.collections import LineCollection + +from sitator.util import PBCCalculator +from sitator.visualization import plotter, plot_atoms, plot_points, layers, DEFAULT_COLORS + + +class SiteTrajectoryPlotter(object): + """Produce various plots of a ``SiteTrajectory``.""" + + @plotter(is3D = True) + def plot_frame(self, st, frame, **kwargs): + """Plot sites and instantaneous positions from a given frame. + + Args: + st (SiteTrajectory) + frame (int) + """ + sites_of_frame = np.unique(st._traj[frame]) + frame_sn = st._sn[sites_of_frame] + + frame_sn.plot(**kwargs) + + if not st._real_traj is None: + mobile_atoms = st._sn.structure.copy() + del mobile_atoms[~st._sn.mobile_mask] + + mobile_atoms.positions[:] = st._real_traj[frame, st._sn.mobile_mask] + plot_atoms(atoms = mobile_atoms, **kwargs) + + kwargs['ax'].set_title("Frame %i/%i" % (frame, st.n_frames)) + + + @plotter(is3D = True) + def plot_site(self, st, site, **kwargs): + """Plot all real space positions associated with a site. + + Args: + st (SiteTrajectory) + site (int) + """ + pbcc = PBCCalculator(st._sn.structure.cell) + pts = st.real_positions_for_site(site).copy() + offset = pbcc.cell_centroid - pts[3] + pts += offset + pbcc.wrap_points(pts) + lattice_pos = st._sn.static_structure.positions.copy() + lattice_pos += offset + pbcc.wrap_points(lattice_pos) + site_pos = st._sn.centers[site:site+1].copy() + site_pos += offset + pbcc.wrap_points(site_pos) + # Plot point cloud + plot_points(points = pts, alpha = 0.3, marker = '.', color = 'k', **kwargs) + # Plot site + plot_points(points = site_pos, color = 'cyan', **kwargs) + # Plot everything else + plot_atoms(st._sn.static_structure, positions = lattice_pos, **kwargs) + + title = "Site %i/%i" % (site, len(st._sn)) + + if not st._sn.site_types is None: + title += " (type %i)" % st._sn.site_types[site] + + kwargs['ax'].set_title(title) + + + @plotter(is3D = False) + def plot_particle_trajectory(self, st, particle, ax = None, fig = None, **kwargs): + """Plot the sites occupied by a mobile particle over time. + + Args: + st (SiteTrajectory) + particle (int) + """ + types = not st._sn.site_types is None + if types: + type_height_percent = 0.1 + axpos = ax.get_position() + typeax_height = type_height_percent * axpos.height + typeax = fig.add_axes([axpos.x0, axpos.y0, axpos.width, typeax_height], sharex = ax) + ax.set_position([axpos.x0, axpos.y0 + typeax_height, axpos.width, axpos.height - typeax_height]) + type_height = 1 + # Draw trajectory + segments = [] + linestyles = [] + colors = [] + + traj = st._traj[:, particle] + current_value = traj[0] + last_value = traj[0] + if types: + last_type = None + current_segment_start = 0 + puttext = False + + for i, f in enumerate(traj): + if f != current_value or i == len(traj) - 1: + val = last_value if current_value == -1 else current_value + if last_value == -1: + segments.append([[current_segment_start, val], [i, val]]) + else: + segments.append([[current_segment_start, last_value], [current_segment_start, val], [i, val]]) + linestyles.append(':' if current_value == -1 else '-') + if current_value == -1: + if val == -1: + c = 'red' # Uncorrected unknown + else: + c = 'lightgray' # Unknown but reassigned + else: + c = 'k' # Known + colors.append(c) + + if types: + rxy = (current_segment_start, 0) + this_type = st._sn.site_types[val] + typerect = matplotlib.patches.Rectangle(rxy, i - current_segment_start, type_height, + color = DEFAULT_COLORS[this_type], linewidth = 0) + typeax.add_patch(typerect) + if this_type != last_type: + typeax.annotate("T%i" % this_type, + xy = (rxy[0], rxy[1] + 0.5 * type_height), + xytext = (3, -1), + textcoords = 'offset points', + fontsize = 'xx-small', + va = 'center', + fontweight = 'bold') + last_type = this_type + + last_value = val + current_segment_start = i + current_value = f + + lc = LineCollection(segments, linestyles = linestyles, colors = colors, linewidth=1.5) + ax.add_collection(lc) + + if types: + typeax.set_xlabel("Frame") + ax.tick_params(axis = 'x', which = 'both', bottom = False, top = False, labelbottom = False) + typeax.tick_params(axis = 'y', which = 'both', left = False, right = False, labelleft = False) + typeax.annotate("Type", xy = (0, 0.5), xytext = (-25, 0), xycoords = 'axes fraction', textcoords = 'offset points', va = 'center', fontsize = 'x-small') + else: + ax.set_xlabel("Frame") + ax.set_ylabel("Atom %i's site" % particle) + + ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) + ax.grid() + + ax.set_xlim((0, st.n_frames - 1)) + margin_percent = 0.04 + ymargin = (margin_percent * st._sn.n_sites) + ax.set_ylim((-ymargin, st._sn.n_sites - 1.0 + ymargin)) + + if types: + typeax.set_xlim((0, st.n_frames - 1)) + typeax.set_ylim((0, type_height)) diff --git a/sitator/visualization/__init__.py b/sitator/visualization/__init__.py index 4624148..e96d843 100644 --- a/sitator/visualization/__init__.py +++ b/sitator/visualization/__init__.py @@ -1,5 +1,7 @@ -from common import layers, grid, plotter, DEFAULT_COLORS +from .common import layers, grid, plotter, DEFAULT_COLORS, set_axes_equal -from atoms import plot_atoms, plot_points +from .atoms import plot_atoms, plot_points -from SiteNetworkPlotter import SiteNetworkPlotter +from .SiteNetworkPlotter import SiteNetworkPlotter + +from .SiteTrajectoryPlotter import SiteTrajectoryPlotter diff --git a/sitator/visualization/atoms.py b/sitator/visualization/atoms.py index 1d0cc9c..b4f76cd 100644 --- a/sitator/visualization/atoms.py +++ b/sitator/visualization/atoms.py @@ -36,9 +36,20 @@ def plot_atoms(atoms, positions = None, hide_species = (), wrap = False, fig = N all_cvecs = [] - whos_left = set(xrange(len(atoms.cell))) + whos_left = set(range(len(atoms.cell))) + cvec_labels = ["$\\vec{a}$", "$\\vec{b}$", "$\\vec{c}$"] for i, cvec1 in enumerate(atoms.cell): all_cvecs.append(np.array([[0.0, 0.0, 0.0], cvec1])) + ax.text( + cvec1[0] * 0.25, + cvec1[1] * 0.25, + cvec1[2] * 0.25, + cvec_labels[i], + size = 9, + color = 'gray', + ha = 'left', + va = 'center' + ) for j, cvec2 in enumerate(atoms.cell[list(whos_left - {i})]): all_cvecs.append(np.array([cvec1, cvec1 + cvec2])) for i, cvec1 in enumerate(atoms.cell): diff --git a/sitator/visualization/common.py b/sitator/visualization/common.py index 9c4e702..74db185 100644 --- a/sitator/visualization/common.py +++ b/sitator/visualization/common.py @@ -33,8 +33,10 @@ def plotter_wrapper(func): def plotter_wraped(*args, **kwargs): fig = None ax = None + toplevel = False if not ('ax' in kwargs and 'fig' in kwargs): - # No existing axis/figure + # No existing axis/figure - toplevel + toplevel = True fig = plt.figure(**outer) if is3D: ax = fig.add_subplot(111, projection = '3d') @@ -54,14 +56,19 @@ def plotter_wraped(*args, **kwargs): kwargs['i'] = 0 func(*args, fig = fig, ax = ax, **kwargs) + + if is3D and toplevel: + set_axes_equal(ax) + return fig, ax + return plotter_wraped + return plotter_wrapper @plotter(is3D = True) def layers(*args, **fax): i = fax['i'] - print i for p, kwargs in args: p(fig = fax['fig'], ax = fax['ax'], i = i, **kwargs) i += 1 diff --git a/sitator/voronoi.py b/sitator/voronoi.py new file mode 100644 index 0000000..748e819 --- /dev/null +++ b/sitator/voronoi.py @@ -0,0 +1,50 @@ + +import numpy as np + +import os + +from sitator import SiteNetwork +from sitator.util import Zeopy + +class VoronoiSiteGenerator(object): + """Given an empty SiteNetwork, use the Voronoi decomposition to predict/generate sites. + + :param str zeopp_path: Path to the Zeo++ ``network`` executable + :param bool radial: Whether to use the radial Voronoi transform. Defaults to, + and should typically be, ``False``. + """ + + def __init__(self, + zeopp_path = os.getenv("SITATOR_ZEO_PATH", default = "network"), + radial = False): + self._radial = radial + self._zeopy = Zeopy(zeopp_path) + + def run(self, sn, seed_mask = None): + """ + Args: + sn (SiteNetwork): Any sites will be ignored; needed for structure + and static mask. + Returns: + A ``SiteNetwork``. + """ + assert isinstance(sn, SiteNetwork) + + if seed_mask is None: + seed_mask = sn.static_mask + assert not np.any(seed_mask & sn.mobile_mask), "Seed mask must not overlap with mobile mask" + assert not np.any(seed_mask & ~sn.static_mask), "All seed atoms must be static." + voro_struct = sn.structure[seed_mask] + translation = np.zeros(shape = len(sn.static_mask), dtype = np.int) + translation[sn.static_mask] = np.arange(sn.n_static) + translation = translation[seed_mask] + + with self._zeopy: + nodes, verts, edges, _ = self._zeopy.voronoi(voro_struct, + radial = self._radial) + + out = sn.copy() + out.centers = nodes + out.vertices = [translation[v] for v in verts] + + return out diff --git a/sitator/voronoi/VoronoiSiteGenerator.py b/sitator/voronoi/VoronoiSiteGenerator.py deleted file mode 100644 index 64c2058..0000000 --- a/sitator/voronoi/VoronoiSiteGenerator.py +++ /dev/null @@ -1,34 +0,0 @@ - -import numpy as np - -from sitator import SiteNetwork -from sitator.util import Zeopy - -class VoronoiSiteGenerator(object): - """Given an empty SiteNetwork, use the Voronoi decomposition to predict/generate sites. - - :param str zeopp_path: Path to the Zeo++ `network` executable - :param bool radial: Whether to use the radial Voronoi transform. Defaults to, - and should typically be, False. - :param bool verbose: - """ - - def __init__(self, zeopp_path = "network", radial = False, verbose = True): - self._radial = radial - self._verbose = verbose - self._zeopy = Zeopy(zeopp_path) - - def run(self, sn): - """SiteNetwork -> SiteNetwork""" - assert isinstance(sn, SiteNetwork) - - with self._zeopy: - nodes, verts, edges, _ = self._zeopy.voronoi(sn.static_structure, - radial = self._radial, - verbose = self._verbose) - - out = sn.copy() - out.centers = nodes - out.vertices = verts - - return out diff --git a/sitator/voronoi/__init__.py b/sitator/voronoi/__init__.py deleted file mode 100644 index 17c3b95..0000000 --- a/sitator/voronoi/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ - -from VoronoiSiteGenerator import VoronoiSiteGenerator