22 using namespace shogun;
25 :m_max_num_iter(3), m_A(0.5), m_B(5), m_svm_C(1), m_svm_epsilon(0.001),
26 m_kernel(NULL), m_feats(NULL), m_machine_for_confusion_matrix(NULL), m_num_classes(0)
47 REQUIRE(feats != NULL, (
"Require non-NULL dense features of float64_t\n"));
107 node = node->
right();
113 if (node->
data.
mu[i] >= 0)
132 SG_ERROR(
"Call set_machine_for_confusion_matrix before training\n");
134 SG_ERROR(
"assign a valid kernel before training\n");
140 SG_ERROR(
"Require non-NULL dense features of float64_t\n");
159 std::queue<node_t *> node_q;
162 while (node_q.size() != 0)
174 left_classes[k++] = i;
177 left_classes.
vlen = k;
179 if (left_classes.
vlen >= 2)
182 node->
left(left_node);
183 node_q.push(left_node);
192 if (node->
data.
mu[i] >= 0)
193 right_classes[k++] = i;
196 right_classes.
vlen = k;
198 if (right_classes.
vlen >= 2)
201 node->
right(right_node);
202 node_q.push(right_node);
216 CSVM *best_svm = NULL;
217 float64_t best_score = std::numeric_limits<float64_t>::max();
219 std::vector<CRelaxedTree::entry_t> mu_init =
init_node(conf_mat, classes);
220 for (std::vector<CRelaxedTree::entry_t>::const_iterator it = mu_init.begin(); it != mu_init.end(); ++it)
229 if (score < best_score)
250 for (int32_t i=0; i < best_mu.
vlen; ++i)
253 long_mu[classes[i]] = 1;
254 else if (best_mu[i] == -1)
255 long_mu[classes[i]] = -1;
256 else if (best_mu[i] == 0)
257 long_mu[classes[i]] = 0;
267 for (int32_t i=0; i < mu.
vlen; ++i)
271 else if (mu[i] == -1)
276 float64_t score = num_neg/(num_neg+num_pos) * totalSV/num_pos +
277 num_pos/(num_neg+num_pos)*totalSV/num_neg;
285 mu[mu_entry.first.first] = 1;
286 mu[mu_entry.first.second] = -1;
295 for (int32_t i=0; i < classes.
vlen; ++i)
298 long_mu[classes[i]] = 1;
299 else if (mu[i] == -1)
300 long_mu[classes[i]] = -1;
308 for (int32_t i=0; i < binlab.vlen; ++i)
311 binlab[i] = long_mu[lab];
312 if (long_mu[lab] != 0)
333 std::copy(&mu[0], &mu[mu.vlen], &prev_mu[0]);
338 for (int32_t i=0; i < mu.vlen; ++i)
340 if (mu[i] != prev_mu[i])
358 return e1.second < e2.second;
369 conf_mat(i, j) = global_conf_mat(classes[i], classes[j]);
378 conf_mat(i,j) += conf_mat(j,i);
382 std::vector<CRelaxedTree::entry_t> entries;
387 entries.push_back(std::make_pair(std::make_pair(i, j), conf_mat(i,j)));
393 const size_t max_n_samples = 30;
394 int32_t n_samples = std::min(max_n_samples, entries.size());
396 return std::vector<CRelaxedTree::entry_t>(entries.begin(), entries.begin() + n_samples);
410 for (int32_t i=0; i < classes.
vlen; ++i)
424 for (int32_t j=0; j < resp.
vlen; ++j)
428 xi_pos_class[i] += std::max(0.0, 1 - resp[j]);
429 xi_neg_class[i] += std::max(0.0, 1 + resp[j]);
436 if (delta_pos[i] > 0 && delta_neg[i] > 0)
442 if (delta_pos[i] < delta_neg[i])
452 for (int32_t i=0; i < mu.
vlen; ++i)
489 if (mu[min_idx] == 1 && (npos == 0 || npos == 1))
494 for (i=0; i < xi_neg_class.vlen; ++i)
498 min_val = xi_neg_class[i];
503 for (i=i+1; i < xi_neg_class.vlen; ++i)
505 if (mu[i] != 1 && xi_neg_class[i] < min_val)
507 min_val = xi_neg_class[i];
528 int32_t num_zero = index_zero.
vlen;
529 int32_t num_pos = index_pos.
vlen;
532 std::copy(&index_zero[0], &index_zero[num_zero], &class_index[0]);
533 std::copy(&index_pos[0], &index_pos[num_pos], &class_index[num_zero]);
534 std::copy(&index_pos[0], &index_pos[num_pos], &class_index[num_pos+num_zero]);
538 std::fill(&orig_mu[num_zero], &orig_mu[orig_mu.
vlen], 1);
541 std::fill(&delta_steps[0], &delta_steps[delta_steps.
vlen], 1);
545 std::fill(&new_mu[0], &new_mu[num_zero], -1);
549 for (
index_t i=0; i < num_zero; ++i)
550 S_delta[i] = delta_neg[index_zero[i]];
552 for (int32_t i=0; i < num_pos; ++i)
554 float64_t delta_k = delta_neg[index_pos[i]];
555 float64_t delta_k_0 = -delta_pos[index_pos[i]];
557 index_t tmp_index = num_zero + i*2;
558 if (delta_k_0 <= delta_k)
560 new_mu[tmp_index] = 0;
561 new_mu[tmp_index+1] = -1;
563 S_delta[tmp_index] = delta_k_0;
564 S_delta[tmp_index+1] = delta_k;
566 delta_steps[tmp_index] = 1;
567 delta_steps[tmp_index+1] = 1;
571 new_mu[tmp_index] = -1;
572 new_mu[tmp_index+1] = 0;
574 S_delta[tmp_index] = (delta_k_0+delta_k)/2;
575 S_delta[tmp_index+1] = delta_k_0;
577 delta_steps[tmp_index] = 2;
578 delta_steps[tmp_index+1] = 1;
586 S_delta_sorted[i] = S_delta[sorted_index[i]];
587 new_mu[i] = new_mu[sorted_index[i]];
588 orig_mu[i] = orig_mu[sorted_index[i]];
589 class_index[i] = class_index[sorted_index[i]];
590 delta_steps[i] = delta_steps[sorted_index[i]];
594 std::fill(&valid_flag[0], &valid_flag[valid_flag.
vlen], 1);
601 if (d == B_prime -
m_B || d == B_prime -
m_B + 1)
604 while (!valid_flag[ctr])
607 if (delta_steps[ctr] == 1)
609 mu[class_index[ctr]] = new_mu[ctr];
615 if (d <= B_prime -
m_B - 2)
617 mu[class_index[ctr]] = new_mu[ctr];
618 ASSERT(new_mu[ctr] == -1);
620 for (
index_t i=0; i < class_index.vlen; ++i)
622 if (class_index[i] == class_index[ctr])
628 float64_t Delta_k_minus = 2*S_delta_sorted[ctr];
633 for (int32_t itr=ctr+1; itr < S_delta_sorted.
vlen; ++itr)
635 if (valid_flag[itr] == 0)
638 if (delta_steps[itr] == 1)
641 Delta_j_min = S_delta_sorted[j];
648 for (int32_t itr=ctr-1; itr >= 0; --itr)
650 if (delta_steps[itr] == 1 && valid_flag[itr] == 1)
653 Delta_i_max = S_delta_sorted[i];
658 float64_t Delta_l_max = std::numeric_limits<float64_t>::min();
660 for (int32_t itr=ctr-1; itr >= 0; itr--)
662 if (delta_steps[itr] == 2)
664 float64_t delta_tmp = xi_neg_class[class_index[itr]];
665 if (delta_tmp > Delta_l_max)
668 Delta_l_max = delta_tmp;
674 if (Delta_j_min <= Delta_k_minus - Delta_i_max &&
675 Delta_j_min <= Delta_k_minus - Delta_l_max)
677 mu[class_index[j]] = new_mu[j];
683 if (Delta_k_minus - Delta_i_max <= Delta_j_min &&
684 Delta_k_minus - Delta_i_max <= Delta_k_minus - Delta_l_max)
686 mu[class_index[ctr]] = -1;
689 mu[class_index[i]] = orig_mu[i];
700 mu[class_index[l]] = 0;
701 mu[class_index[ctr]] = -1;
717 int32_t num_zero = index_zero.
vlen;
718 int32_t num_neg = index_neg.
vlen;
721 std::copy(&index_zero[0], &index_zero[num_zero], &class_index[0]);
722 std::copy(&index_neg[0], &index_neg[num_neg], &class_index[num_zero]);
723 std::copy(&index_neg[0], &index_neg[num_neg], &class_index[num_neg+num_zero]);
727 std::fill(&orig_mu[num_zero], &orig_mu[orig_mu.
vlen], -1);
730 std::fill(&delta_steps[0], &delta_steps[delta_steps.
vlen], 1);
734 std::fill(&new_mu[0], &new_mu[num_zero], 1);
738 for (
index_t i=0; i < num_zero; ++i)
739 S_delta[i] = delta_pos[index_zero[i]];
741 for (int32_t i=0; i < num_neg; ++i)
743 float64_t delta_k = delta_pos[index_neg[i]];
744 float64_t delta_k_0 = -delta_neg[index_neg[i]];
746 index_t tmp_index = num_zero + i*2;
747 if (delta_k_0 <= delta_k)
749 new_mu[tmp_index] = 0;
750 new_mu[tmp_index+1] = 1;
752 S_delta[tmp_index] = delta_k_0;
753 S_delta[tmp_index+1] = delta_k;
755 delta_steps[tmp_index] = 1;
756 delta_steps[tmp_index+1] = 1;
760 new_mu[tmp_index] = 1;
761 new_mu[tmp_index+1] = 0;
763 S_delta[tmp_index] = (delta_k_0+delta_k)/2;
764 S_delta[tmp_index+1] = delta_k_0;
766 delta_steps[tmp_index] = 2;
767 delta_steps[tmp_index+1] = 1;
775 S_delta_sorted[i] = S_delta[sorted_index[i]];
776 new_mu[i] = new_mu[sorted_index[i]];
777 orig_mu[i] = orig_mu[sorted_index[i]];
778 class_index[i] = class_index[sorted_index[i]];
779 delta_steps[i] = delta_steps[sorted_index[i]];
783 std::fill(&valid_flag[0], &valid_flag[valid_flag.
vlen], 1);
790 if (d == -
m_B - B_prime || d == -
m_B - B_prime + 1)
793 while (!valid_flag[ctr])
796 if (delta_steps[ctr] == 1)
798 mu[class_index[ctr]] = new_mu[ctr];
804 if (d >= -
m_B - B_prime - 2)
806 mu[class_index[ctr]] = new_mu[ctr];
810 for (
index_t i=0; i < class_index.vlen; ++i)
812 if (class_index[i] == class_index[ctr])
818 float64_t Delta_k_minus = 2*S_delta_sorted[ctr];
823 for (int32_t itr=ctr+1; itr < S_delta_sorted.
vlen; ++itr)
825 if (valid_flag[itr] == 0)
828 if (delta_steps[itr] == 1)
831 Delta_j_min = S_delta_sorted[j];
838 for (int32_t itr=ctr-1; itr >= 0; --itr)
840 if (delta_steps[itr] == 1 && valid_flag[itr] == 1)
843 Delta_i_max = S_delta_sorted[i];
848 float64_t Delta_l_max = std::numeric_limits<float64_t>::min();
850 for (int32_t itr=ctr-1; itr >= 0; itr--)
852 if (delta_steps[itr] == 2)
854 float64_t delta_tmp = xi_neg_class[class_index[itr]];
855 if (delta_tmp > Delta_l_max)
858 Delta_l_max = delta_tmp;
864 if (Delta_j_min <= Delta_k_minus - Delta_i_max &&
865 Delta_j_min <= Delta_k_minus - Delta_l_max)
867 mu[class_index[j]] = new_mu[j];
873 if (Delta_k_minus - Delta_i_max <= Delta_j_min &&
874 Delta_k_minus - Delta_i_max <= Delta_k_minus - Delta_l_max)
876 mu[class_index[ctr]] = -1;
879 mu[class_index[i]] = orig_mu[i];
890 mu[class_index[l]] = 0;
891 mu[class_index[ctr]] = -1;
905 for (int32_t i=0; i < resp.vlen; ++i)