You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
grogu/tests/test_io.py

255 lines
7.9 KiB

# Copyright (c) [2024] []
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import io
import os
import tempfile
from contextlib import redirect_stdout
from pathlib import Path
import numpy as np
import pytest
from grogupy.io import (
default_args,
load_pickle,
print_atoms_and_pairs,
print_job_description,
print_parameters,
print_runtime_information,
save_pickle,
)
# Fixtures for common test data
@pytest.fixture
def simulation_parameters():
"""Create sample simulation parameters for testing"""
params = default_args.copy()
params.update(
{
"infile": "test.fdf",
"outfile": "test_output.pickle",
"parallel_size": 4,
"cell": np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
"automatic_ebot": True,
}
)
return params
@pytest.fixture
def magnetic_entities():
"""Create sample magnetic entities for testing"""
return [
{
"tags": ["[0]Fe(d)"],
"xyz": [np.array([0.0, 0.0, 0.0])],
"K": np.array([1.0, 2.0, 3.0]),
"K_consistency": 0.001,
},
{
"tags": ["[1]Fe(d)"],
"xyz": [np.array([1.0, 1.0, 1.0])],
"K": np.array([2.0, 3.0, 4.0]),
"K_consistency": 0.002,
},
]
@pytest.fixture
def pairs():
"""Create sample pairs for testing"""
return [
{
"tags": ["[0]Fe(d)", "[1]Fe(d)"],
"Ruc": np.array([0, 0, 0]),
"dist": 1.732,
"J_iso": 10.0,
"D": np.array([0.1, 0.2, 0.3]),
"J_S": np.array([0.5, 0.6, 0.7, 0.8, 0.9]),
"J": np.array([[1.0, 0.1, 0.2], [0.1, 1.1, 0.3], [0.2, 0.3, 1.2]]),
}
]
@pytest.fixture
def runtime_info():
"""Create sample runtime information for testing"""
return {
"start_time": 0.0,
"setup_time": 1.0,
"H_and_XCF_time": 2.0,
"site_and_pair_dictionaries_time": 3.0,
"k_set_time": 4.0,
"reference_rotations_time": 5.0,
"green_function_inversion_time": 6.0,
"end_time": 7.0,
}
# Test pickle save/load functionality
def test_pickle_save_load(simulation_parameters):
"""Test saving and loading data using pickle"""
with tempfile.NamedTemporaryFile(suffix=".pickle", delete=False) as tmp:
temp_path = tmp.name
try:
# Test saving
save_pickle(temp_path, simulation_parameters)
assert os.path.exists(temp_path)
assert os.path.getsize(temp_path) > 0
# Test loading
loaded_data = load_pickle(temp_path)
assert loaded_data == simulation_parameters
finally:
# Cleanup
if os.path.exists(temp_path):
os.unlink(temp_path)
# Test print functions
def test_print_parameters(simulation_parameters):
"""Test parameters printing function"""
output = io.StringIO()
with redirect_stdout(output):
print_parameters(simulation_parameters)
printed_output = output.getvalue()
# Check key elements are present in output
assert "Input file:" in printed_output
assert "Cell [Ang]:" in printed_output
assert "DFT axis:" in printed_output
assert str(simulation_parameters["kset"]) in printed_output
def test_print_atoms_and_pairs(magnetic_entities, pairs):
"""Test atoms and pairs printing function"""
output = io.StringIO()
with redirect_stdout(output):
print_atoms_and_pairs(magnetic_entities, pairs)
printed_output = output.getvalue()
# Check key elements are present in output
assert "Atomic information:" in printed_output
assert "Anisotropy [meV]" in printed_output
assert "Exchange [meV]" in printed_output
# Check if magnetic entity information is printed
for entity in magnetic_entities:
assert entity["tags"][0] in printed_output
# Check if pair information is printed
for pair in pairs:
assert pair["tags"][0] in printed_output
assert pair["tags"][1] in printed_output
def test_print_runtime_information(runtime_info):
"""Test runtime information printing function"""
output = io.StringIO()
with redirect_stdout(output):
print_runtime_information(runtime_info)
printed_output = output.getvalue()
# Check key elements are present in output
assert "Runtime information:" in printed_output
assert "Total runtime:" in printed_output
assert "Initial setup:" in printed_output
def test_print_job_description(simulation_parameters):
"""Test job description printing function"""
output = io.StringIO()
with redirect_stdout(output):
print_job_description(simulation_parameters)
printed_output = output.getvalue()
# Check key elements are present in output
assert "Input file:" in printed_output
assert "Number of nodes in the parallel cluster:" in printed_output
assert "Cell [Ang]:" in printed_output
assert "Parameters for the contour integral:" in printed_output
# Test default arguments
def test_default_args_structure():
"""Test the structure and values of default arguments"""
assert isinstance(default_args, dict)
assert "infile" in default_args
assert "outfile" in default_args
assert "kset" in default_args
assert "kdirs" in default_args
assert "eset" in default_args
assert "esetp" in default_args
# Test specific default values
assert default_args["kset"] == 2
assert default_args["kdirs"] == "xyz"
assert default_args["eset"] == 42
assert default_args["esetp"] == 1000
assert default_args["parallel_solver_for_Gk"] is False
assert default_args["padawan_mode"] is True
def test_simulation_parameters_validation(simulation_parameters):
"""Test validation of simulation parameters"""
# Test required fields
required_fields = ["infile", "outfile", "kset", "kdirs", "eset", "esetp"]
for field in required_fields:
assert field in simulation_parameters
# Test numeric parameters are positive
assert simulation_parameters["kset"] > 0
assert simulation_parameters["eset"] > 0
assert simulation_parameters["esetp"] > 0
# Test cell is 3x3 array
assert simulation_parameters["cell"].shape == (3, 3)
@pytest.mark.parametrize("test_dir", ["x", "y", "z", "xy", "yz", "xz", "xyz"])
def test_kdirs_validation(simulation_parameters, test_dir):
"""Test validation of kdirs parameter"""
simulation_parameters["kdirs"] = test_dir
# Should not raise any exceptions
print_parameters(simulation_parameters)
def test_ebot_automatic_detection(simulation_parameters):
"""Test automatic ebot detection flag"""
assert "automatic_ebot" in simulation_parameters
output = io.StringIO()
with redirect_stdout(output):
print_job_description(simulation_parameters)
printed_output = output.getvalue()
assert "WARNING: This was automatically determined!" in printed_output
if __name__ == "__main__":
pytest.main([__file__])