tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
slice_layer.h
1/*
2 Copyright (c) 2016, Taiga Nomi
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7 * Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in the
11 documentation and/or other materials provided with the distribution.
12 * Neither the name of the <organization> nor the
13 names of its contributors may be used to endorse or promote products
14 derived from this software without specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28#include "tiny_dnn/util/util.h"
29#include "tiny_dnn/layers/layer.h"
30
31namespace tiny_dnn {
32
33 enum class slice_type {
34 slice_samples,
35 slice_channels
36 };
37
38
42class slice_layer : public layer {
43public:
44 typedef layer Base;
45
69 slice_layer(const shape3d& in_shape, slice_type slice_type, serial_size_t num_outputs)
70 : layer(std::vector<vector_type>(1, vector_type::data), std::vector<vector_type>(num_outputs, vector_type::data)),
71 in_shape_(in_shape), slice_type_(slice_type), num_outputs_(num_outputs) {
72 set_shape();
73 }
74
75 slice_layer(const layer& prev_layer, slice_type slice_type, serial_size_t num_outputs)
76 : layer(std::vector<vector_type>(1, vector_type::data), std::vector<vector_type>(num_outputs, vector_type::data)),
77 in_shape_(prev_layer.out_shape()[0]), slice_type_(slice_type), num_outputs_(num_outputs) {
78 set_shape();
79 }
80
81 std::string layer_type() const override {
82 return "slice";
83 }
84
85 std::vector<shape3d> in_shape() const override {
86 return {in_shape_};
87 }
88
89 std::vector<shape3d> out_shape() const override {
90 return out_shapes_;
91 }
92
93 void forward_propagation(const std::vector<tensor_t*>& in_data,
94 std::vector<tensor_t*>& out_data) override {
95 switch (slice_type_) {
96 case slice_type::slice_samples:
97 slice_data_forward(*in_data[0], out_data);
98 break;
99 case slice_type::slice_channels:
100 slice_channels_forward(*in_data[0], out_data);
101 break;
102 default:
104 }
105 }
106
107 void back_propagation(const std::vector<tensor_t*>& in_data,
108 const std::vector<tensor_t*>& out_data,
109 std::vector<tensor_t*>& out_grad,
110 std::vector<tensor_t*>& in_grad) override {
111 CNN_UNREFERENCED_PARAMETER(in_data);
112 CNN_UNREFERENCED_PARAMETER(out_data);
113
114 switch (slice_type_) {
115 case slice_type::slice_samples:
116 slice_data_backward(out_grad, *in_grad[0]);
117 break;
118 case slice_type::slice_channels:
119 slice_channels_backward(out_grad, *in_grad[0]);
120 break;
121 default:
123 }
124 }
125
126 template <class Archive>
127 static void load_and_construct(Archive & ar, cereal::construct<slice_layer> & construct) {
129 slice_type slice_type;
130 serial_size_t num_outputs;
131
132 ar(cereal::make_nvp("in_size", in_shape), cereal::make_nvp("slice_type", slice_type), cereal::make_nvp("num_outputs", num_outputs));
133 construct(in_shape, slice_type, num_outputs);
134 }
135
136 template <class Archive>
137 void serialize(Archive & ar) {
138 layer::serialize_prolog(ar);
139 ar(cereal::make_nvp("in_size", in_shape_), cereal::make_nvp("slice_type", slice_type_), cereal::make_nvp("num_outputs", num_outputs_));
140 }
141private:
142 void slice_data_forward(const tensor_t& in_data,
143 std::vector<tensor_t*>& out_data) {
144 const vec_t* in = &in_data[0];
145
146 for (serial_size_t i = 0; i < num_outputs_; i++) {
147 tensor_t& out = *out_data[i];
148
149 std::copy(in, in + slice_size_[i], &out[0]);
150
151 in += slice_size_[i];
152 }
153 }
154
155 void slice_data_backward(std::vector<tensor_t*>& out_grad,
156 tensor_t& in_grad) {
157 vec_t* in = &in_grad[0];
158
159 for (serial_size_t i = 0; i < num_outputs_; i++) {
160 tensor_t& out = *out_grad[i];
161
162 std::copy(&out[0], &out[0] + slice_size_[i], in);
163
164 in += slice_size_[i];
165 }
166 }
167
168 void slice_channels_forward(const tensor_t& in_data,
169 std::vector<tensor_t*>& out_data) {
170 serial_size_t num_samples = static_cast<serial_size_t>(in_data.size());
171 serial_size_t channel_idx = 0;
172 serial_size_t spatial_dim = in_shape_.area();
173
174 for (serial_size_t i = 0; i < num_outputs_; i++) {
175 for (serial_size_t s = 0; s < num_samples; s++) {
176 float_t *out = &(*out_data[i])[s][0];
177 const float_t *in = &in_data[s][0] + channel_idx*spatial_dim;
178
179 std::copy(in, in + slice_size_[i] * spatial_dim, out);
180 }
181 channel_idx += slice_size_[i];
182 }
183 }
184
185 void slice_channels_backward(std::vector<tensor_t*>& out_grad,
186 tensor_t& in_grad) {
187 serial_size_t num_samples = static_cast<serial_size_t>(in_grad.size());
188 serial_size_t channel_idx = 0;
189 serial_size_t spatial_dim = in_shape_.area();
190
191 for (serial_size_t i = 0; i < num_outputs_; i++) {
192 for (serial_size_t s = 0; s < num_samples; s++) {
193 const float_t *out = &(*out_grad[i])[s][0];
194 float_t *in = &in_grad[s][0] + channel_idx*spatial_dim;
195
196 std::copy(out, out + slice_size_[i] * spatial_dim, in);
197 }
198 channel_idx += slice_size_[i];
199 }
200 }
201
202 void set_sample_count(serial_size_t sample_count) override {
203 if (slice_type_ == slice_type::slice_samples) {
204 if (num_outputs_ == 0)
205 throw nn_error("num_outputs must be positive integer");
206
207 serial_size_t sample_per_out = sample_count / num_outputs_;
208
209 slice_size_.resize(num_outputs_, sample_per_out);
210 slice_size_.back() = sample_count - (sample_per_out*(num_outputs_-1));
211 }
212 Base::set_sample_count(sample_count);
213 }
214
215 void set_shape() {
216 switch (slice_type_) {
217 case slice_type::slice_samples:
218 set_shape_data();
219 break;
220 case slice_type::slice_channels:
221 set_shape_channels();
222 break;
223 default:
224 throw nn_not_implemented_error();
225 }
226 }
227
228 void set_shape_data() {
229 out_shapes_.resize(num_outputs_, in_shape_);
230 }
231
232 void set_shape_channels() {
233 serial_size_t channel_per_out = in_shape_.depth_ / num_outputs_;
234
235 out_shapes_.clear();
236 for (serial_size_t i = 0; i < num_outputs_; i++) {
237 serial_size_t ch = channel_per_out;
238
239 if (i == num_outputs_ - 1) {
240 assert(in_shape_.depth_ >= i * channel_per_out);
241 ch = in_shape_.depth_ - i * channel_per_out;
242 }
243
244 slice_size_.push_back(ch);
245 out_shapes_.push_back(shape3d(in_shape_.width_, in_shape_.height_, ch));
246 }
247 }
248
249 shape3d in_shape_;
250 slice_type slice_type_;
251 serial_size_t num_outputs_;
252 std::vector<shape3d> out_shapes_;
253 std::vector<serial_size_t> slice_size_;
254};
255
256} // namespace tiny_dnn
Simple image utility class.
Definition image.h:94
base class of all kind of NN layers
Definition layer.h:62
Definition nn_error.h:83
slice an input data into multiple outputs along a given slice dimension.
Definition slice_layer.h:42
void forward_propagation(const std::vector< tensor_t * > &in_data, std::vector< tensor_t * > &out_data) override
Definition slice_layer.h:93
slice_layer(const shape3d &in_shape, slice_type slice_type, serial_size_t num_outputs)
Definition slice_layer.h:69
std::vector< shape3d > in_shape() const override
array of input shapes (width x height x depth)
Definition slice_layer.h:85
std::string layer_type() const override
name of layer, should be unique for each concrete class
Definition slice_layer.h:81
std::vector< shape3d > out_shape() const override
array of output shapes (width x height x depth)
Definition slice_layer.h:89
void back_propagation(const std::vector< tensor_t * > &in_data, const std::vector< tensor_t * > &out_data, std::vector< tensor_t * > &out_grad, std::vector< tensor_t * > &in_grad) override
return delta of previous layer (delta=\frac{dE}{da}, a=wx in fully-connected layer)
Definition slice_layer.h:107