tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
target_cost.h
1#pragma once
2#include "tiny_dnn/util/util.h"
3#include <numeric> // std::accumulate
4
5namespace tiny_dnn {
6
7// calculate the number of samples for each class label
8// - for example, if there are 10 samples having label 0, and
9// 20 samples having label 1, returns a vector [10, 20]
10inline std::vector<serial_size_t> calculate_label_counts(const std::vector<label_t>& t) {
11 std::vector<serial_size_t> label_counts;
12 for (label_t label : t) {
13 if (label >= label_counts.size()) {
14 label_counts.resize(label + 1);
15 }
16 label_counts[label]++;
17 }
18 assert(std::accumulate(label_counts.begin(), label_counts.end(), static_cast<serial_size_t>(0)) == t.size());
19 return label_counts;
20}
21
22// calculate the weight of a given sample needed for a balanced target cost
23// NB: we call a target cost matrix "balanced", if the cost of each *class* is equal
24// (this happens when the product weight * sample count is equal between the different
25// classes, and the sum of these products equals the total number of samples)
26inline float_t get_sample_weight_for_balanced_target_cost(serial_size_t classes, serial_size_t total_samples, serial_size_t this_class_samples)
27{
28 assert(this_class_samples <= total_samples);
29 return total_samples / static_cast<float_t>(classes * this_class_samples);
30}
31
32// create a target cost matrix implying equal cost for each *class* (distinct label)
33// - by default, each *sample* has an equal cost, which means e.g. that a classifier
34// may prefer to always guess the majority class (in case the degree of imbalance
35// is relatively high, and the classification task is relatively difficult)
36// - the parameter w can be used to fine-tune the balance:
37// * use 0 to have an equal cost for each *sample* (equal to not supplying any target costs at all)
38// * use 1 to have an equal cost for each *class* (default behaviour of this function)
39// * use a value between 0 and 1 to have something between the two extremes
40inline std::vector<vec_t> create_balanced_target_cost(const std::vector<label_t>& t, float_t w = 1.0)
41{
42 const auto label_counts = calculate_label_counts(t);
43 const serial_size_t total_sample_count = static_cast<serial_size_t>(t.size());
44 const serial_size_t class_count = static_cast<serial_size_t>(label_counts.size());
45
46 std::vector<vec_t> target_cost(t.size());
47
48 for (serial_size_t i = 0; i < total_sample_count; ++i) {
49 vec_t& sample_cost = target_cost[i];
50 sample_cost.resize(class_count);
51 const float_t balanced_weight = get_sample_weight_for_balanced_target_cost(class_count, total_sample_count, label_counts[t[i]]);
52 const float_t unbalanced_weight = 1;
53 const float_t sample_weight = w * balanced_weight + (1 - w) * unbalanced_weight;
54 std::fill(sample_cost.begin(), sample_cost.end(), sample_weight);
55 }
56
57 return target_cost;
58}
59
60} // namespace tiny_dnn