tiny_dnn
1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
tiny_dnn
core
kernels
conv2d_grad_op.h
1
/*
2
COPYRIGHT
3
4
All contributions by Taiga Nomi
5
Copyright (c) 2013, Taiga Nomi
6
All rights reserved.
7
8
All other contributions:
9
Copyright (c) 2013-2016, the respective contributors.
10
All rights reserved.
11
12
Each contributor holds copyright over their respective contributions.
13
The project versioning (Git) records all such contribution source information.
14
15
LICENSE
16
17
The BSD 3-Clause License
18
19
20
Redistribution and use in source and binary forms, with or without
21
modification, are permitted provided that the following conditions are met:
22
23
* Redistributions of source code must retain the above copyright notice, this
24
list of conditions and the following disclaimer.
25
26
* Redistributions in binary form must reproduce the above copyright notice,
27
this list of conditions and the following disclaimer in the documentation
28
and/or other materials provided with the distribution.
29
30
* Neither the name of tiny-dnn nor the names of its
31
contributors may be used to endorse or promote products derived from
32
this software without specific prior written permission.
33
34
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
*/
45
#pragma once
46
47
#include "tiny_dnn/core/framework/op_kernel.h"
48
49
#include "tiny_dnn/core/kernels/conv2d_grad_op_avx.h"
50
#include "tiny_dnn/core/kernels/conv2d_op_internal.h"
51
52
namespace
tiny_dnn {
53
54
class
Conv2dGradOp
:
public
core::OpKernel
{
55
public
:
56
explicit
Conv2dGradOp
(
const
core::OpKernelConstruction
&
context
)
57
:
core::OpKernel
(
context
) {}
58
59
void
compute(
const
core::OpKernelContext
&
context
)
override
{
60
auto
params = OpKernel::params_->conv();
61
62
// incoming/outcoming data
63
const
tensor_t&
prev_out
=
context
.input(0);
64
const
tensor_t&
W
=
context
.input(1);
65
tensor_t&
dW
=
context
.input_grad(1);
66
tensor_t&
db
=
context
.input_grad(2);
67
tensor_t&
prev_delta
=
context
.input_grad(0);
68
tensor_t&
curr_delta
=
context
.output_grad(1);
69
70
// initalize outputs
71
fill_tensor(
prev_delta
,
float_t
(0));
72
73
// call convolution algorithm depending
74
// on the selected engine type
75
76
const
core::backend_t engine =
context
.engine();
77
78
if
(engine == core::backend_t::internal) {
79
kernels::conv2d_op_internal(
80
prev_out
,
81
W
[0],
82
dW
,
83
db
,
84
curr_delta
,
85
prev_delta
,
86
params,
87
context
.parallelize());
88
}
89
else
if
(engine == core::backend_t::avx) {
90
kernels::conv2d_grad_op_avx(
91
prev_out
,
92
W
[0],
93
dW
,
94
db
,
95
curr_delta
,
96
prev_delta
,
97
params,
98
context
.parallelize());
99
}
100
else
{
101
throw
nn_error
(
"Not supported engine: "
+ to_string(engine));
102
}
103
}
104
};
105
106
}
// namespace tiny_dnn
tiny_dnn::Conv2dGradOp
Definition
conv2d_grad_op.h:54
tiny_dnn::core::OpKernelConstruction
Definition
op_kernel.h:55
tiny_dnn::core::OpKernelContext
Definition
op_kernel.h:72
tiny_dnn::core::OpKernel
Definition
op_kernel.h:175
tiny_dnn::image
Simple image utility class.
Definition
image.h:94
tiny_dnn::nn_error
error exception class for tiny-dnn
Definition
nn_error.h:37
Generated by
1.9.8