1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 package unbbayes.controller;
22
23 import java.awt.Cursor;
24 import java.awt.event.ActionEvent;
25 import java.awt.event.ActionListener;
26 import java.text.NumberFormat;
27 import java.util.ArrayList;
28 import java.util.List;
29 import java.util.Locale;
30 import java.util.ResourceBundle;
31
32 import javax.swing.JOptionPane;
33 import javax.swing.JTable;
34 import javax.swing.event.TableModelEvent;
35 import javax.swing.event.TableModelListener;
36
37 import unbbayes.gui.ExplanationProperties;
38 import unbbayes.gui.NetworkWindow;
39 import unbbayes.gui.continuous.ContinuousNormalDistributionPane;
40 import unbbayes.gui.table.GUIPotentialTable;
41 import unbbayes.gui.table.ReplaceTextCellEditor;
42 import unbbayes.prs.Edge;
43 import unbbayes.prs.Node;
44 import unbbayes.prs.bn.IProbabilityFunction;
45 import unbbayes.prs.bn.IRandomVariable;
46 import unbbayes.prs.bn.JunctionTreeAlgorithm;
47 import unbbayes.prs.bn.PotentialTable;
48 import unbbayes.prs.bn.ProbabilisticNode;
49 import unbbayes.prs.bn.SingleEntityNetwork;
50 import unbbayes.prs.exception.InvalidParentException;
51 import unbbayes.prs.hybridbn.CNNormalDistribution;
52 import unbbayes.prs.hybridbn.ContinuousNode;
53 import unbbayes.prs.hybridbn.GaussianMixture;
54 import unbbayes.prs.id.DecisionNode;
55 import unbbayes.prs.id.UtilityNode;
56 import unbbayes.util.extension.bn.inference.IInferenceAlgorithm;
57
58 public class SENController {
59
60 private NetworkWindow screen;
61
62 private SingleEntityNetwork singleEntityNetwork;
63
64 private NumberFormat df;
65
66
67 protected GaussianMixture gmInference;
68
69 public enum InferenceAlgorithmEnum {
70 JUNCTION_TREE,
71 LIKELIHOOD_WEIGHTING,
72 GIBBS,
73 GAUSSIAN_MIXTURE
74 }
75
76
77 private IInferenceAlgorithm inferenceAlgorithm = new JunctionTreeAlgorithm();
78
79 public IInferenceAlgorithm getInferenceAlgorithm() {
80 return inferenceAlgorithm;
81 }
82
83 public void setInferenceAlgorithm(IInferenceAlgorithm inferenceAlgorithm) {
84 this.inferenceAlgorithm = inferenceAlgorithm;
85 }
86
87
88 private static ResourceBundle resource = unbbayes.util.ResourceController.newInstance().getBundle(
89 unbbayes.controller.resources.ControllerResources.class.getName());
90
91
92
93
94
95 public SENController(SingleEntityNetwork singleEntityNetwork,
96 NetworkWindow screen) {
97 this.singleEntityNetwork = singleEntityNetwork;
98 this.screen = screen;
99 df = NumberFormat.getInstance(Locale.getDefault());
100 df.setMaximumFractionDigits(4);
101 }
102
103
104
105
106
107
108
109
110 public void insertState(Node node) {
111 if (node instanceof ProbabilisticNode) {
112 node.appendState(resource.getString("stateProbabilisticName")
113 + node.getStatesSize());
114
115
116 for (Node child : node.getChildren()) {
117 if (child.getType() == Node.CONTINUOUS_NODE_TYPE) {
118 ((ContinuousNode)child).getCnNormalDistribution().refreshParents();
119 }
120 }
121 } else if (node instanceof DecisionNode) {
122 node.appendState(resource.getString("stateDecisionName")
123 + node.getStatesSize());
124 }
125 screen.setTable(makeTable(node), node);
126 }
127
128
129
130
131
132
133
134
135
136 public void removeState(Node node) {
137 node.removeLastState();
138
139
140 for (Node child : node.getChildren()) {
141 if (child.getType() == Node.CONTINUOUS_NODE_TYPE) {
142 ((ContinuousNode)child).getCnNormalDistribution().refreshParents();
143 }
144 }
145 screen.setTable(makeTable(node), node);
146 }
147
148
149
150
151 public void initialize() {
152 try {
153 this.getInferenceAlgorithm().reset();
154
155 screen.getEvidenceTree().updateTree(true);
156 } catch (Exception e) {
157 e.printStackTrace();
158 }
159 }
160
161
162
163
164 public void propagate() {
165
166 boolean bReset = false;
167 screen.setCursor(new Cursor(Cursor.WAIT_CURSOR));
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192 try {
193 this.getInferenceAlgorithm().propagate();
194 } catch (Exception e) {
195 JOptionPane.showMessageDialog(screen, e.getMessage(), resource
196 .getString("statusError"), JOptionPane.ERROR_MESSAGE);
197 bReset = true;
198 }
199
200
201 screen.getEvidenceTree().updateTree(bReset);
202 screen.setCursor(new Cursor(Cursor.DEFAULT_CURSOR));
203 }
204
205
206
207
208
209
210
211
212
213 public boolean compileNetwork() {
214 long ini = System.currentTimeMillis();
215 screen.setCursor(new Cursor(Cursor.WAIT_CURSOR));
216
217
218 singleEntityNetwork.resetNodesCopy();
219
220
221 screen.getEvidenceTree().resetTree();
222
223
224 try {
225 singleEntityNetwork.resetEvidences();
226 this.getInferenceAlgorithm().setNetwork(singleEntityNetwork);
227 this.getInferenceAlgorithm().run();
228 } catch (Exception e) {
229 e.printStackTrace();
230 JOptionPane.showMessageDialog(null, e.getMessage(), resource
231 .getString("statusError"), JOptionPane.ERROR_MESSAGE);
232 screen.setCursor(new Cursor(Cursor.DEFAULT_CURSOR));
233 return false;
234 }
235
236
237
238 screen.setCursor(new Cursor(Cursor.DEFAULT_CURSOR));
239
240 screen.setStatus(resource.getString("statusTotalTime")
241 + df.format(((System.currentTimeMillis() - ini)) / 1000.0)
242 + resource.getString("statusSeconds"));
243 return true;
244 }
245
246
247
248
249
250
251 public Node insertContinuousNode(double x, double y) {
252 ContinuousNode node = new ContinuousNode();
253 node.setPosition(x, y);
254 node.setName(resource.getString("probabilisticNodeName")
255 + singleEntityNetwork.getNodeCount());
256 node.setDescription(node.getName());
257
258 singleEntityNetwork.addNode(node);
259
260 return node;
261 }
262
263
264
265
266
267
268 public Node insertProbabilisticNode(double x, double y) {
269 ProbabilisticNode node = new ProbabilisticNode();
270 node.setPosition(x, y);
271 node.appendState(resource.getString("firstStateProbabilisticName"));
272 node.setName(resource.getString("probabilisticNodeName")
273 + singleEntityNetwork.getNodeCount());
274 node.setDescription(node.getName());
275 PotentialTable auxTabProb = (PotentialTable)(node)
276 .getProbabilityFunction();
277 auxTabProb.addVariable(node);
278 auxTabProb.setValue(0, 1);
279 singleEntityNetwork.addNode(node);
280
281 return node;
282 }
283
284
285
286
287
288
289 public Node insertDecisionNode(double x, double y) {
290 DecisionNode node = new DecisionNode();
291 node.setPosition(x, y);
292 node.appendState(resource.getString("firstStateDecisionName"));
293 node.setName(resource.getString("decisionNodeName")
294 + singleEntityNetwork.getNodeCount());
295 node.setDescription(node.getName());
296 singleEntityNetwork.addNode(node);
297
298 return node;
299 }
300
301
302
303
304
305
306 public Node insertUtilityNode(double x, double y) {
307 UtilityNode node = new UtilityNode();
308 node.setPosition(x, y);
309 node.setName(resource.getString("utilityNodeName")
310 + singleEntityNetwork.getNodeCount());
311 node.setDescription(node.getName());
312 IProbabilityFunction auxTab = ((IRandomVariable) node).getProbabilityFunction();
313 auxTab.addVariable(node);
314 singleEntityNetwork.addNode(node);
315
316 return node;
317 }
318
319
320
321
322
323
324
325 public boolean insertEdge(Edge edge) {
326 try {
327 singleEntityNetwork.addEdge(edge);
328 } catch (InvalidParentException e) {
329 JOptionPane.showMessageDialog(null, e.getMessage(), resource
330 .getString("statusError"), JOptionPane.ERROR_MESSAGE);
331
332 return false;
333 }
334
335 return true;
336 }
337
338
339
340
341 public void deleteSelectedItem(){
342
343 Object selected = screen.getGraphPane().getSelected();
344 if(selected != null){
345 deleteSelected(selected);
346 }
347
348 screen.getGraphPane().update();
349 }
350
351
352
353
354
355
356 public void createContinuousDistribution(final ContinuousNode node) {
357
358 final List<Node> discreteNodeList = new ArrayList<Node>();
359 final List<String> discreteNodeNameList = new ArrayList<String>();
360 final List<String> continuousNodeNameList = new ArrayList<String>();
361 for (Node n : node.getParents()) {
362 if (n.getType() == Node.PROBABILISTIC_NODE_TYPE) {
363 discreteNodeList.add(n);
364 discreteNodeNameList.add(n.getName());
365 } else if (n.getType() == Node.CONTINUOUS_NODE_TYPE) {
366 continuousNodeNameList.add(n.getName());
367 }
368 }
369
370
371 final ContinuousNormalDistributionPane distributionPane = new ContinuousNormalDistributionPane(discreteNodeNameList, continuousNodeNameList);
372 screen.setDistributionPane(distributionPane);
373 screen.setTableOwner(node);
374
375
376 for (Node n : discreteNodeList) {
377 List<String> stateList = new ArrayList<String>(n.getStatesSize());
378 for (int i = 0; i < n.getStatesSize(); i++) {
379 stateList.add(n.getStateAt(i));
380 }
381 distributionPane.fillDiscreteParentStateSelection(n.getName(), stateList);
382 }
383
384 loadContinuousDistributionPaneValues(distributionPane, node.getCnNormalDistribution());
385
386
387 ActionListener confirmAL = new ActionListener() {
388 public void actionPerformed(ActionEvent ae) {
389 setContinuousDistributionValues(distributionPane, node.getCnNormalDistribution());
390 }
391 };
392
393
394
395 ActionListener restoreValuesFromDistributionAL = new ActionListener() {
396 public void actionPerformed(ActionEvent ae) {
397 loadContinuousDistributionPaneValues(distributionPane, node.getCnNormalDistribution());
398 }
399 };
400
401 distributionPane.addConfirmButtonActionListener(confirmAL);
402 distributionPane.addCancelButtonActionListener(restoreValuesFromDistributionAL);
403 distributionPane.addParentStateChangeActionListener(restoreValuesFromDistributionAL);
404
405
406
407
408 }
409
410
411
412
413
414
415 private void loadContinuousDistributionPaneValues(ContinuousNormalDistributionPane distributionPane, CNNormalDistribution distribution) {
416
417 int[] mCoord = distributionPane.getDiscreteParentNodeStateSelectedList();
418 if (mCoord.length == 0) {
419 mCoord = new int[1];
420 mCoord[0] = 0;
421 }
422
423
424 String mean = String.valueOf(distribution.getMean(mCoord));
425 String variance = String.valueOf(distribution.getVariance(mCoord));
426 List<String> constantList = new ArrayList<String>(distribution.getConstantListSize());
427 for (int i = 0; i < distribution.getConstantListSize(); i++) {
428 constantList.add(String.valueOf(distribution.getConstantAt(i, mCoord)));
429 }
430
431
432 distributionPane.setMeanText(mean);
433 distributionPane.setVarianceText(variance);
434 distributionPane.setConstantTextList(constantList);
435 }
436
437
438
439
440
441
442 private void setContinuousDistributionValues(ContinuousNormalDistributionPane distributionPane, CNNormalDistribution distribution) {
443 try {
444
445 double mean = Double.parseDouble(distributionPane.getMeanText());
446 double variance = Double.parseDouble(distributionPane.getVarianceText());
447 List<String> constantTextList = distributionPane.getConstantTextList();
448 List<Double> constantList = new ArrayList<Double>(constantTextList.size());
449 for (String constantText : constantTextList) {
450 constantList.add(Double.parseDouble(constantText));
451 }
452
453
454 int[] multidimensionalCoord = distributionPane.getDiscreteParentNodeStateSelectedList();
455 if (multidimensionalCoord.length == 0) {
456 multidimensionalCoord = new int[1];
457 multidimensionalCoord[0] = 0;
458 }
459
460
461 distribution.setMean(mean, multidimensionalCoord);
462 distribution.setVariance(variance, multidimensionalCoord);
463 for (int i = 0; i < constantList.size(); i++) {
464 distribution.setConstantAt(i, constantList.get(i), multidimensionalCoord);
465 }
466
467 } catch (NumberFormatException e) {
468 JOptionPane.showMessageDialog(null, resource.getString("continuousNormalDistributionParamError"), resource
469 .getString("statusError"), JOptionPane.ERROR_MESSAGE);
470 }
471 }
472
473
474
475
476
477
478 public void createDiscreteTable(Node node) {
479 screen.setTable(makeTable(node), node);
480
481 if (screen.isCompiled()) {
482 for (int i = 0; i < screen.getEvidenceTree().getRowCount(); i++) {
483 if (screen.getEvidenceTree().getPathForRow(i).getLastPathComponent().toString().equals(node.toString())) {
484 if (screen.getEvidenceTree().isExpanded(screen.getEvidenceTree().getPathForRow(i))) {
485 screen.getEvidenceTree().collapsePath(screen.getEvidenceTree().getPathForRow(i));
486 }
487 else {
488 screen.getEvidenceTree().expandPath(screen.getEvidenceTree().getPathForRow(i));
489 }
490 break;
491 }
492 }
493 }
494 screen.setAddRemoveStateButtonVisible(true);
495 }
496
497
498
499
500
501
502
503
504 public JTable makeTable(final Node node) {
505 screen.getTxtDescription().setEnabled(true);
506 screen.getTxtName().setEnabled(true);
507 screen.getTxtDescription().setText(node.getDescription());
508 screen.getTxtName().setText(node.getName());
509
510 final JTable table;
511 final PotentialTable potTab;
512
513
514 if (node.getStatesSize() == 0) {
515 Node parent = node.getParents().get(0);
516 int numClasses = parent.getStatesSize();
517 double[] mean = node.getMean();
518 double[] stdDev = node.getStandardDeviation();
519
520 table = new JTable(3, numClasses + 1);
521 table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
522 table.setTableHeader(null);
523
524
525 table.setValueAt(parent.getName(), 0, 0);
526 table.setValueAt(resource.getString("mean"), 1, 0);
527 table.setValueAt(resource.getString("stdDev"), 2, 0);
528
529
530 for (int i = 0; i < numClasses; i++) {
531 table.setValueAt(parent.getStateAt(i), 0, i + 1);
532 table.setValueAt(mean[i], 1, i + 1);
533 table.setValueAt(stdDev[i], 2, i + 1);
534 }
535
536 return table;
537 }
538
539 if (node instanceof IRandomVariable) {
540 potTab = (PotentialTable)((IRandomVariable) node).getProbabilityFunction();
541
542 table = new GUIPotentialTable(potTab).makeTable();
543
544 } else {
545
546
547
548
549
550
551
552
553
554 potTab = null;
555
556 table = new JTable(node.getStatesSize(), 1);
557
558 for (int i = 0; i < node.getStatesSize(); i++) {
559 table.setValueAt(node.getStateAt(i), i, 0);
560 }
561
562 table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
563 table.setTableHeader(null);
564
565 }
566
567 table.getModel().addTableModelListener(new TableModelListener() {
568 public void tableChanged(TableModelEvent e) {
569
570 if (e.getColumn() == 0) {
571 if (!table.getValueAt(e.getLastRow(), e.getColumn()).toString().trim().equals("")) {
572 node.setStateAt(table.getValueAt(e.getLastRow(),
573 e.getColumn()).toString(), e.getLastRow()
574 - (table.getRowCount() - node.getStatesSize()));
575 } else {
576 table.revalidate();
577 table.setValueAt(node.getStateAt(e.getLastRow()
578 - (table.getRowCount() - node.getStatesSize())), e.getLastRow(),
579 e.getColumn());
580 }
581
582 } else if (potTab != null) {
583 String valueText = table.getValueAt(e.getLastRow(),
584 e.getColumn()).toString().replace(',', '.');
585 try {
586 float value = Float.parseFloat(valueText);
587 potTab.setValue((e.getColumn() - 1) * node.getStatesSize() + e.getLastRow(), value);
588 } catch (NumberFormatException nfe) {
589
590 if (!valueText.trim().equals("")) {
591 JOptionPane.showMessageDialog(null,
592 resource.getString("numberFormatError"),
593 resource.getString("error"),
594 JOptionPane.ERROR_MESSAGE);
595 }
596 table.revalidate();
597 table.setValueAt(""
598 + potTab.getValue((e.getColumn() - 1) * node.getStatesSize() + e.getLastRow()),
599 e.getLastRow(), e.getColumn());
600 }
601 }
602 }
603 });
604
605
606 ReplaceTextCellEditor cellEditor = new ReplaceTextCellEditor();
607 for (int i = 0; i < table.getColumnModel().getColumnCount(); i++) {
608 table.getColumnModel().getColumn(i).setCellEditor(cellEditor);
609 }
610
611
612 table.setSurrendersFocusOnKeystroke(true);
613
614 return table;
615 }
616
617
618
619
620
621
622
623
624
625
626 public JTable makeTableOld(final Node node) {
627 screen.getTxtDescription().setEnabled(true);
628 screen.getTxtName().setEnabled(true);
629 screen.getTxtDescription().setText(node.getDescription());
630 screen.getTxtName().setText(node.getName());
631
632 final JTable table;
633 final PotentialTable potTab;
634 final int variables;
635
636
637 if (node.getStatesSize() == 0) {
638 Node parent = node.getParents().get(0);
639 int numClasses = parent.getStatesSize();
640 double[] mean = node.getMean();
641 double[] stdDev = node.getStandardDeviation();
642
643 table = new JTable(3, numClasses + 1);
644 table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
645 table.setTableHeader(null);
646
647
648 table.setValueAt(parent.getName(), 0, 0);
649 table.setValueAt(resource.getString("mean"), 1, 0);
650 table.setValueAt(resource.getString("stdDev"), 2, 0);
651
652
653 for (int i = 0; i < numClasses; i++) {
654 table.setValueAt(parent.getStateAt(i), 0, i + 1);
655 table.setValueAt(mean[i], 1, i + 1);
656 table.setValueAt(stdDev[i], 2, i + 1);
657 }
658
659 return table;
660 }
661
662 if (node instanceof IRandomVariable) {
663 potTab = (PotentialTable)((IRandomVariable) node).getProbabilityFunction();
664
665 int states = 1;
666 variables = potTab.variableCount();
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682 states = potTab.tableSize() / node.getStatesSize();
683
684
685
686
687
688
689
690
691 int rows = node.getStatesSize() + variables - 1;
692
693
694
695
696 int columns = states + 1;
697
698 table = new JTable(rows, columns);
699
700
701
702
703
704 for (int k = variables - 1, l = 0; k < table.getRowCount(); k++, l++) {
705 table.setValueAt(node.getStateAt(l), k, 0);
706 }
707
708
709
710 for (int k = variables - 1, l = 0; k >= 1; k--, l++) {
711 Node variable = (Node)potTab.getVariableAt(k);
712
713
714
715 states /= variable.getStatesSize();
716
717
718 table.setValueAt(variable.getName(), l, 0);
719
720
721
722
723 for (int i = 0; i < table.getColumnCount() - 1; i++) {
724 table.setValueAt(variable.getStateAt((i / states)
725 % variable.getStatesSize()), l, i + 1);
726 }
727 }
728
729
730 states = node.getStatesSize();
731
732
733
734 for (int i = 1, k = 0; i < table.getColumnCount(); i++, k += states) {
735 for (int j = variables - 1, l = 0; j < table.getRowCount(); j++, l++) {
736 table.setValueAt("" + df.format(potTab.getValue(k + l)), j,
737 i);
738 }
739 }
740
741 } else {
742
743
744
745
746
747
748
749
750
751 potTab = null;
752 variables = node.getParents().size();
753
754 table = new JTable(node.getStatesSize(), 1);
755
756 for (int i = 0; i < node.getStatesSize(); i++) {
757 table.setValueAt(node.getStateAt(i), i, 0);
758 }
759
760 }
761
762 table.setTableHeader(null);
763 table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
764 table.getModel().addTableModelListener(new TableModelListener() {
765 public void tableChanged(TableModelEvent e) {
766 if (e.getLastRow() < variables - 1) {
767 return;
768 }
769 if (e.getColumn() == 0) {
770 if (!table.getValueAt(e.getLastRow(), e.getColumn())
771 .equals("")) {
772 node.setStateAt(table.getValueAt(e.getLastRow(),
773 e.getColumn()).toString(), e.getLastRow()
774 - (table.getRowCount() - node.getStatesSize()));
775 }
776 } else {
777 String temp = table.getValueAt(e.getLastRow(),
778 e.getColumn()).toString().replace(',', '.');
779 try {
780 float valor = Float.parseFloat(temp);
781 potTab.setValue((e.getColumn() - 1) * node.getStatesSize() + e.getLastRow(), valor);
782 } catch (NumberFormatException nfe) {
783 JOptionPane.showMessageDialog(null,
784 resource.getString("error"),
785 resource.getString("realNumberError"),
786 JOptionPane.ERROR_MESSAGE);
787 table.revalidate();
788 table.setValueAt(""
789 + potTab.getValue((e.getColumn() - 1) * node.getStatesSize() + e.getLastRow()),
790 e.getLastRow(), e.getColumn());
791 }
792 }
793 }
794 });
795
796 return table;
797 }
798
799 public void deleteSelected(Object selecionado) {
800
801 if (selecionado instanceof Edge) {
802 singleEntityNetwork.removeEdge((Edge) selecionado);
803 } else if (selecionado instanceof Node) {
804 singleEntityNetwork.removeNode((Node) selecionado);
805 }
806
807
808 }
809
810 public void showExplanationProperties(ProbabilisticNode node) {
811 screen.setCursor(new Cursor(Cursor.WAIT_CURSOR));
812 ExplanationProperties explanation = new ExplanationProperties(screen,
813 singleEntityNetwork);
814 explanation.setProbabilisticNode(node);
815 explanation.setVisible(true);
816 screen.setCursor(new Cursor(Cursor.DEFAULT_CURSOR));
817 }
818
819 }