00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include "socket.h"
00023 #include "usbwrap.h"
00024 #include "data.h"
00025 #include "protocol.h"
00026 #include "protostructs.h"
00027 #include "endian.h"
00028 #include "debug.h"
00029 #include "packet.h"
00030 #include "sha1.h"
00031 #include <sstream>
00032 #include <string.h>
00033
00034 using namespace Usb;
00035
00036
00037 namespace Barry {
00038
00039
00040
00041
00042
00043 SocketZero::SocketZero( SocketRoutingQueue &queue,
00044 int writeEndpoint,
00045 uint8_t zeroSocketSequenceStart)
00046 : m_dev(0),
00047 m_queue(&queue),
00048 m_writeEp(writeEndpoint),
00049 m_readEp(0),
00050 m_zeroSocketSequence(zeroSocketSequenceStart),
00051 m_sequenceId(0),
00052 m_halfOpen(false),
00053 m_challengeSeed(0),
00054 m_remainingTries(0)
00055 {
00056 }
00057
00058 SocketZero::SocketZero( Device &dev,
00059 int writeEndpoint, int readEndpoint,
00060 uint8_t zeroSocketSequenceStart)
00061 : m_dev(&dev),
00062 m_queue(0),
00063 m_writeEp(writeEndpoint),
00064 m_readEp(readEndpoint),
00065 m_zeroSocketSequence(zeroSocketSequenceStart),
00066 m_sequenceId(0),
00067 m_halfOpen(false),
00068 m_challengeSeed(0),
00069 m_remainingTries(0)
00070 {
00071 }
00072
00073 SocketZero::~SocketZero()
00074 {
00075
00076 }
00077
00078
00079
00080
00081
00082
00083
00084
00085 void SocketZero::AppendFragment(Data &whole, const Data &fragment)
00086 {
00087 if( whole.GetSize() == 0 ) {
00088
00089 whole = fragment;
00090 }
00091 else {
00092
00093 int size = whole.GetSize();
00094 unsigned char *buf = whole.GetBuffer(size + fragment.GetSize());
00095 MAKE_PACKET(fpack, fragment);
00096 int fragsize = fragment.GetSize() - SB_FRAG_HEADER_SIZE;
00097
00098 memcpy(buf+size, &fpack->u.db.u.fragment, fragsize);
00099 whole.ReleaseBuffer(size + fragsize);
00100 }
00101
00102
00103 Barry::Protocol::Packet *wpack = (Barry::Protocol::Packet *) whole.GetBuffer();
00104 wpack->size = htobs((uint16_t) whole.GetSize());
00105 wpack->command = SB_COMMAND_DB_DATA;
00106
00107
00108 }
00109
00110
00111
00112
00113 unsigned int SocketZero::MakeNextFragment(const Data &whole, Data &fragment, unsigned int offset)
00114 {
00115
00116 if( whole.GetSize() < SB_FRAG_HEADER_SIZE ) {
00117 eout("Whole packet too short to fragment: " << whole.GetSize());
00118 throw Error("Socket: Whole packet too short to fragment");
00119 }
00120
00121
00122 unsigned int todo = whole.GetSize() - SB_FRAG_HEADER_SIZE - offset;
00123 unsigned int nextOffset = 0;
00124 if( todo > (MAX_PACKET_SIZE - SB_FRAG_HEADER_SIZE) ) {
00125 todo = MAX_PACKET_SIZE - SB_FRAG_HEADER_SIZE;
00126 nextOffset = offset + todo;
00127 }
00128
00129
00130 unsigned char *buf = fragment.GetBuffer(SB_FRAG_HEADER_SIZE + todo);
00131 memcpy(buf, whole.GetData(), SB_FRAG_HEADER_SIZE);
00132
00133
00134 memcpy(buf + SB_FRAG_HEADER_SIZE, whole.GetData() + SB_FRAG_HEADER_SIZE + offset, todo);
00135
00136
00137 Barry::Protocol::Packet *wpack = (Barry::Protocol::Packet *) buf;
00138 wpack->size = htobs((uint16_t) (todo + SB_FRAG_HEADER_SIZE));
00139 if( nextOffset )
00140 wpack->command = SB_COMMAND_DB_FRAGMENTED;
00141 else
00142 wpack->command = SB_COMMAND_DB_DATA;
00143
00144
00145 fragment.ReleaseBuffer(SB_FRAG_HEADER_SIZE + todo);
00146
00147
00148 return nextOffset;
00149 }
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160 void SocketZero::CheckSequence(uint16_t socket, const Data &seq)
00161 {
00162 MAKE_PACKET(spack, seq);
00163 if( (unsigned int) seq.GetSize() < SB_SEQUENCE_PACKET_SIZE ) {
00164 eout("Short sequence packet:\n" << seq);
00165 throw Error("Socket: invalid sequence packet");
00166 }
00167
00168
00169
00170 uint32_t sequenceId = btohl(spack->u.sequence.sequenceId);
00171 if( sequenceId == 0 ) {
00172
00173 m_sequenceId = 0;
00174 }
00175 else {
00176 if( sequenceId != m_sequenceId ) {
00177 if( socket != 0 ) {
00178 std::ostringstream oss;
00179 oss << "Socket 0x" << std::hex << (unsigned int)socket
00180 << ": out of sequence. "
00181 << "(Global sequence: " << m_sequenceId
00182 << ". Packet sequence: " << sequenceId
00183 << ")";
00184 eout(oss.str());
00185 throw Error(oss.str());
00186 }
00187 else {
00188 dout("Bad sequence on socket 0: expected: "
00189 << msequenceId
00190 << ". Packet sequence: " << sequenceId);
00191 }
00192 }
00193 }
00194
00195
00196 m_sequenceId++;
00197 }
00198
00199 void SocketZero::SendOpen(uint16_t socket, Data &receive)
00200 {
00201
00202 Barry::Protocol::Packet packet;
00203 packet.socket = 0;
00204 packet.size = htobs(SB_SOCKET_PACKET_HEADER_SIZE);
00205 packet.command = SB_COMMAND_OPEN_SOCKET;
00206 packet.u.socket.socket = htobs(socket);
00207 packet.u.socket.sequence = m_zeroSocketSequence;
00208
00209 Data send(&packet, SB_SOCKET_PACKET_HEADER_SIZE);
00210 try {
00211 RawSend(send);
00212 RawReceive(receive);
00213 } catch( Usb::Error & ) {
00214 eeout(send, receive);
00215 throw;
00216 }
00217
00218
00219 Protocol::CheckSize(receive);
00220 if( IS_COMMAND(receive, SB_COMMAND_SEQUENCE_HANDSHAKE) ) {
00221 CheckSequence(0, receive);
00222
00223
00224 RawReceive(receive);
00225 }
00226
00227
00228 }
00229
00230
00231 void SocketZero::SendPasswordHash(uint16_t socket, const char *password, Data &receive)
00232 {
00233 unsigned char pwdigest[SHA_DIGEST_LENGTH];
00234 unsigned char prefixedhash[SHA_DIGEST_LENGTH + 4];
00235
00236
00237 SHA1((unsigned char *) password, strlen(password), pwdigest);
00238
00239
00240 uint32_t seed = htobl(m_challengeSeed);
00241 memcpy(&prefixedhash[0], &seed, sizeof(uint32_t));
00242 memcpy(&prefixedhash[4], pwdigest, SHA_DIGEST_LENGTH);
00243
00244
00245 SHA1((unsigned char *) prefixedhash, SHA_DIGEST_LENGTH + 4, pwdigest);
00246
00247
00248 size_t size = SB_SOCKET_PACKET_HEADER_SIZE + PASSWORD_CHALLENGE_SIZE;
00249
00250
00251 Barry::Protocol::Packet packet;
00252 packet.socket = 0;
00253 packet.size = htobs(size);
00254 packet.command = SB_COMMAND_PASSWORD;
00255 packet.u.socket.socket = htobs(socket);
00256 packet.u.socket.sequence = m_zeroSocketSequence;
00257 packet.u.socket.u.password.remaining_tries = 0;
00258 packet.u.socket.u.password.unknown = 0;
00259 packet.u.socket.u.password.param = htobs(0x14);
00260 memcpy(packet.u.socket.u.password.u.hash, pwdigest,
00261 sizeof(packet.u.socket.u.password.u.hash));
00262
00263
00264 memset(pwdigest, 0, sizeof(pwdigest));
00265 memset(prefixedhash, 0, sizeof(prefixedhash));
00266
00267 Data send(&packet, size);
00268 RawSend(send);
00269 RawReceive(receive);
00270
00271
00272 memset(packet.u.socket.u.password.u.hash, 0,
00273 sizeof(packet.u.socket.u.password.u.hash));
00274 send.Zap();
00275
00276
00277 Protocol::CheckSize(receive);
00278 if( IS_COMMAND(receive, SB_COMMAND_SEQUENCE_HANDSHAKE) ) {
00279 CheckSequence(0, receive);
00280
00281
00282 RawReceive(receive);
00283 }
00284
00285
00286 }
00287
00288 void SocketZero::RawSend(Data &send, int timeout)
00289 {
00290 Usb::Device *dev = m_queue ? m_queue->GetUsbDevice() : m_dev;
00291
00292
00293
00294
00295
00296
00297
00298
00299 if( (send.GetSize() % 0x40) == 0 ) {
00300 Protocol::SizePacket packet;
00301 packet.size = htobs(send.GetSize());
00302 packet.buffer[2] = 0;
00303 Data sizeCommand(&packet, 3);
00304
00305 dev->BulkWrite(m_writeEp, sizeCommand);
00306 }
00307
00308 dev->BulkWrite(m_writeEp, send);
00309 }
00310
00311 void SocketZero::RawReceive(Data &receive, int timeout)
00312 {
00313 do {
00314 if( m_queue ) {
00315 if( !m_queue->DefaultRead(receive, timeout) )
00316 throw Timeout("SocketZero::RawReceive: queue DefaultRead returned false (likely a timeout)");
00317 }
00318 else {
00319 m_dev->BulkRead(m_readEp, receive, timeout);
00320 }
00321 ddout("SocketZero::RawReceive: Endpoint " << m_readEp
00322 << "\nReceived:\n" << receive);
00323 } while( SequencePacket(receive) );
00324 }
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340 bool SocketZero::SequencePacket(const Data &data)
00341 {
00342 if( data.GetSize() >= MIN_PACKET_SIZE ) {
00343 if( IS_COMMAND(data, SB_COMMAND_SEQUENCE_HANDSHAKE) ) {
00344 CheckSequence(0, data);
00345 return true;
00346 }
00347 }
00348 return false;
00349 }
00350
00351
00352
00353
00354
00355 void SocketZero::SetRoutingQueue(SocketRoutingQueue &queue)
00356 {
00357
00358 m_queue = &queue;
00359 }
00360
00361 void SocketZero::UnlinkRoutingQueue()
00362 {
00363 m_queue = 0;
00364 }
00365
00366 void SocketZero::Send(Data &send, int timeout)
00367 {
00368
00369 if( send.GetSize() >= SB_SOCKET_PACKET_HEADER_SIZE ) {
00370 MAKE_PACKETPTR_BUF(spack, send.GetBuffer());
00371 spack->socket = 0;
00372 }
00373
00374
00375
00376 if( send.GetSize() >= SB_SOCKET_PACKET_HEADER_SIZE ) {
00377 MAKE_PACKETPTR_BUF(spack, send.GetBuffer());
00378 spack->u.socket.sequence = m_zeroSocketSequence;
00379 m_zeroSocketSequence++;
00380 }
00381
00382 RawSend(send, timeout);
00383 }
00384
00385 void SocketZero::Send(Data &send, Data &receive, int timeout)
00386 {
00387 Send(send, timeout);
00388 RawReceive(receive, timeout);
00389 }
00390
00391 void SocketZero::Send(Barry::Packet &packet, int timeout)
00392 {
00393 Send(packet.m_send, packet.m_receive, timeout);
00394 }
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420 SocketHandle SocketZero::Open(uint16_t socket, const char *password)
00421 {
00422
00423
00424
00425
00426
00427 Data send, receive;
00428 ZeroPacket packet(send, receive);
00429
00430
00431 uint8_t closeFlag = GetZeroSocketSequence();
00432
00433 if( !m_halfOpen ) {
00434
00435 m_remainingTries = 0;
00436
00437 SendOpen(socket, receive);
00438
00439
00440 if( packet.Command() == SB_COMMAND_PASSWORD_CHALLENGE ) {
00441 m_halfOpen = true;
00442 m_challengeSeed = packet.ChallengeSeed();
00443 m_remainingTries = packet.RemainingTries();
00444 }
00445
00446
00447 }
00448
00449 if( m_halfOpen ) {
00450
00451
00452 if( !password ) {
00453 throw BadPassword("No password specified.", m_remainingTries, false);
00454 }
00455
00456
00457
00458
00459
00460 if( m_remainingTries < BARRY_MIN_PASSWORD_TRIES ) {
00461 throw BadPassword("Fewer than " BARRY_MIN_PASSWORD_TRIES_ASC " password tries remaining in device. Refusing to proceed, to avoid device zapping itself. Use a Windows client, or re-cradle the device.",
00462 m_remainingTries,
00463 true);
00464 }
00465
00466
00467 closeFlag = GetZeroSocketSequence();
00468
00469 SendPasswordHash(socket, password, receive);
00470
00471 if( packet.Command() == SB_COMMAND_PASSWORD_FAILED ) {
00472 m_halfOpen = true;
00473 m_challengeSeed = packet.ChallengeSeed();
00474 m_remainingTries = packet.RemainingTries();
00475 throw BadPassword("Password rejected by device.", m_remainingTries, false);
00476 }
00477
00478
00479
00480 m_halfOpen = false;
00481
00482
00483 }
00484
00485 if( packet.Command() != SB_COMMAND_OPENED_SOCKET ||
00486 packet.SocketResponse() != socket ||
00487 packet.SocketSequence() != closeFlag )
00488 {
00489 eout("Packet:\n" << receive);
00490 throw Error("Socket: Bad OPENED packet in Open");
00491 }
00492
00493
00494 return SocketHandle(new Socket(*this, socket, closeFlag));
00495 }
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507 void SocketZero::Close(Socket &socket)
00508 {
00509 if( socket.GetSocket() == 0 )
00510 return;
00511
00512
00513 Barry::Protocol::Packet packet;
00514 packet.socket = 0;
00515 packet.size = htobs(SB_SOCKET_PACKET_HEADER_SIZE);
00516 packet.command = SB_COMMAND_CLOSE_SOCKET;
00517 packet.u.socket.socket = htobs(socket.GetSocket());
00518 packet.u.socket.sequence = socket.GetCloseFlag();
00519
00520 Data command(&packet, SB_SOCKET_PACKET_HEADER_SIZE);
00521 Data response;
00522 try {
00523 Send(command, response);
00524 }
00525 catch( Usb::Error & ) {
00526
00527 socket.ForceClosed();
00528
00529 eeout(command, response);
00530 throw;
00531 }
00532
00533
00534 Protocol::CheckSize(response);
00535 if( IS_COMMAND(response, SB_COMMAND_SEQUENCE_HANDSHAKE) ) {
00536 CheckSequence(0, response);
00537
00538
00539 RawReceive(response);
00540 }
00541
00542 Protocol::CheckSize(response, SB_SOCKET_PACKET_HEADER_SIZE);
00543 MAKE_PACKET(rpack, response);
00544 if( rpack->command != SB_COMMAND_CLOSED_SOCKET ||
00545 btohs(rpack->u.socket.socket) != socket.GetSocket() ||
00546 rpack->u.socket.sequence != socket.GetCloseFlag() )
00547 {
00548
00549 socket.ForceClosed();
00550
00551 eout("Packet:\n" << response);
00552 throw Error("Socket: Bad CLOSED packet in Close");
00553 }
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564 socket.ForceClosed();
00565 }
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575 Socket::Socket( SocketZero &zero,
00576 uint16_t socket,
00577 uint8_t closeFlag)
00578 : m_zero(&zero)
00579 , m_socket(socket)
00580 , m_closeFlag(closeFlag)
00581 , m_registered(false)
00582 {
00583 }
00584
00585 Socket::~Socket()
00586 {
00587
00588 try {
00589
00590 Close();
00591 }
00592 catch( std::runtime_error &re ) {
00593
00594 dout("Exception caught in ~Socket: " << re.what());
00595 }
00596 }
00597
00598
00599
00600
00601
00602 void Socket::CheckSequence(const Data &seq)
00603 {
00604 m_zero->CheckSequence(m_socket, seq);
00605 }
00606
00607 void Socket::ForceClosed()
00608 {
00609 m_socket = 0;
00610 m_closeFlag = 0;
00611 }
00612
00613
00614
00615
00616
00617 void Socket::Close()
00618 {
00619 UnregisterInterest();
00620 m_zero->Close(*this);
00621 }
00622
00623
00624
00625
00626
00627
00628
00629
00630
00631
00632
00633 void Socket::Send(Data &send, int timeout)
00634 {
00635
00636 if( send.GetSize() >= SB_PACKET_HEADER_SIZE ) {
00637 MAKE_PACKETPTR_BUF(spack, send.GetBuffer());
00638 spack->socket = htobs(m_socket);
00639 }
00640 m_zero->RawSend(send, timeout);
00641 }
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652 void Socket::Send(Data &send, Data &receive, int timeout)
00653 {
00654 Send(send, timeout);
00655 Receive(receive, timeout);
00656 }
00657
00658 void Socket::Send(Barry::Packet &packet, int timeout)
00659 {
00660 Send(packet.m_send, packet.m_receive, timeout);
00661 }
00662
00663 void Socket::Receive(Data &receive, int timeout)
00664 {
00665 if( m_registered ) {
00666 if( m_zero->m_queue ) {
00667 if( !m_zero->m_queue->SocketRead(m_socket, receive, timeout) )
00668 throw Timeout("Socket::Receive: queue SocketRead returned false (likely a timeout)");
00669 }
00670 else {
00671 throw std::logic_error("NULL queue pointer in a registered socket read.");
00672 }
00673 }
00674 else {
00675 m_zero->RawReceive(receive, timeout);
00676 }
00677 }
00678
00679
00680
00681
00682
00683 void Socket::Packet(Data &send, Data &receive, int timeout)
00684 {
00685
00686
00687
00688
00689
00690
00691
00692
00693
00694
00695
00696 MAKE_PACKET(spack, send);
00697 if( send.GetSize() < MIN_PACKET_SIZE ||
00698 (spack->command != SB_COMMAND_DB_DATA &&
00699 spack->command != SB_COMMAND_DB_DONE) )
00700 {
00701
00702 throw std::logic_error("Socket: unknown send data in Packet()");
00703 }
00704
00705 Data inFrag;
00706 receive.Zap();
00707
00708 if( send.GetSize() <= MAX_PACKET_SIZE ) {
00709
00710 Send(send, inFrag, timeout);
00711 }
00712 else {
00713
00714 unsigned int offset = 0;
00715 Data outFrag;
00716
00717 do {
00718 offset = SocketZero::MakeNextFragment(send, outFrag, offset);
00719 Send(outFrag, inFrag, timeout);
00720
00721 MAKE_PACKET(rpack, inFrag);
00722
00723
00724
00725 if( offset && inFrag.GetSize() > 0 ) {
00726
00727 Protocol::CheckSize(inFrag);
00728
00729 switch( rpack->command )
00730 {
00731 case SB_COMMAND_SEQUENCE_HANDSHAKE:
00732 CheckSequence(inFrag);
00733 break;
00734
00735 default: {
00736 std::ostringstream oss;
00737 oss << "Socket: unhandled packet in Packet() (send): 0x" << std::hex << (unsigned int)rpack->command;
00738 eout(oss.str());
00739 throw Error(oss.str());
00740 }
00741 break;
00742 }
00743 }
00744
00745
00746 } while( offset > 0 );
00747 }
00748
00749 bool done = false, frag = false;
00750 int blankCount = 0;
00751 while( !done ) {
00752 MAKE_PACKET(rpack, inFrag);
00753
00754
00755 if( inFrag.GetSize() > 0 ) {
00756 blankCount = 0;
00757
00758 Protocol::CheckSize(inFrag);
00759
00760 switch( rpack->command )
00761 {
00762 case SB_COMMAND_SEQUENCE_HANDSHAKE:
00763 CheckSequence(inFrag);
00764 break;
00765
00766 case SB_COMMAND_DB_DATA:
00767 if( frag ) {
00768 SocketZero::AppendFragment(receive, inFrag);
00769 }
00770 else {
00771 receive = inFrag;
00772 }
00773 done = true;
00774 break;
00775
00776 case SB_COMMAND_DB_FRAGMENTED:
00777 SocketZero::AppendFragment(receive, inFrag);
00778 frag = true;
00779 break;
00780
00781 case SB_COMMAND_DB_DONE:
00782 receive = inFrag;
00783 done = true;
00784 break;
00785
00786 default: {
00787 std::ostringstream oss;
00788 oss << "Socket: unhandled packet in Packet() (read): 0x" << std::hex << (unsigned int)rpack->command;
00789 eout(oss.str());
00790 throw Error(oss.str());
00791 }
00792 break;
00793 }
00794 }
00795 else {
00796 blankCount++;
00797
00798 if( blankCount == 10 ) {
00799
00800
00801 throw Error("Socket: 10 blank packets received");
00802 }
00803 }
00804
00805 if( !done ) {
00806
00807 Receive(inFrag);
00808 }
00809 }
00810 }
00811
00812 void Socket::Packet(Barry::Packet &packet, int timeout)
00813 {
00814 Packet(packet.m_send, packet.m_receive, timeout);
00815 }
00816
00817 void Socket::NextRecord(Data &receive)
00818 {
00819 Barry::Protocol::Packet packet;
00820 packet.socket = htobs(GetSocket());
00821 packet.size = htobs(7);
00822 packet.command = SB_COMMAND_DB_DONE;
00823 packet.u.db.tableCmd = 0;
00824 packet.u.db.u.command.operation = 0;
00825
00826 Data command(&packet, 7);
00827 Packet(command, receive);
00828 }
00829
00830 void Socket::RegisterInterest(SocketRoutingQueue::SocketDataHandler handler,
00831 void *context)
00832 {
00833 if( !m_zero->m_queue )
00834 throw std::logic_error("SocketRoutingQueue required in SocketZero in order to call Socket::RegisterInterest()");
00835
00836 if( m_registered )
00837 throw std::logic_error("Socket already registered in Socket::RegisterInterest()!");
00838
00839 m_zero->m_queue->RegisterInterest(m_socket, handler, context);
00840 m_registered = true;
00841 }
00842
00843 void Socket::UnregisterInterest()
00844 {
00845 if( m_registered ) {
00846 if( m_zero->m_queue )
00847 m_zero->m_queue->UnregisterInterest(m_socket);
00848 m_registered = false;
00849 }
00850 }
00851
00852
00853 }
00854