tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
graph_visualizer.h
1/*
2Copyright (c) 2016, Taiga Nomi
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions are met:
7* Redistributions of source code must retain the above copyright
8notice, this list of conditions and the following disclaimer.
9* Redistributions in binary form must reproduce the above copyright
10notice, this list of conditions and the following disclaimer in the
11documentation and/or other materials provided with the distribution.
12* Neither the name of the <organization> nor the
13names of its contributors may be used to endorse or promote products
14derived from this software without specific prior written permission.
15
16THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28
29#include "tiny_dnn/node.h"
30#include "tiny_dnn/layers/layer.h"
31#include "tiny_dnn/network.h"
32
33namespace tiny_dnn {
34
39public:
40 explicit graph_visualizer(layer *root_node, const std::string& graph_name = "graph")
41 : root_(root_node), name_(graph_name) {}
42
43 template <typename N>
44 explicit graph_visualizer(network<N>& network, const std::string& graph_name = "graph")
45 : root_(network[0]), name_(graph_name) {}
46
50 void generate(std::ostream& stream) {
51 generate_header(stream);
52 generate_nodes(stream);
53 generate_footer(stream);
54 }
55
56private:
57 typedef std::unordered_map<const node*, std::string> node2name_t;
58
59 void generate_header(std::ostream& stream) {
60 stream << "digraph \"" << name_ << "\" {" << std::endl;
61 stream << " node [ shape=record ];" << std::endl;
62 }
63
64 void generate_nodes(std::ostream& stream) {
65 node2name_t node2name;
66 get_layer_names(node2name);
67
68 graph_traverse(root_,
69 [&](const layer& l) { generate_layer(stream, l, node2name); },
70 [&](const edge& e) { generate_edge(stream, e, node2name); });
71 }
72
73 void get_layer_names(node2name_t& node2name) {
74 std::unordered_map<std::string, int> layer_counts; // [layer_type -> num]
75
76 auto namer = [&](const layer& l) {
77 std::string ltype = l.layer_type();
78
79 // add quote and sequential-id
80 node2name[&l] = "\"" + ltype + to_string(layer_counts[l.layer_type()]++) + "\"";
81 };
82
83 graph_traverse(root_, namer, [&](const edge&){});
84 }
85
86 void generate_edge(std::ostream& stream, const edge& e, node2name_t& node2name) {
87 auto next = e.next();
88 auto prev = e.prev();
89
90 for (auto n : next) {
91 serial_size_t dst_port = n->prev_port(e);
92 serial_size_t src_port = prev->next_port(e);
93 stream << " " << node2name[prev] << ":out" << src_port <<
94 " -> " << node2name[n] << ":in" << dst_port << ";" << std::endl;
95 }
96 }
97
98 void generate_layer(std::ostream& stream, const layer& layer, node2name_t& node2name) {
99 stream << " " << node2name[&layer] << " [" << std::endl;
100 stream << " label= \"";
101 stream << layer.layer_type() << "|{{in";
102 generate_layer_channels(stream, layer.in_shape(), layer.in_types(), "in");
103 stream << "}|{out";
104 generate_layer_channels(stream, layer.out_shape(), layer.out_types(), "out");
105 stream << "}}\""<< std::endl;
106 stream << " ];" << std::endl;
107 }
108
109 void generate_layer_channels(std::ostream& stream,
110 const std::vector<shape3d>& shapes,
111 const std::vector<vector_type>& vtypes,
112 const std::string& port_prefix) {
113 CNN_UNREFERENCED_PARAMETER(vtypes);
114 for (size_t i = 0; i < shapes.size(); i++) {
115 stream << "|<" << port_prefix << i << ">" << shapes[i] << "(" << vtypes[i] << ")";
116 }
117 }
118
119 void generate_footer(std::ostream& stream) {
120 stream << "}" << std::endl;
121 }
122
123 layer* root_;
124 std::string name_;
125};
126
127
128} // namespace tiny_dnn
utility for graph visualization
Definition graph_visualizer.h:38
void generate(std::ostream &stream)
generate graph structure in dot language format
Definition graph_visualizer.h:50
Simple image utility class.
Definition image.h:94
base class of all kind of NN layers
Definition layer.h:62
A model of neural networks in tiny-dnn.
Definition network.h:167