Skip to content

Commit 4e91d90

Browse files
committed
Add CommonRelaxInputGenerator subclass for DFT based codes
1 parent f83cbba commit 4e91d90

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

aiida_common_workflows/common/types.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""Module with basic type definitions."""
33
from enum import Enum
44

5-
__all__ = ('ElectronicType', 'SpinType', 'RelaxType')
5+
__all__ = ('ElectronicType', 'SpinType', 'RelaxType', 'OccupationType', 'XcFunctionalType')
66

77

88
class RelaxType(Enum):
@@ -24,12 +24,30 @@ class SpinType(Enum):
2424
NONE = 'none'
2525
COLLINEAR = 'collinear'
2626
NON_COLLINEAR = 'non_collinear'
27-
SPIN_ORBIT = 'spin_orbit'
2827

2928

3029
class ElectronicType(Enum):
3130
"""Enumeration of known electronic types."""
3231

33-
AUTOMATIC = 'automatic'
32+
UNKNOWN = 'unknown'
3433
METAL = 'metal'
3534
INSULATOR = 'insulator'
35+
36+
37+
class OccupationType(Enum):
38+
"""Enumeration of known methods of treating electronic occupations."""
39+
40+
FIXED = 'fixed'
41+
TETRAHEDRON = 'tetrahedron'
42+
GAUSSIAN = 'gaussian'
43+
FERMI_DIRAC = 'fermi-dirac'
44+
METHFESSEL_PAXTON = 'methfessel-paxton'
45+
MARZARI_VANDERBILT = 'marzari-vanderbilt'
46+
47+
48+
class XcFunctionalType(Enum):
49+
"""Enumeration of known exchange-correlation functional types."""
50+
51+
LDA = 'lda'
52+
PBE = 'pbe'
53+
PBESOL = 'pbesol'

aiida_common_workflows/workflows/relax/generator.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aiida import orm
66
from aiida import plugins
77

8-
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
8+
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType, OccupationType, XcFunctionalType
99
from aiida_common_workflows.generators import ChoiceType, InputGenerator
1010

1111
__all__ = ('CommonRelaxInputGenerator',)
@@ -37,6 +37,12 @@ def define(cls, spec):
3737
help='The protocol to use for the automated input generation. This value indicates the level of precision '
3838
'of the results and computational cost that the input parameters will be selected for.',
3939
)
40+
spec.input(
41+
'spin_orbit',
42+
valid_type=bool,
43+
default=False,
44+
help='Whether to apply spin-orbit coupling.',
45+
)
4046
spec.input(
4147
'spin_type',
4248
valid_type=SpinType,
@@ -109,3 +115,81 @@ def define(cls, spec):
109115
required=False,
110116
help='Options for the geometry optimization calculation jobs.',
111117
)
118+
119+
120+
class CommonDftRelaxInputGenerator(CommonRelaxInputGenerator, metaclass=abc.ABCMeta):
121+
"""Input generator for the common relax workflow.
122+
123+
.. note:: This class is a subclass of the ``CommonRelaxInputGenerator`` but defines some additional inputs that are
124+
common to a number of implementations.
125+
126+
This class should be subclassed by implementations for specific quantum engines. After calling the super, they can
127+
modify the ports defined here in the base class as well as add additional custom ports.
128+
"""
129+
130+
@staticmethod
131+
def validate_kpoints_shift(value, _):
132+
"""Validate the ``kpoints_shift`` input."""
133+
if not isinstance(value, list) or len(value) != 3 or any(not isinstance(element, float) for element in value):
134+
return f'The `kpoints_shift` argument should be a list of three floats, but got: `{value}`.'
135+
136+
@staticmethod
137+
def validate_inputs(value, _):
138+
"""Docs."""
139+
if value['spin_orbit'] is True and value['spin_type'] == SpinType.NONE:
140+
return '`spin_type` cannot be `SpinType.NONE` for `spin_orbit = True`.'
141+
142+
smearing_broadening = value['smearing_broadening']
143+
occupation_type = value['occupation_type']
144+
145+
if smearing_broadening is not None and occupation_type not in [
146+
OccupationType.FIXED, OccupationType.TETRAHEDRON
147+
]:
148+
return f'cannot define `smearing_broadening` for `occupation_type = {occupation_type}.'
149+
150+
@classmethod
151+
def define(cls, spec):
152+
"""Define the specification of the input generator.
153+
154+
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
155+
"""
156+
super().define(spec)
157+
spec.inputs.validator = cls.validate_inputs
158+
spec.input(
159+
'occupation_type',
160+
valid_type=OccupationType,
161+
serializer=OccupationType,
162+
default=OccupationType.FIXED,
163+
help='The way to treat electronic occupations.',
164+
)
165+
spec.input(
166+
'smearing_broadening',
167+
valid_type=float,
168+
required=False,
169+
help='The broadening of the smearing in eV. Should only be specified if a smearing method is defined for'
170+
'the `occupation_type` input.',
171+
)
172+
spec.input(
173+
'xc_functional',
174+
valid_type=XcFunctionalType,
175+
serializer=XcFunctionalType,
176+
default=XcFunctionalType.PBE,
177+
help='The functional for the exchange-correlation to be used.',
178+
)
179+
spec.input(
180+
'kpoints_distance',
181+
valid_type=float,
182+
required=False,
183+
help='The desired minimum distance between k-points in reciprocal space in 1/Å. The implementation will'
184+
'guarantee that a k-point mesh is generated for which the distances between all adjacent k-points along '
185+
'each cell vector are at most this distance. It is therefore possible that the distance is smaller than '
186+
'requested along certain directions.',
187+
)
188+
spec.input(
189+
'kpoints_shift',
190+
valid_type=list,
191+
validator=cls.validate_kpoints_shift,
192+
required=False,
193+
help='Optional shift to apply to all k-points of the k-point mesh. Should be a list of three floats where '
194+
'each float is a number between 0 and 1.',
195+
)

aiida_common_workflows/workflows/relax/quantum_espresso/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
88
from aiida_common_workflows.generators import ChoiceType, CodeType
99

10-
from ..generator import CommonRelaxInputGenerator
10+
from ..generator import CommonDftRelaxInputGenerator
1111

1212
__all__ = ('QuantumEspressoCommonRelaxInputGenerator',)
1313

@@ -62,7 +62,7 @@ def create_magnetic_allotrope(structure, magnetization_per_site):
6262
return (allotrope, allotrope_magnetic_moments)
6363

6464

65-
class QuantumEspressoCommonRelaxInputGenerator(CommonRelaxInputGenerator):
65+
class QuantumEspressoCommonRelaxInputGenerator(CommonDftRelaxInputGenerator):
6666
"""Input generator for the common relax workflow implementation of Quantum ESPRESSO."""
6767

6868
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)