IPPL (Independent Parallel Particle Layer)
IPPL
Loading...
Searching...
No Matches
FFTPeriodicPoissonSolver.hpp
Go to the documentation of this file.
1//
2// Class FFTPeriodicPoissonSolver
3// Solves the periodic Poisson problem using Fourier transforms
4// cf. https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 5
5//
6//
7
8namespace ippl {
9
10 template <typename FieldLHS, typename FieldRHS>
15
16 template <typename FieldLHS, typename FieldRHS>
18 const Layout_t& layout_r = this->rhs_mp->getLayout();
19 domain_m = layout_r.getDomain();
20
21 NDIndex<Dim> domainComplex;
22
23 vector_type hComplex;
24 vector_type originComplex;
25
26 std::array<bool, Dim> isParallel = layout_r.isParallel();
27 for (unsigned d = 0; d < Dim; ++d) {
28 hComplex[d] = 1.0;
29 originComplex[d] = 0.0;
30
31 if (this->params_m.template get<int>("r2c_direction") == (int)d) {
32 domainComplex[d] = Index(domain_m[d].length() / 2 + 1);
33 } else {
34 domainComplex[d] = Index(domain_m[d].length());
35 }
36 }
37
38 layoutComplex_mp = std::make_shared<Layout_t>(layout_r.comm, domainComplex, isParallel);
39
40 mesh_type meshComplex(domainComplex, hComplex, originComplex);
41
42 fieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
43
44 if (this->params_m.template get<int>("output_type") == Base::GRAD) {
45 tempFieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
46 }
47
48 fft_mp = std::make_shared<FFT_t>(layout_r, *layoutComplex_mp, this->params_m);
49 fft_mp->warmup(*this->rhs_mp, fieldComplex_m); // warmup the FFT object
50 }
51
52 template <typename FieldLHS, typename FieldRHS>
54 fft_mp->transform(FORWARD, *this->rhs_mp, fieldComplex_m);
55
56 auto view = fieldComplex_m.getView();
57 const int nghost = fieldComplex_m.getNghost();
58
59 scalar_type pi = Kokkos::numbers::pi_v<scalar_type>;
60 const mesh_type& mesh = this->rhs_mp->get_mesh();
61 const auto& lDomComplex = layoutComplex_mp->getLocalNDIndex();
62 using vector_type = typename mesh_type::vector_type;
63 const vector_type& origin = mesh.getOrigin();
64 const vector_type& hx = mesh.getMeshSpacing();
65
66 vector_type rmax;
68 for (size_t d = 0; d < Dim; ++d) {
69 N[d] = domain_m[d].length();
70 rmax[d] = origin[d] + (N[d] * hx[d]);
71 }
72
73 // Based on output_type calculate either solution
74 // or gradient
75
76 using index_array_type = typename RangePolicy<Dim>::index_array_type;
77 switch (this->params_m.template get<int>("output_type")) {
78 case Base::SOL: {
80 "Solution FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
81 KOKKOS_LAMBDA(const index_array_type& args) {
82 Vector<int, Dim> iVec = args - nghost;
83 for (unsigned d = 0; d < Dim; ++d) {
84 iVec[d] += lDomComplex[d].first();
85 }
86
87 Vector_t kVec;
88
89 for (size_t d = 0; d < Dim; ++d) {
90 const scalar_type Len = rmax[d] - origin[d];
91 bool shift = (iVec[d] > (N[d] / 2));
92 kVec[d] = 2 * pi / Len * (iVec[d] - shift * N[d]);
93 }
94
95 scalar_type Dr = 0;
96 for (unsigned d = 0; d < Dim; ++d) {
97 Dr += kVec[d] * kVec[d];
98 }
99
100 bool isNotZero = (Dr != 0.0);
101 scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
102
103 apply(view, args) *= factor;
104 });
105
106 fft_mp->transform(BACKWARD, *this->rhs_mp, fieldComplex_m);
107
108 break;
109 }
110 case Base::GRAD: {
111 // Compute gradient in Fourier space and then
112 // take inverse FFT.
113
114 Complex_t imag = {0.0, 1.0};
115 auto tempview = tempFieldComplex_m.getView();
116 auto viewRhs = this->rhs_mp->getView();
117 auto viewLhs = this->lhs_mp->getView();
118 const int nghostL = this->lhs_mp->getNghost();
119
120 for (size_t gd = 0; gd < Dim; ++gd) {
122 "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
123 KOKKOS_LAMBDA(const index_array_type& args) {
124 Vector<int, Dim> iVec = args - nghost;
125 for (unsigned d = 0; d < Dim; ++d) {
126 iVec[d] += lDomComplex[d].first();
127 }
128
129 Vector_t kVec;
130
131 for (size_t d = 0; d < Dim; ++d) {
132 const scalar_type Len = rmax[d] - origin[d];
133 bool shift = (iVec[d] > (N[d] / 2));
134 bool notMid = (iVec[d] != (N[d] / 2));
135 // For the noMid part see
136 // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
137 kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
138 }
139
140 scalar_type Dr = 0;
141 for (unsigned d = 0; d < Dim; ++d) {
142 Dr += kVec[d] * kVec[d];
143 }
144
145 apply(tempview, args) = apply(view, args);
146
147 bool isNotZero = (Dr != 0.0);
148 scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
149
150 apply(tempview, args) *= -(imag * kVec[gd] * factor);
151 });
152
153 fft_mp->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
154
156 "Assign Gradient FFTPeriodicPoissonSolver",
157 getRangePolicy(viewLhs, nghostL),
158 KOKKOS_LAMBDA(const index_array_type& args) {
159 apply(viewLhs, args)[gd] = apply(viewRhs, args);
160 });
161 }
162
163 break;
164 }
165
166 default:
167 throw IpplException("FFTPeriodicPoissonSolver::solve", "Unrecognized output_type");
168 }
169 }
170} // namespace ippl
const double pi
Definition Archive.h:20
void initialize(int &argc, char *argv[], MPI_Comm comm)
Definition Ippl.cpp:16
@ FORWARD
Definition FFT.h:63
@ BACKWARD
Definition FFT.h:64
KOKKOS_INLINE_FUNCTION constexpr decltype(auto) apply(const View &view, const Coords &coords)
RangePolicy< View::rank, typenameView::execution_space, PolicyArgs... >::policy_type getRangePolicy(const View &view, int shift=0)
void parallel_for(const std::string &name, const ExecPolicy &policy, const FunctorType &functor)
KOKKOS_INLINE_FUNCTION auto & get(Tuple< Ts... > &t)
Accessor function to get an element mutable reference at a specific index from a Tuple.
Definition Tuple.h:314
const NDIndex< Dim > & getDomain() const
std::array< bool, Dim > isParallel() const
mpi::Communicator comm
typename FieldLHS::Mesh_t::vector_type vector_type
void setRhs(rhs_type &rhs) override
std::shared_ptr< Layout_t > layoutComplex_mp
typename FieldLHS::Mesh_t::value_type scalar_type
lhs_type * lhs_mp
Definition Poisson.h:123
ParameterList params_m
Definition Poisson.h:120
virtual void setRhs(rhs_type &rhs)
Definition Poisson.h:96
rhs_type * rhs_mp
Definition Poisson.h:122
::ippl::Vector< index_type, Dim > index_array_type