1 7 package com.maverick.multiplex; 8 9 import java.io.ByteArrayOutputStream ; 10 import java.io.DataInputStream ; 11 import java.io.DataOutputStream ; 12 import java.io.EOFException ; 13 import java.io.IOException ; 14 import java.io.InputStream ; 15 import java.io.InterruptedIOException ; 16 import java.io.OutputStream ; 17 import java.util.Enumeration ; 18 import java.util.Hashtable ; 19 import java.util.Vector ; 20 21 import com.jcraft.jzlib.JZlib; 22 import com.jcraft.jzlib.ZStream; 23 24 29 public class MultiplexedConnection implements RequestHandler { 30 31 int nextChannelId = 0; 32 Hashtable channelsById = new Hashtable (); 33 DataInputStream in; 34 DataOutputStream out; 35 boolean running = false; 36 ChannelFactory factory; 37 int totalChannels; 38 int maxChannels = 0; 39 Thread thread; 40 Vector activeChannels = new Vector (); 41 Vector listeners = new Vector (); 42 MessageStore globalMessages = new MessageStore(null, null); 43 Hashtable requestHandlers = new Hashtable (); 44 Object sendLock = new Object (); 45 Object startLock = new Object (); 46 Request lastRequest; 47 TimeoutCallback timeoutCallback = null; 48 49 static org.apache.commons.logging.Log log = org.apache.commons.logging.LogFactory 51 .getLog(MultiplexedConnection.class); 52 54 public static final int MSG_CHANNEL_OPEN = 1; 55 public static final int MSG_CHANNEL_OPEN_CONFIRMATION = 2; 56 public static final int MSG_CHANNEL_OPEN_FAILURE = 3; 57 public static final int MSG_CHANNEL_DATA = 4; 58 public static final int MSG_CHANNEL_WINDOW_ADJUST = 5; 59 public static final int MSG_CHANNEL_CLOSE = 6; 60 public static final int MSG_DISCONNECT = 7; 61 public static final int MSG_REQUEST = 8; 62 public static final int MSG_REQUEST_SUCCESS = 9; 63 public static final int MSG_REQUEST_FAILURE = 10; 64 65 public static final int MSG_CHANNEL_REQUEST = 11; 66 public static final int MSG_CHANNEL_REQUEST_SUCCESS = 12; 67 public static final int MSG_CHANNEL_REQUEST_FAILURE = 13; 68 69 MessageObserver channelOpenMessages = new MessageObserver() { 70 public boolean wantsNotification(Message msg) { 71 switch (msg.getMessageId()) { 72 case MSG_CHANNEL_OPEN_CONFIRMATION: 73 case MSG_CHANNEL_OPEN_FAILURE: 74 return true; 75 default: 76 return false; 77 } 78 } 79 }; 80 81 MessageObserver requestMessages = new MessageObserver() { 82 public boolean wantsNotification(Message msg) { 83 switch (msg.getMessageId()) { 84 case MSG_REQUEST_SUCCESS: 85 case MSG_REQUEST_FAILURE: 86 case MSG_DISCONNECT: 87 return true; 88 default: 89 return false; 90 } 91 } 92 }; 93 94 public MultiplexedConnection(ChannelFactory factory) { 95 this.factory = factory; 96 } 97 98 public void setChannelFactory(ChannelFactory factory) { 99 this.factory = factory; 100 } 101 102 public void startProtocol(InputStream _in, OutputStream _out, boolean threaded) { 103 setStreams(_in, _out); 104 if(threaded) { 105 this.thread = new Thread (new Runner(), "MultiplexProtocolThread"); 106 thread.start(); 107 } 108 else { 109 this.thread = Thread.currentThread(); 110 runProtocol(); 111 } 112 } 113 114 public void setMaxChannels(int maxChannels) { 115 this.maxChannels = maxChannels; 116 } 117 118 public int getMaxChannels() { 119 return maxChannels; 120 } 121 122 public int getTotalChannelCount() { 123 return totalChannels; 124 } 125 126 131 public Thread getThread() { 132 return thread; 133 } 134 135 public int getActiveChannelCount() { 136 return activeChannels.size(); 137 } 138 139 public synchronized Channel[] getActiveChannels() { 140 Channel[] tmp = new Channel[activeChannels.size()]; 141 activeChannels.copyInto(tmp); 142 return tmp; 143 } 144 145 public void addListener(MultiplexedConnectionListener listener) { 146 if (listener != null) 147 listeners.addElement(listener); 148 } 149 150 public void stop() { 151 if(log.isDebugEnabled()) 153 log.debug("Shutting down multiplexed connection thread"); 154 running = false; 156 if (thread != null && !Thread.currentThread().equals(thread)) 157 thread.interrupt(); 158 } 159 160 163 public void close() { 164 if(log.isDebugEnabled()) 166 log.debug("Closing multiplexed connection streams"); 167 running = false; 169 try { 170 in.close(); 171 } 172 catch(IOException ioe) { 173 } 174 try { 175 out.close(); 176 } 177 catch(IOException ioe) { 178 } 179 } 180 181 protected void onChannelOpen(Message msg) throws IOException { 182 String type = msg.readString(); 183 int remoteid = (int) msg.readInt(); 184 185 if(log.isDebugEnabled()) 187 log.debug("Open channel '" + type + "' [" + remoteid + "]"); 188 190 int remotepacket = (int) msg.readInt(); 191 int remotewindow = (int) msg.readInt(); 192 193 byte[] data = null; 194 195 if(msg.available() > 0) { 196 data = new byte[msg.available()]; 197 msg.read(data); 198 } 199 200 if (factory == null) { 201 sendChannelOpenFailure(remoteid, "This connection does not support the opening of channels"); 202 return; 203 } 204 205 Channel channel = null; 206 try { 207 channel = factory.createChannel(this, type); 208 if (channel == null) { 209 sendChannelOpenFailure(remoteid, "Failed to create channel of type " + type); 210 return; 211 } 212 } 213 catch(ChannelOpenException coe) { 214 sendChannelOpenFailure(remoteid, coe.getMessage() == null ? ( "Failed to create channel of type " + type + ". Reason " + coe.getReason()) : coe.getMessage() ); 215 return; 216 } 217 218 channel.init(this, remoteid, remotepacket, remotewindow); 219 220 if (allocateChannel(channel) == -1) { 221 sendChannelOpenFailure(remoteid, "Too many channels already open"); 222 223 } else { 224 try { 225 data = channel.open(data); 226 sendChannelOpenConfirmation(channel, data); 227 } 228 catch(ChannelOpenException coe) { 229 sendChannelOpenFailure(remoteid, coe.getMessage() == null ? ( "Failed to open channel. Reason " + coe.getReason()) : coe.getMessage() ); 230 return; 231 } 232 } 233 234 channel.fireChannelOpen(data); 235 236 } 237 238 protected void onChannelMessage(Message msg) throws IOException { 239 240 241 Integer channelid = new Integer ((int) msg.readInt()); 242 Channel channel = (Channel) channelsById.get(channelid); 243 244 if(log.isDebugEnabled()) 246 log.debug("Channel message [" + channelid + "]"); 247 249 if(channel != null) { 250 if(channel.isCompressionEnabled() && msg.getMessageId() == MSG_CHANNEL_DATA) { 251 252 ByteArrayOutputStream out = new ByteArrayOutputStream (); 253 byte[] tmp = new byte[65535]; 254 255 out.write(msg.array(), 0, msg.getPosition()); 256 257 ZStream compress = channel.getCompressionIn(); 258 259 compress.next_in = msg.array(); 260 compress.next_in_index = msg.getPosition()+4; 261 compress.avail_in = msg.available()-4; 262 263 int status; 264 265 do { 266 compress.next_out = tmp; 267 compress.next_out_index = 0; 268 compress.avail_out = tmp.length; 269 270 status = compress.deflate(JZlib.Z_PARTIAL_FLUSH); 271 switch(status) { 272 case JZlib.Z_OK: 273 out.write(tmp, 0, tmp.length - compress.avail_out); 274 break; 275 default: 276 throw new IOException ("Compression Failure: deflate returned " + status); 277 } 278 } while(channel.getCompressionIn().avail_out==0); 279 280 msg = new Message(out.toByteArray()); 281 msg.skip(5); } 283 284 if (channel.processChannelMessage(msg)) { 285 channel.messageStore.addMessage(msg); 286 } 287 } else { 288 log.warn("Message received for non-existent channel id " + channelid); 290 292 } 293 } 294 295 protected void onChannelRequest(Message msg) throws IOException { 296 Integer channelid = new Integer ((int) msg.readInt()); 297 Channel channel = (Channel) channelsById.get(channelid); 298 299 if(log.isDebugEnabled()) 301 log.debug("Channel request [" + channelid + "]"); 302 304 if (channel != null) { 305 String requestName = msg.readString(); 306 boolean wantReply = msg.readBoolean(); 307 byte[] data = msg.readBinaryString(); 308 309 Request request = new Request(requestName, data); 310 311 if(channel.onChannelRequest(request)) { 312 if(log.isDebugEnabled()) 314 log.debug("Channel request success [" + channelid + "]"); 315 if(wantReply) { 317 Packet reply = new Packet(); 318 reply.write(MSG_CHANNEL_REQUEST_SUCCESS); 319 reply.writeInt(channelid.intValue()); 320 reply.writeString(request.getRequestName()); 321 reply.writeBinaryString(request.getRequestData()); 322 sendMessage(reply); 323 } 324 } else { 325 if(log.isDebugEnabled()) 327 log.debug("Channel request failure [" + channelid + "]"); 328 if(wantReply) { 330 Packet reply = new Packet(); 331 reply.write(MSG_CHANNEL_REQUEST_FAILURE); 332 reply.writeInt(channelid.intValue()); 333 reply.writeString(request.getRequestName()); 334 reply.writeBinaryString(request.getRequestData()); 335 sendMessage(reply); 336 } 337 338 } 339 340 } else { 341 log.warn("Message received for non-existent channel id " + channelid); 343 } 345 } 346 347 protected void onRequestSuccessOrFailure(Message msg) throws IOException { 348 if(log.isDebugEnabled()) 350 log.debug("Request success or failure"); 351 globalMessages.addMessage(msg); 353 } 354 355 public void registerRequestHandler(String requestName, RequestHandler handler) { 356 requestHandlers.put(requestName, handler); 357 } 358 359 public void unregisterRequestHandler(String requestName) { 360 requestHandlers.remove(requestName); 361 } 362 363 protected void onRequest(Message msg) throws IOException { 364 365 String requestName = msg.readString(); 366 boolean wantReply = msg.readBoolean(); 367 boolean success = false; 368 369 if(log.isDebugEnabled()) 371 log.debug("Request '" + requestName + "' wantsReply=" + wantReply); 372 374 byte[] data = null; 375 376 if(msg.available() > 0) 377 data = msg.readBinaryString(); 378 379 Request request = new Request(requestName, data); 380 381 if(requestHandlers.containsKey(requestName)) { 382 success = ((RequestHandler)requestHandlers.get(requestName)).processRequest(request, this); 383 } else 384 success = processRequest(request, null); 385 386 if(wantReply) { 387 388 Packet p = new Packet(); 389 p.write(success ? MSG_REQUEST_SUCCESS : MSG_REQUEST_FAILURE); 390 if(request.getRequestData()!=null) 391 p.writeBinaryString(request.getRequestData()); 392 393 sendMessage(p); 394 } 395 396 397 if(requestHandlers.containsKey(requestName)) { 398 ((RequestHandler)requestHandlers.get(requestName)).postReply(this); 399 } else 400 postReply(this); 401 402 } 403 404 protected void onDisconnect(Message msg) throws IOException { 405 int reason = (int) msg.readInt(); 406 String desc = msg.readString(); 407 408 if(log.isDebugEnabled()) 410 log.debug("Remote disconnected:" + desc + " (" + reason + ")"); 411 413 stop(); 414 } 415 416 private int allocateChannel(Channel channel) { 417 synchronized (channelsById) { 418 419 if(maxChannels > 0 && channelsById.size() >= maxChannels) 420 return -1; 421 422 Integer channelid = null; 423 do { 424 channelid = new Integer (nextChannelId++); 426 } while(channelsById.containsKey(channelid)); 427 428 channelsById.put(channelid, channel); 429 channel.channelid = channelid.intValue(); 430 activeChannels.addElement(channel); 431 totalChannels++; 432 return channelid.intValue(); 433 } 434 } 435 436 void freeChannel(Channel channel) { 437 synchronized (channelsById) { 438 channelsById.remove(new Integer (channel.channelid)); 439 activeChannels.removeElement(channel); 440 } 441 } 442 443 public boolean sendRequest(Request request, boolean wantReply) throws IOException { 444 return sendRequest(request, wantReply, 0); 445 } 446 447 public boolean sendRequest(Request request, boolean wantReply, int timeoutMs) throws IOException { 448 449 log.info("Sending request " + request.getRequestName() + " timeout=" + timeoutMs + " wantReply=" + wantReply); 451 synchronized(sendLock) { 453 Packet msg = new Packet(); 454 msg.write(MSG_REQUEST); 455 msg.writeString(request.getRequestName()); 456 msg.writeBoolean(wantReply); 457 if(request.getRequestData()!=null) { 458 msg.writeBinaryString(request.getRequestData()); 459 } 460 461 sendMessage(msg); 462 463 if(wantReply) { 464 465 if(Thread.currentThread() == thread) { 466 throw new IOException ("You cannot send requests that require replies on the protocol thread."); 467 } 468 469 Message reply = globalMessages.nextMessage(requestMessages, timeoutMs); 470 471 switch (reply.getMessageId()) { 472 case MSG_REQUEST_SUCCESS: 473 case MSG_REQUEST_FAILURE: 474 475 byte[] data = null; 476 if(reply.available() > 0) 477 data = reply.readBinaryString(); 478 479 request.setRequestData(data); 480 481 boolean success = reply.getMessageId()==MSG_REQUEST_SUCCESS; 482 483 log.info("Remote responded to request " + request.getRequestName() + " with success=" + success); 485 return success; 487 case MSG_DISCONNECT: 488 throw new EOFException ("Connection closed before request reply"); 489 default: 490 throw new IOException ("Unexpected reply in channel open procedure"); 491 } 492 } else 493 return true; 494 } 495 } 496 497 public boolean processRequest(Request request, MultiplexedConnection connection) { 498 return false; 499 } 500 501 public void postReply(MultiplexedConnection connection) { 502 } 503 504 public void openChannel(Channel channel)throws IOException , ChannelOpenException { 505 openChannel(channel, 0); 506 } 507 508 public void openChannel(Channel channel, int timeout) throws IOException , ChannelOpenException { 509 510 if(log.isDebugEnabled()) 512 log.debug("Open channel '" + channel.getType() + "' timeout=" + timeout); 513 515 synchronized (sendLock) { 516 517 byte[] data = channel.create(); 518 519 if (allocateChannel(channel) == -1) { 520 throw new ChannelOpenException(ChannelOpenException.CHANNEL_LIMIT_EXCEEDED, 521 "Failed to allocate channel: too many active channels"); 522 } 523 524 Packet msg = new Packet(); 525 msg.write(MSG_CHANNEL_OPEN); 526 msg.writeString(channel.getType()); 527 msg.writeInt(channel.channelid); 528 msg.writeInt(channel.getLocalPacket()); 529 msg.writeInt(channel.getLocalWindow()); 530 if (data != null) 531 msg.write(data); 532 533 sendMessage(msg); 534 535 try { 536 Message reply = channel.messageStore.nextMessage(channelOpenMessages, timeout); 537 538 switch (reply.getMessageId()) { 539 case MSG_CHANNEL_OPEN_CONFIRMATION: 540 int remoteid = (int) reply.readInt(); 541 int remotepacket = (int) reply.readInt(); 542 int remotewindow = (int) reply.readInt(); 543 544 data = null; 545 if(reply.available() > 0) { 546 data = new byte[reply.available()]; 547 reply.read(data); 548 } 549 channel.init(this, remoteid, remotepacket, remotewindow); 550 channel.fireChannelOpen(data); 551 break; 552 case MSG_CHANNEL_OPEN_FAILURE: 553 String desc = reply.readString(); 554 freeChannel(channel); 555 throw new ChannelOpenException(ChannelOpenException.CHANNEL_REFUSED, desc); 556 default: 557 throw new IOException ("Unexpected reply in channel open procedure"); 558 } 559 } catch(InterruptedIOException ex) { 560 throw new ChannelOpenException(ChannelOpenException.COMMUNICATION_TIMEOUT, "Timeout limit exceeded"); 561 } 562 } 563 564 } 565 566 public void sendChannelData(Channel channel, byte[] data, int off, int len) throws IOException { 567 Packet msg = new Packet(); 568 msg.write(MSG_CHANNEL_DATA); 569 570 if(channel.isCompressionEnabled()) { 571 572 ByteArrayOutputStream out = new ByteArrayOutputStream (); 573 byte[] tmp = new byte[65535]; 574 575 ZStream compress = channel.getCompressionOut(); 576 577 compress.next_in = data; 578 compress.next_in_index = off; 579 compress.avail_in = len; 580 581 int status; 582 583 do { 584 compress.next_out = tmp; 585 compress.next_out_index = 0; 586 compress.avail_out = tmp.length; 587 588 status = compress.inflate(JZlib.Z_PARTIAL_FLUSH); 589 switch(status) { 590 case JZlib.Z_OK: 591 out.write(tmp, 0, tmp.length - compress.avail_out); 592 break; 593 default: 594 throw new IOException ("Compression Failure: inflate returned " + status); 595 } 596 } while(channel.getCompressionIn().avail_out==0); 597 598 data = out.toByteArray(); 599 off = 0; 600 len = data.length; 601 } 602 603 msg.writeInt(channel.remoteid); 604 msg.writeBinaryString(data, off, len); 605 606 sendMessage(msg); 607 } 608 609 private void sendChannelOpenFailure(int channelid, String desc) throws IOException { 610 Packet msg = new Packet(); 611 msg.write(MSG_CHANNEL_OPEN_FAILURE); 612 msg.writeInt(channelid); 613 msg.writeString(desc); 614 615 sendMessage(msg); 616 } 617 618 private void sendChannelOpenConfirmation(Channel channel, byte[] data) throws IOException { 619 Packet msg = new Packet(); 620 msg.write(MSG_CHANNEL_OPEN_CONFIRMATION); 621 msg.writeInt(channel.remoteid); 622 msg.writeInt(channel.channelid); 623 msg.writeInt(channel.getLocalPacket()); 624 msg.writeInt(channel.getLocalWindow()); 625 if (data != null) 626 msg.write(data); 627 628 sendMessage(msg); 629 } 630 631 public void sendWindowAdjust(Channel channel, int increment) throws IOException { 632 Packet msg = new Packet(); 633 msg.write(MSG_CHANNEL_WINDOW_ADJUST); 634 msg.writeInt(channel.remoteid); 635 msg.writeInt(increment); 636 637 sendMessage(msg); 638 } 639 640 public void closeChannel(Channel channel) throws IOException { 641 Packet msg = new Packet(); 642 msg.write(MSG_CHANNEL_CLOSE); 643 msg.writeInt(channel.remoteid); 644 645 sendMessage(msg); 646 } 647 648 public void closeAllChannels() { 649 650 Channel channel; 651 for(Enumeration e = channelsById.elements(); e.hasMoreElements();) { 652 channel = (Channel) e.nextElement(); 653 try { 654 channel.close(); 655 } catch(Throwable t) { } 656 } 657 658 channelsById.clear(); 659 } 660 661 public void disconnect(String desc) { 662 663 if(log.isDebugEnabled()) 665 log.debug("Disconnecting multiplexed connection"); 666 668 running = false; 669 670 closeAllChannels(); 671 672 try { 673 Packet msg = new Packet(); 674 msg.write(MSG_DISCONNECT); 675 msg.writeInt(0); 676 msg.writeString(desc); 677 678 sendMessage(msg); 679 680 } catch (IOException ex) { 681 if(log.isDebugEnabled()) 683 log.debug("Error on disconnect", ex); 684 } finally { 686 try { 687 in.close(); 688 } catch (Throwable t) { 689 } 690 try { 691 out.close(); 692 } catch (Throwable t) { 693 } 694 695 696 } 697 } 698 699 protected void sendMessage(Packet msg) throws IOException { 700 if(log.isDebugEnabled()) 702 log.debug("Sending message of " + msg.size() + " bytes"); 703 705 msg.prepare(); 706 707 out.write(msg.array(), 0, msg.size()); 709 out.flush(); 710 711 if(log.isDebugEnabled()) 713 log.debug("Sent message of " + msg.size() + " bytes"); 714 } 716 717 public void waitForProtocolStart(long timeout) throws InterruptedException { 718 synchronized(startLock) { 719 if(!isRunning()) { 720 startLock.wait(timeout); 721 } 722 } 723 } 724 725 public boolean isRunning() { 726 return running; 727 } 728 729 private void setStreams(InputStream _in, OutputStream _out) { 730 this.in = new DataInputStream (_in); 731 this.out = new DataOutputStream (_out); 732 } 733 734 735 private void runProtocol() { 736 737 738 try { 739 running = true; 740 synchronized (startLock) { 741 startLock.notifyAll(); 742 } 743 744 for (Enumeration it = listeners.elements(); it.hasMoreElements();) { 745 ((MultiplexedConnectionListener) it.nextElement()).onConnectionOpen(); 746 } 747 748 while (running) { 749 750 try { 751 752 int msglength = in.readInt(); 753 754 if (msglength <= 0) { 755 log.error("Invalid message length of " + msglength + " bytes"); 757 stop(); 759 } else { 760 761 byte[] tmp = new byte[msglength]; 762 in.readFully(tmp); 763 764 Message msg = new Message(tmp); 765 766 switch (msg.getMessageId()) { 767 case MSG_CHANNEL_OPEN: 768 onChannelOpen(msg); 769 break; 770 case MSG_CHANNEL_OPEN_CONFIRMATION: 771 onChannelMessage(msg); 772 break; 773 case MSG_CHANNEL_OPEN_FAILURE: 774 onChannelMessage(msg); 775 break; 776 case MSG_CHANNEL_DATA: 777 onChannelMessage(msg); 778 break; 779 case MSG_CHANNEL_REQUEST: 780 onChannelRequest(msg); 781 break; 782 case MSG_CHANNEL_WINDOW_ADJUST: 783 onChannelMessage(msg); 784 break; 785 case MSG_CHANNEL_CLOSE: 786 onChannelMessage(msg); 787 break; 788 case MSG_CHANNEL_REQUEST_SUCCESS: 789 case MSG_CHANNEL_REQUEST_FAILURE: 790 onChannelMessage(msg); 791 break; 792 case MSG_DISCONNECT: 793 onDisconnect(msg); 794 break; 795 case MSG_REQUEST_SUCCESS: 796 case MSG_REQUEST_FAILURE: 797 onRequestSuccessOrFailure(msg); 798 break; 799 case MSG_REQUEST: 800 onRequest(msg); 801 break; 802 default: 803 throw new IOException ("Unexpected message id " + msg.getMessageId()); 804 805 } 806 807 } 808 } catch(InterruptedIOException ex) { 809 if(timeoutCallback!=null && timeoutCallback.isAlive(MultiplexedConnection.this)) { 810 continue; 811 } 812 if(running) { 813 log.error("Multiplexed connection timed out", ex); 815 stop(); 817 } 818 } catch (IOException ex) { 819 if(running) { 820 if(!(ex instanceof EOFException )) { 821 log.error("Multiplexed connection thread failed", ex); 823 } 825 stop(); 826 } 827 } 828 829 } 830 831 } finally { 832 for (Enumeration it = listeners.elements(); it.hasMoreElements();) { 833 ((MultiplexedConnectionListener) it.nextElement()).onConnectionClose(); 834 } 835 } 836 } 837 838 class Runner implements Runnable { 839 840 public Runner() { 841 } 842 845 public void run() { 846 runProtocol(); 847 } 848 } 849 850 public void setTimeoutCallback(TimeoutCallback timeoutCallback) { 851 this.timeoutCallback = timeoutCallback; 852 } 853 } 854 | Popular Tags |