|
5 | 5 | from aiida import orm |
6 | 6 | from aiida import plugins |
7 | 7 |
|
8 | | -from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType |
| 8 | +from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType, OccupationType, XcFunctionalType |
9 | 9 | from aiida_common_workflows.generators import ChoiceType, InputGenerator |
10 | 10 |
|
11 | 11 | __all__ = ('CommonRelaxInputGenerator',) |
@@ -37,6 +37,12 @@ def define(cls, spec): |
37 | 37 | help='The protocol to use for the automated input generation. This value indicates the level of precision ' |
38 | 38 | 'of the results and computational cost that the input parameters will be selected for.', |
39 | 39 | ) |
| 40 | + spec.input( |
| 41 | + 'spin_orbit', |
| 42 | + valid_type=bool, |
| 43 | + default=False, |
| 44 | + help='Whether to apply spin-orbit coupling.', |
| 45 | + ) |
40 | 46 | spec.input( |
41 | 47 | 'spin_type', |
42 | 48 | valid_type=SpinType, |
@@ -109,3 +115,81 @@ def define(cls, spec): |
109 | 115 | required=False, |
110 | 116 | help='Options for the geometry optimization calculation jobs.', |
111 | 117 | ) |
| 118 | + |
| 119 | + |
| 120 | +class CommonDftRelaxInputGenerator(CommonRelaxInputGenerator, metaclass=abc.ABCMeta): |
| 121 | + """Input generator for the common relax workflow. |
| 122 | +
|
| 123 | + .. note:: This class is a subclass of the ``CommonRelaxInputGenerator`` but defines some additional inputs that are |
| 124 | + common to a number of implementations. |
| 125 | +
|
| 126 | + This class should be subclassed by implementations for specific quantum engines. After calling the super, they can |
| 127 | + modify the ports defined here in the base class as well as add additional custom ports. |
| 128 | + """ |
| 129 | + |
| 130 | + @staticmethod |
| 131 | + def validate_kpoints_shift(value, _): |
| 132 | + """Validate the ``kpoints_shift`` input.""" |
| 133 | + if not isinstance(value, list) or len(value) != 3 or any(not isinstance(element, float) for element in value): |
| 134 | + return f'The `kpoints_shift` argument should be a list of three floats, but got: `{value}`.' |
| 135 | + |
| 136 | + @staticmethod |
| 137 | + def validate_inputs(value, _): |
| 138 | + """Docs.""" |
| 139 | + if value['spin_orbit'] is True and value['spin_type'] == SpinType.NONE: |
| 140 | + return '`spin_type` cannot be `SpinType.NONE` for `spin_orbit = True`.' |
| 141 | + |
| 142 | + smearing_broadening = value['smearing_broadening'] |
| 143 | + occupation_type = value['occupation_type'] |
| 144 | + |
| 145 | + if smearing_broadening is not None and occupation_type not in [ |
| 146 | + OccupationType.FIXED, OccupationType.TETRAHEDRON |
| 147 | + ]: |
| 148 | + return f'cannot define `smearing_broadening` for `occupation_type = {occupation_type}.' |
| 149 | + |
| 150 | + @classmethod |
| 151 | + def define(cls, spec): |
| 152 | + """Define the specification of the input generator. |
| 153 | +
|
| 154 | + The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method. |
| 155 | + """ |
| 156 | + super().define(spec) |
| 157 | + spec.inputs.validator = cls.validate_inputs |
| 158 | + spec.input( |
| 159 | + 'occupation_type', |
| 160 | + valid_type=OccupationType, |
| 161 | + serializer=OccupationType, |
| 162 | + default=OccupationType.FIXED, |
| 163 | + help='The way to treat electronic occupations.', |
| 164 | + ) |
| 165 | + spec.input( |
| 166 | + 'smearing_broadening', |
| 167 | + valid_type=float, |
| 168 | + required=False, |
| 169 | + help='The broadening of the smearing in eV. Should only be specified if a smearing method is defined for' |
| 170 | + 'the `occupation_type` input.', |
| 171 | + ) |
| 172 | + spec.input( |
| 173 | + 'xc_functional', |
| 174 | + valid_type=XcFunctionalType, |
| 175 | + serializer=XcFunctionalType, |
| 176 | + default=XcFunctionalType.PBE, |
| 177 | + help='The functional for the exchange-correlation to be used.', |
| 178 | + ) |
| 179 | + spec.input( |
| 180 | + 'kpoints_distance', |
| 181 | + valid_type=float, |
| 182 | + required=False, |
| 183 | + help='The desired minimum distance between k-points in reciprocal space in 1/Å. The implementation will' |
| 184 | + 'guarantee that a k-point mesh is generated for which the distances between all adjacent k-points along ' |
| 185 | + 'each cell vector are at most this distance. It is therefore possible that the distance is smaller than ' |
| 186 | + 'requested along certain directions.', |
| 187 | + ) |
| 188 | + spec.input( |
| 189 | + 'kpoints_shift', |
| 190 | + valid_type=list, |
| 191 | + validator=cls.validate_kpoints_shift, |
| 192 | + required=False, |
| 193 | + help='Optional shift to apply to all k-points of the k-point mesh. Should be a list of three floats where ' |
| 194 | + 'each float is a number between 0 and 1.', |
| 195 | + ) |
0 commit comments