55 CNN_USE_LAYER_MEMBERS;
68 backend_t backend_type = core::default_engine())
75 backend_t backend_type = core::default_engine())
79 max_pooling_layer(serial_size_t in_width,
80 serial_size_t in_height,
82 serial_size_t pooling_size,
84 backend_t backend_type = core::default_engine())
85 : max_pooling_layer(in_width, in_height,
in_channels, pooling_size,
86 pooling_size, stride, stride, padding::valid,
101 serial_size_t stride_x,
102 serial_size_t stride_y,
103 padding pad_type = padding::valid,
104 backend_t backend_type = core::default_engine())
105 :
Base({ vector_type::data }) {
114 init_backend(backend_type);
115 Base::set_backend_type(backend_type);
119 max_pooling_layer(max_pooling_layer&& other)
120 : Base(std::move(other))
121 , params_(std::move(other.params_)) {
123 init_backend(std::move(Base::engine()));
127 return static_cast<serial_size_t
>(params_.out2in[0].size());
135 std::vector<tensor_t*>&
out_data)
override {
138 ctx.setParallelize(layer::parallelize());
139 ctx.setEngine(layer::engine());
142 kernel_fwd_->compute(
ctx);
149 const std::vector<tensor_t*>&
out_data,
151 std::vector<tensor_t*>&
in_grad)
override {
158 ctx.setParallelize(layer::parallelize());
159 ctx.setEngine(layer::engine());
162 kernel_back_->compute(
ctx);
165 std::vector<index3d<serial_size_t>>
166 in_shape()
const override {
return { params_.in }; }
168 std::vector<index3d<serial_size_t>>
169 out_shape()
const override {
return { params_.out, params_.out }; }
172 return std::string(
"max-pool");
175 std::string kernel_file()
const override {
176 return std::string(
"../tiny_cnn/core/kernels/cl_kernels/pooling.cl");
179 std::pair<serial_size_t, serial_size_t> pool_size()
const {
180 return std::make_pair(params_.pool_size_x, params_.pool_size_y);
183 void set_sample_count(serial_size_t sample_count)
override {
184 Base::set_sample_count(sample_count);
185 params_.out2inmax.resize(
186 sample_count, std::vector<serial_size_t>(params_.out.size()));
190 template <
class Archive>
192 load_and_construct(Archive & ar,
193 cereal::construct<max_pooling_layer> & construct) {
195 serial_size_t stride_x, stride_y, pool_size_x, pool_size_y;
198 ar(cereal::make_nvp(
"in_size", in),
199 cereal::make_nvp(
"pool_size_x", pool_size_x),
200 cereal::make_nvp(
"pool_size_y", pool_size_y),
201 cereal::make_nvp(
"stride_x", stride_x),
202 cereal::make_nvp(
"stride_y", stride_y),
203 cereal::make_nvp(
"pad_type", pad_type));
204 construct(in.width_, in.height_, in.depth_, pool_size_x, pool_size_y,
205 stride_x, stride_y, pad_type);
208 template <
class Archive>
209 void serialize(Archive & ar) {
210 layer::serialize_prolog(ar);
211 ar(cereal::make_nvp(
"in_size", params_.in),
212 cereal::make_nvp(
"pool_size_x", params_.pool_size_x),
213 cereal::make_nvp(
"pool_size_y", params_.pool_size_y),
214 cereal::make_nvp(
"stride_x", params_.stride_x),
215 cereal::make_nvp(
"stride_y", params_.stride_y),
216 cereal::make_nvp(
"pad_type", params_.pad_type));
221 maxpool_params params_;
224 std::shared_ptr<core::OpKernel> kernel_fwd_;
225 std::shared_ptr<core::OpKernel> kernel_back_;
227 void connect_kernel(serial_size_t pooling_size_x,
228 serial_size_t pooling_size_y,
232 serial_size_t dxmax =
static_cast<serial_size_t
>(
233 std::min(
static_cast<serial_size_t
>(pooling_size_x),
234 params_.in.width_ - outx * params_.stride_x));
236 serial_size_t dymax =
static_cast<serial_size_t
>(
237 std::min(
static_cast<serial_size_t
>(pooling_size_y),
238 params_.in.height_ - outy * params_.stride_y));
240 for (serial_size_t dy = 0; dy < dymax; dy++) {
241 for (serial_size_t dx = 0; dx < dxmax; dx++) {
242 serial_size_t in_index = params_.in.get_index(
243 static_cast<serial_size_t
>(outx * params_.stride_x + dx),
244 static_cast<serial_size_t
>(outy * params_.stride_y + dy), c);
245 serial_size_t out_index = params_.out.get_index(outx, outy, c);
247 if (in_index >= params_.in2out.size()) {
248 throw nn_error(
"index overflow");
250 if (out_index >= params_.out2in.size()) {
251 throw nn_error(
"index overflow");
253 params_.in2out[in_index] = out_index;
254 params_.out2in[out_index].push_back(in_index);
259 void init_connection() {
260 params_.in2out.resize(params_.in.size());
261 params_.out2in.resize(params_.out.size());
263 for (serial_size_t c = 0; c < params_.in.depth_; ++c) {
264 for (serial_size_t y = 0; y < params_.out.height_; ++y) {
265 for (serial_size_t x = 0; x < params_.out.width_; ++x) {
266 connect_kernel(params_.pool_size_x,
274 void init_backend(backend_t backend_type) {
275 core::OpKernelConstruction ctx =
276 core::OpKernelConstruction(layer::device(), ¶ms_);
278 if (backend_type == backend_t::internal ||
279 backend_type == backend_t::nnpack ||
280 backend_type == backend_t::avx) {
282 kernel_fwd_.reset(
new MaxPoolOp(ctx));
283 kernel_back_.reset(
new MaxPoolGradOp(ctx));
287 throw nn_error(
"Not supported engine: " + to_string(backend_type));
292 void set_maxpool_params(
const shape3d& in,
294 serial_size_t pooling_size_x,
295 serial_size_t pooling_size_y,
296 serial_size_t stride_x,
297 serial_size_t stride_y,
301 params_.pool_size_x = pooling_size_x;
302 params_.pool_size_y = pooling_size_y;
303 params_.stride_x = stride_x;
304 params_.stride_y = stride_y;
305 params_.pad_type = pad_type;
max_pooling_layer(serial_size_t in_width, serial_size_t in_height, serial_size_t in_channels, serial_size_t pooling_size_x, serial_size_t pooling_size_y, serial_size_t stride_x, serial_size_t stride_y, padding pad_type=padding::valid, backend_t backend_type=core::default_engine())
Definition max_pooling_layer.h:96