You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
7.9 KiB
255 lines
7.9 KiB
# Copyright (c) [2024] []
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import io
|
|
import os
|
|
import tempfile
|
|
from contextlib import redirect_stdout
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from grogupy.io import (
|
|
default_args,
|
|
load_pickle,
|
|
print_atoms_and_pairs,
|
|
print_job_description,
|
|
print_parameters,
|
|
print_runtime_information,
|
|
save_pickle,
|
|
)
|
|
|
|
|
|
# Fixtures for common test data
|
|
@pytest.fixture
|
|
def simulation_parameters():
|
|
"""Create sample simulation parameters for testing"""
|
|
params = default_args.copy()
|
|
params.update(
|
|
{
|
|
"infile": "test.fdf",
|
|
"outfile": "test_output.pickle",
|
|
"parallel_size": 4,
|
|
"cell": np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
|
|
"automatic_ebot": True,
|
|
}
|
|
)
|
|
return params
|
|
|
|
|
|
@pytest.fixture
|
|
def magnetic_entities():
|
|
"""Create sample magnetic entities for testing"""
|
|
return [
|
|
{
|
|
"tags": ["[0]Fe(d)"],
|
|
"xyz": [np.array([0.0, 0.0, 0.0])],
|
|
"K": np.array([1.0, 2.0, 3.0]),
|
|
"K_consistency": 0.001,
|
|
},
|
|
{
|
|
"tags": ["[1]Fe(d)"],
|
|
"xyz": [np.array([1.0, 1.0, 1.0])],
|
|
"K": np.array([2.0, 3.0, 4.0]),
|
|
"K_consistency": 0.002,
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def pairs():
|
|
"""Create sample pairs for testing"""
|
|
return [
|
|
{
|
|
"tags": ["[0]Fe(d)", "[1]Fe(d)"],
|
|
"Ruc": np.array([0, 0, 0]),
|
|
"dist": 1.732,
|
|
"J_iso": 10.0,
|
|
"D": np.array([0.1, 0.2, 0.3]),
|
|
"J_S": np.array([0.5, 0.6, 0.7, 0.8, 0.9]),
|
|
"J": np.array([[1.0, 0.1, 0.2], [0.1, 1.1, 0.3], [0.2, 0.3, 1.2]]),
|
|
}
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def runtime_info():
|
|
"""Create sample runtime information for testing"""
|
|
return {
|
|
"start_time": 0.0,
|
|
"setup_time": 1.0,
|
|
"H_and_XCF_time": 2.0,
|
|
"site_and_pair_dictionaries_time": 3.0,
|
|
"k_set_time": 4.0,
|
|
"reference_rotations_time": 5.0,
|
|
"green_function_inversion_time": 6.0,
|
|
"end_time": 7.0,
|
|
}
|
|
|
|
|
|
# Test pickle save/load functionality
|
|
def test_pickle_save_load(simulation_parameters):
|
|
"""Test saving and loading data using pickle"""
|
|
with tempfile.NamedTemporaryFile(suffix=".pickle", delete=False) as tmp:
|
|
temp_path = tmp.name
|
|
|
|
try:
|
|
# Test saving
|
|
save_pickle(temp_path, simulation_parameters)
|
|
assert os.path.exists(temp_path)
|
|
assert os.path.getsize(temp_path) > 0
|
|
|
|
# Test loading
|
|
loaded_data = load_pickle(temp_path)
|
|
assert loaded_data == simulation_parameters
|
|
|
|
finally:
|
|
# Cleanup
|
|
if os.path.exists(temp_path):
|
|
os.unlink(temp_path)
|
|
|
|
|
|
# Test print functions
|
|
def test_print_parameters(simulation_parameters):
|
|
"""Test parameters printing function"""
|
|
output = io.StringIO()
|
|
with redirect_stdout(output):
|
|
print_parameters(simulation_parameters)
|
|
|
|
printed_output = output.getvalue()
|
|
|
|
# Check key elements are present in output
|
|
assert "Input file:" in printed_output
|
|
assert "Cell [Ang]:" in printed_output
|
|
assert "DFT axis:" in printed_output
|
|
assert str(simulation_parameters["kset"]) in printed_output
|
|
|
|
|
|
def test_print_atoms_and_pairs(magnetic_entities, pairs):
|
|
"""Test atoms and pairs printing function"""
|
|
output = io.StringIO()
|
|
with redirect_stdout(output):
|
|
print_atoms_and_pairs(magnetic_entities, pairs)
|
|
|
|
printed_output = output.getvalue()
|
|
|
|
# Check key elements are present in output
|
|
assert "Atomic information:" in printed_output
|
|
assert "Anisotropy [meV]" in printed_output
|
|
assert "Exchange [meV]" in printed_output
|
|
|
|
# Check if magnetic entity information is printed
|
|
for entity in magnetic_entities:
|
|
assert entity["tags"][0] in printed_output
|
|
|
|
# Check if pair information is printed
|
|
for pair in pairs:
|
|
assert pair["tags"][0] in printed_output
|
|
assert pair["tags"][1] in printed_output
|
|
|
|
|
|
def test_print_runtime_information(runtime_info):
|
|
"""Test runtime information printing function"""
|
|
output = io.StringIO()
|
|
with redirect_stdout(output):
|
|
print_runtime_information(runtime_info)
|
|
|
|
printed_output = output.getvalue()
|
|
|
|
# Check key elements are present in output
|
|
assert "Runtime information:" in printed_output
|
|
assert "Total runtime:" in printed_output
|
|
assert "Initial setup:" in printed_output
|
|
|
|
|
|
def test_print_job_description(simulation_parameters):
|
|
"""Test job description printing function"""
|
|
output = io.StringIO()
|
|
with redirect_stdout(output):
|
|
print_job_description(simulation_parameters)
|
|
|
|
printed_output = output.getvalue()
|
|
|
|
# Check key elements are present in output
|
|
assert "Input file:" in printed_output
|
|
assert "Number of nodes in the parallel cluster:" in printed_output
|
|
assert "Cell [Ang]:" in printed_output
|
|
assert "Parameters for the contour integral:" in printed_output
|
|
|
|
|
|
# Test default arguments
|
|
def test_default_args_structure():
|
|
"""Test the structure and values of default arguments"""
|
|
assert isinstance(default_args, dict)
|
|
assert "infile" in default_args
|
|
assert "outfile" in default_args
|
|
assert "kset" in default_args
|
|
assert "kdirs" in default_args
|
|
assert "eset" in default_args
|
|
assert "esetp" in default_args
|
|
|
|
# Test specific default values
|
|
assert default_args["kset"] == 2
|
|
assert default_args["kdirs"] == "xyz"
|
|
assert default_args["eset"] == 42
|
|
assert default_args["esetp"] == 1000
|
|
assert default_args["parallel_solver_for_Gk"] is False
|
|
assert default_args["padawan_mode"] is True
|
|
|
|
|
|
def test_simulation_parameters_validation(simulation_parameters):
|
|
"""Test validation of simulation parameters"""
|
|
# Test required fields
|
|
required_fields = ["infile", "outfile", "kset", "kdirs", "eset", "esetp"]
|
|
for field in required_fields:
|
|
assert field in simulation_parameters
|
|
|
|
# Test numeric parameters are positive
|
|
assert simulation_parameters["kset"] > 0
|
|
assert simulation_parameters["eset"] > 0
|
|
assert simulation_parameters["esetp"] > 0
|
|
|
|
# Test cell is 3x3 array
|
|
assert simulation_parameters["cell"].shape == (3, 3)
|
|
|
|
|
|
@pytest.mark.parametrize("test_dir", ["x", "y", "z", "xy", "yz", "xz", "xyz"])
|
|
def test_kdirs_validation(simulation_parameters, test_dir):
|
|
"""Test validation of kdirs parameter"""
|
|
simulation_parameters["kdirs"] = test_dir
|
|
# Should not raise any exceptions
|
|
print_parameters(simulation_parameters)
|
|
|
|
|
|
def test_ebot_automatic_detection(simulation_parameters):
|
|
"""Test automatic ebot detection flag"""
|
|
assert "automatic_ebot" in simulation_parameters
|
|
output = io.StringIO()
|
|
with redirect_stdout(output):
|
|
print_job_description(simulation_parameters)
|
|
|
|
printed_output = output.getvalue()
|
|
assert "WARNING: This was automatically determined!" in printed_output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|