@@ -17,27 +17,21 @@ struct Abs {
1717 T operator ()(T x) {
1818 return metal::abs (x);
1919 };
20- template <>
2120 uint8_t operator ()(uint8_t x) {
2221 return x;
2322 };
24- template <>
2523 uint16_t operator ()(uint16_t x) {
2624 return x;
2725 };
28- template <>
2926 uint32_t operator ()(uint32_t x) {
3027 return x;
3128 };
32- template <>
3329 uint64_t operator ()(uint64_t x) {
3430 return x;
3531 };
36- template <>
3732 bool operator ()(bool x) {
3833 return x;
3934 };
40- template <>
4135 complex64_t operator ()(complex64_t x) {
4236 return {metal::precise::sqrt (x.real * x.real + x.imag * x.imag ), 0 };
4337 };
@@ -48,6 +42,8 @@ struct ArcCos {
4842 T operator ()(T x) {
4943 return metal::precise::acos (x);
5044 };
45+
46+ complex64_t operator ()(complex64_t x);
5147};
5248
5349struct ArcCosh {
@@ -62,6 +58,8 @@ struct ArcSin {
6258 T operator ()(T x) {
6359 return metal::precise::asin (x);
6460 };
61+
62+ complex64_t operator ()(complex64_t x);
6563};
6664
6765struct ArcSinh {
@@ -76,6 +74,8 @@ struct ArcTan {
7674 T operator ()(T x) {
7775 return metal::precise::atan (x);
7876 };
77+
78+ complex64_t operator ()(complex64_t x);
7979};
8080
8181struct ArcTanh {
@@ -97,39 +97,30 @@ struct Ceil {
9797 T operator ()(T x) {
9898 return metal::ceil (x);
9999 };
100- template <>
101100 int8_t operator ()(int8_t x) {
102101 return x;
103102 };
104- template <>
105103 int16_t operator ()(int16_t x) {
106104 return x;
107105 };
108- template <>
109106 int32_t operator ()(int32_t x) {
110107 return x;
111108 };
112- template <>
113109 int64_t operator ()(int64_t x) {
114110 return x;
115111 };
116- template <>
117112 uint8_t operator ()(uint8_t x) {
118113 return x;
119114 };
120- template <>
121115 uint16_t operator ()(uint16_t x) {
122116 return x;
123117 };
124- template <>
125118 uint32_t operator ()(uint32_t x) {
126119 return x;
127120 };
128- template <>
129121 uint64_t operator ()(uint64_t x) {
130122 return x;
131123 };
132- template <>
133124 bool operator ()(bool x) {
134125 return x;
135126 };
@@ -141,7 +132,6 @@ struct Cos {
141132 return metal::precise::cos (x);
142133 };
143134
144- template <>
145135 complex64_t operator ()(complex64_t x) {
146136 return {
147137 metal::precise::cos (x.real ) * metal::precise::cosh (x.imag ),
@@ -155,7 +145,6 @@ struct Cosh {
155145 return metal::precise::cosh (x);
156146 };
157147
158- template <>
159148 complex64_t operator ()(complex64_t x) {
160149 return {
161150 metal::precise::cosh (x.real ) * metal::precise::cos (x.imag ),
@@ -188,7 +177,6 @@ struct Exp {
188177 T operator ()(T x) {
189178 return metal::precise::exp (x);
190179 };
191- template <>
192180 complex64_t operator ()(complex64_t x) {
193181 auto m = metal::precise::exp (x.real );
194182 return {m * metal::precise::cos (x.imag ), m * metal::precise::sin (x.imag )};
@@ -207,39 +195,30 @@ struct Floor {
207195 T operator ()(T x) {
208196 return metal::floor (x);
209197 };
210- template <>
211198 int8_t operator ()(int8_t x) {
212199 return x;
213200 };
214- template <>
215201 int16_t operator ()(int16_t x) {
216202 return x;
217203 };
218- template <>
219204 int32_t operator ()(int32_t x) {
220205 return x;
221206 };
222- template <>
223207 int64_t operator ()(int64_t x) {
224208 return x;
225209 };
226- template <>
227210 uint8_t operator ()(uint8_t x) {
228211 return x;
229212 };
230- template <>
231213 uint16_t operator ()(uint16_t x) {
232214 return x;
233215 };
234- template <>
235216 uint32_t operator ()(uint32_t x) {
236217 return x;
237218 };
238- template <>
239219 uint64_t operator ()(uint64_t x) {
240220 return x;
241221 };
242- template <>
243222 bool operator ()(bool x) {
244223 return x;
245224 };
@@ -258,7 +237,6 @@ struct Log {
258237 return metal::precise::log (x);
259238 };
260239
261- template <>
262240 complex64_t operator ()(complex64_t x) {
263241 auto r = metal::precise::log (Abs{}(x).real );
264242 auto i = metal::precise::atan2 (x.imag , x.real );
@@ -272,7 +250,6 @@ struct Log2 {
272250 return metal::precise::log2 (x);
273251 };
274252
275- template <>
276253 complex64_t operator ()(complex64_t x) {
277254 auto y = Log{}(x);
278255 return {y.real / M_LN2_F, y.imag / M_LN2_F};
@@ -285,7 +262,6 @@ struct Log10 {
285262 return metal::precise::log10 (x);
286263 };
287264
288- template <>
289265 complex64_t operator ()(complex64_t x) {
290266 auto y = Log{}(x);
291267 return {y.real / M_LN10_F, y.imag / M_LN10_F};
@@ -325,7 +301,6 @@ struct Round {
325301 T operator ()(T x) {
326302 return metal::rint (x);
327303 };
328- template <>
329304 complex64_t operator ()(complex64_t x) {
330305 return {metal::rint (x.real ), metal::rint (x.imag )};
331306 };
@@ -344,11 +319,9 @@ struct Sign {
344319 T operator ()(T x) {
345320 return (x > T (0 )) - (x < T (0 ));
346321 };
347- template <>
348322 uint32_t operator ()(uint32_t x) {
349323 return x != 0 ;
350324 };
351- template <>
352325 complex64_t operator ()(complex64_t x) {
353326 if (x == complex64_t (0 )) {
354327 return x;
@@ -364,7 +337,6 @@ struct Sin {
364337 return metal::precise::sin (x);
365338 };
366339
367- template <>
368340 complex64_t operator ()(complex64_t x) {
369341 return {
370342 metal::precise::sin (x.real ) * metal::precise::cosh (x.imag ),
@@ -378,7 +350,6 @@ struct Sinh {
378350 return metal::precise::sinh (x);
379351 };
380352
381- template <>
382353 complex64_t operator ()(complex64_t x) {
383354 return {
384355 metal::precise::sinh (x.real ) * metal::precise::cos (x.imag ),
@@ -398,13 +369,28 @@ struct Sqrt {
398369 T operator ()(T x) {
399370 return metal::precise::sqrt (x);
400371 };
372+
373+ complex64_t operator ()(complex64_t x) {
374+ if (x.real == 0.0 && x.imag == 0.0 ) {
375+ return {0.0 , 0.0 };
376+ }
377+ auto r = Abs{}(x).real ;
378+ auto a = metal::precise::sqrt ((r + x.real ) / 2.0 );
379+ auto b_abs = metal::precise::sqrt ((r - x.real ) / 2.0 );
380+ auto b = metal::copysign (b_abs, x.imag );
381+ return {a, b};
382+ }
401383};
402384
403385struct Rsqrt {
404386 template <typename T>
405387 T operator ()(T x) {
406388 return metal::precise::rsqrt (x);
407389 };
390+
391+ complex64_t operator ()(complex64_t x) {
392+ return 1.0 / Sqrt{}(x);
393+ }
408394};
409395
410396struct Tan {
@@ -413,7 +399,6 @@ struct Tan {
413399 return metal::precise::tan (x);
414400 };
415401
416- template <>
417402 complex64_t operator ()(complex64_t x) {
418403 float tan_a = metal::precise::tan (x.real );
419404 float tanh_b = metal::precise::tanh (x.imag );
@@ -429,7 +414,6 @@ struct Tanh {
429414 return metal::precise::tanh (x);
430415 };
431416
432- template <>
433417 complex64_t operator ()(complex64_t x) {
434418 float tanh_a = metal::precise::tanh (x.real );
435419 float tan_b = metal::precise::tan (x.imag );
@@ -438,3 +422,21 @@ struct Tanh {
438422 return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
439423 };
440424};
425+
426+ complex64_t ArcCos::operator ()(complex64_t x) {
427+ auto i = complex64_t {0.0 , 1.0 };
428+ auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
429+ return {y.imag , -y.real };
430+ };
431+
432+ complex64_t ArcSin::operator ()(complex64_t x) {
433+ auto i = complex64_t {0.0 , 1.0 };
434+ auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
435+ return {y.imag , -y.real };
436+ };
437+
438+ complex64_t ArcTan::operator ()(complex64_t x) {
439+ auto i = complex64_t {0.0 , 1.0 };
440+ auto ix = i * x;
441+ return (1.0 / complex64_t {0.0 , 2.0 }) * Log{}((1.0 + ix) / (1.0 - ix));
442+ };
0 commit comments