Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I want to make a Trojan Client , and need some help. #701

Open
iwaitu opened this issue Jul 4, 2024 · 0 comments
Open

I want to make a Trojan Client , and need some help. #701

iwaitu opened this issue Jul 4, 2024 · 0 comments

Comments

@iwaitu
Copy link

iwaitu commented Jul 4, 2024

#include "pch.h"
#include "TrojanClient.h"
#include "trojanrequest.h"
#include <iostream>
#include <cstring>
#include <openssl/sha.h>
#include <iomanip>
#include <sstream>
#include <string>

#include <regex>
#include <string>
#include <stdexcept>

struct ParsedURL {
    std::string hostname;
    std::string path;
};

ParsedURL ParseURL(const std::string& url) {
    const std::regex url_regex(R"((https?://)?([^/]+)(/.*)?)");
    std::smatch match;

    if (std::regex_match(url, match, url_regex)) {
        std::string hostname = match[2];
        std::string path = match[3].str().empty() ? "/" : match[3].str();
        return { hostname, path };
    }
    else {
        throw std::invalid_argument("Invalid URL");
    }
}

TrojanClient::TrojanClient(const std::string& server_ip, int server_port, const std::string& password, const std::string& sni)
    : server_ip_(server_ip), server_port_(server_port), password_(password), sni_(sni), ctx_(nullptr, SSL_CTX_free), ssl_(nullptr, SSL_free), server_fd_(INVALID_SOCKET), connected_(false) {
    try {
        SSL_library_init();
        OpenSSL_add_all_algorithms();
        SSL_load_error_strings();

        ctx_.reset(InitializeSSL());
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Initialization error: " + std::string(e.what()));
        throw;
    }
}

TrojanClient::~TrojanClient() {
    try {
        Disconnect();
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Error during disconnection: " + std::string(e.what()));
    }
}

SSL_CTX* TrojanClient::InitializeSSL() {
    // 初始化SSL库
    SSL_library_init();
    OpenSSL_add_all_algorithms();
    SSL_load_error_strings();

    // 使用TLS客户端方法
    const SSL_METHOD* method = TLS_client_method();
    SSL_CTX* ctx = SSL_CTX_new(method);
    if (!ctx) {
        ERR_print_errors_fp(stderr);
        throw std::runtime_error("SSL_CTX_new failed");
    }

    // 设置最低和最高的TLS版本
    if (SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION) == 0) {
        ERR_print_errors_fp(stderr);
        SSL_CTX_free(ctx);
        throw std::runtime_error("Failed to set minimum TLS version to 1.2");
    }

    if (SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION) == 0) {
        ERR_print_errors_fp(stderr);
        SSL_CTX_free(ctx);
        throw std::runtime_error("Failed to set maximum TLS version to 1.3");
    }

    // 设置首选密码套件
    const char* const PREFERRED_CIPHERS =
        "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:TLS_AES_128_GCM_SHA256:"
        "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:"
        "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:"
        "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256";

    if (SSL_CTX_set_cipher_list(ctx, PREFERRED_CIPHERS) == 0) {
        ERR_print_errors_fp(stderr);
        SSL_CTX_free(ctx);
        throw std::runtime_error("Failed to set preferred ciphers");
    }

    // 设置证书验证选项
    SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
    SSL_CTX_set_cert_verify_callback(ctx, [](X509_STORE_CTX* ctx, void* arg) -> int {
        return 1; // 跳过证书验证
        }, nullptr);

    return ctx;
}

bool TrojanClient::ConnectToServer() {
    std::lock_guard<std::recursive_mutex> lock(mtx_);

    try {
        WSADATA wsaData;
        if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
            TrojanClientException::ReportError("WSAStartup failed: " + std::to_string(WSAGetLastError()));
            return false;
        }

        server_fd_ = CreateSocket(server_ip_, server_port_);
        if (server_fd_ == INVALID_SOCKET) {
            TrojanClientException::ReportError("Unable to connect to server");
            WSACleanup();
            return false;
        }

        ssl_.reset(SSL_new(ctx_.get()));
        if (!ssl_) {
            ERR_print_errors_fp(stderr);
            TrojanClientException::ReportError("Failed to create SSL object");
            closesocket(server_fd_);
            WSACleanup();
            return false;
        }

        SSL_set_fd(ssl_.get(), server_fd_);
        SSL_set_tlsext_host_name(ssl_.get(), sni_.c_str());

        int ret;
        while ((ret = SSL_connect(ssl_.get())) <= 0) {
            int ssl_err = SSL_get_error(ssl_.get(), ret);
            if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) {
                ERR_print_errors_fp(stderr);
                TrojanClientException::ReportError("SSL_connect() failed");
                closesocket(server_fd_);
                WSACleanup();
                return false;
            }
            // 如果需要等待,可以使用select等待,或者使用sleep稍作延时再试
            Sleep(100);
        }

        const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_.get());
        connected_ = true;
        std::cout << "SSL password authentication succeeded, connection established!" << std::endl;
        return true;
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Error during connection: " + std::string(e.what()));
        Disconnect();
        return false;
    }
}

void TrojanClient::Disconnect() {
    std::lock_guard<std::recursive_mutex> lock(mtx_);

    try {
        if (ssl_) {
            SSL_shutdown(ssl_.get());
            ssl_.reset(nullptr);
        }
        if (server_fd_ != INVALID_SOCKET) {
            closesocket(server_fd_);
            server_fd_ = INVALID_SOCKET;
        }
        WSACleanup();
        connected_ = false;
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Error during disconnection: " + std::string(e.what()));
    }
}

bool TrojanClient::SendData(const std::string& data) {
    std::lock_guard<std::recursive_mutex> lock(mtx_);
    int totalSent = 0;
    int len = data.length();

    try {
        if (ssl_ && connected_) {
            std::cout << "Sending data: " << data << std::endl;
            while (totalSent < len) {
                int ret = SSL_write(ssl_.get(), data.c_str() + totalSent, len - totalSent);
                if (ret > 0) {
                    totalSent += ret;
                }
                else {
                    int ssl_err = SSL_get_error(ssl_.get(), ret);
                    if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
                        Sleep(100);
                        continue;
                    }
                    else {
                        std::cerr << "SSL_write failed with error: " << ssl_err << std::endl;
                        while (unsigned long err = ERR_get_error()) {
                            char err_msg[256];
                            ERR_error_string_n(err, err_msg, sizeof(err_msg));
                            TrojanClientException::ReportError("SSL_write() error detail: " + std::string(err_msg));
                        }
                        Disconnect();
                        return false;
                    }
                }
            }
            return true;
        }
        else {
            TrojanClientException::ReportError("SSL connection is not established or already closed");
        }
        return false;
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Error during data send: " + std::string(e.what()));
        Disconnect();
        return false;
    }
}


int TrojanClient::ReceiveData(char* buffer, int len) {
    std::lock_guard<std::recursive_mutex> lock(mtx_);
    int totalReceived = 0;
    int maxRetries = 10; // 最大重试次数
    bool endOfData = false;

    try {
        if (ssl_ && connected_) {
            while (totalReceived < len && maxRetries > 0 && !endOfData) {
                int ret = SSL_read(ssl_.get(), buffer + totalReceived, len - totalReceived);
                if (ret > 0) {
                    totalReceived += ret;
                    if (strstr(buffer, "\r\n")) {
                        endOfData = true;
                    }
                }
                else {
                    int ssl_err = SSL_get_error(ssl_.get(), ret);
                    if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
                        Sleep(100);
                        maxRetries--;
                        continue;
                    }
                    else if (ssl_err == SSL_ERROR_ZERO_RETURN) {
                        Disconnect();
                        return -1;
                    }
                    else {
                        Disconnect();
                        return -1;
                    }
                }
            }
            return totalReceived;
        }
        return -1;
    }
    catch (const std::exception& e) {
        return -1;
    }
}


SOCKET TrojanClient::CreateSocket(const std::string& ip, int port) {
    try {
        SOCKET sockfd;
        struct sockaddr_in serv_addr;

        if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
            TrojanClientException::ReportError("Socket creation error: " + std::to_string(WSAGetLastError()));
            return INVALID_SOCKET;
        }

        // 设置socket为非阻塞模式
        u_long mode = 1;  // 1 表示非阻塞,0 表示阻塞
        if (ioctlsocket(sockfd, FIONBIO, &mode) != 0) {
            TrojanClientException::ReportError("Failed to set socket to non-blocking mode: " + std::to_string(WSAGetLastError()));
            closesocket(sockfd);
            return INVALID_SOCKET;
        }

        serv_addr.sin_family = AF_INET;
        serv_addr.sin_port = htons(port);

        if (inet_pton(AF_INET, ip.c_str(), &serv_addr.sin_addr) <= 0) {
            TrojanClientException::ReportError("Invalid address/ Address not supported: " + std::to_string(WSAGetLastError()));
            closesocket(sockfd);
            return INVALID_SOCKET;
        }

        // 对于非阻塞socket,connect可能会立即返回SOCKET_ERROR
        if (connect(sockfd, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) == SOCKET_ERROR) {
            int error = WSAGetLastError();
            if (error != WSAEWOULDBLOCK) {
                TrojanClientException::ReportError("Connection failed: " + std::to_string(error));
                closesocket(sockfd);
                return INVALID_SOCKET;
            }
        }

        // 使用select来等待连接完成或超时
        fd_set write_fds, except_fds;
        struct timeval timeout;
        FD_ZERO(&write_fds);
        FD_ZERO(&except_fds);
        FD_SET(sockfd, &write_fds);
        FD_SET(sockfd, &except_fds);
        timeout.tv_sec = 10;  // 10秒超时
        timeout.tv_usec = 0;

        int select_result = select(0, NULL, &write_fds, &except_fds, &timeout);
        if (select_result == 0) {
            TrojanClientException::ReportError("Connection timeout");
            closesocket(sockfd);
            return INVALID_SOCKET;
        }
        else if (select_result == SOCKET_ERROR) {
            TrojanClientException::ReportError("Select failed: " + std::to_string(WSAGetLastError()));
            closesocket(sockfd);
            return INVALID_SOCKET;
        }

        if (FD_ISSET(sockfd, &except_fds)) {
            TrojanClientException::ReportError("Connection failed");
            closesocket(sockfd);
            return INVALID_SOCKET;
        }

        // 连接成功
        return sockfd;
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Error during socket creation: " + std::string(e.what()));
        return INVALID_SOCKET;
    }
}

bool TrojanClient::IsConnected() {
    std::lock_guard<std::recursive_mutex> lock(mtx_);
    if (!connected_ || !ssl_) {
        return false;
    }

    char buf;
    int ret = SSL_peek(ssl_.get(), &buf, 0);
    if (ret <= 0) {
        int ssl_error = SSL_get_error(ssl_.get(), ret);
        if (ssl_error != SSL_ERROR_WANT_READ && ssl_error != SSL_ERROR_WANT_WRITE) {
            Disconnect();
            return false;
        }
    }
    return true;
}

SOCKET TrojanClient::get_socket() const {
    return server_fd_;
}

std::string TrojanClient::ComputePasswordHash() {
    unsigned char hash[SHA224_DIGEST_LENGTH];
    SHA224(reinterpret_cast<const unsigned char*>(password_.c_str()), password_.size(), hash);
    std::stringstream ss;
    for (int i = 0; i < SHA224_DIGEST_LENGTH; ++i) {
        ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(hash[i]);
    }
    return ss.str();
}

std::string TrojanClient::GenerateTrojanRequest(const std::string& password_hash, const std::string& url, int port, bool tcp) {
    return TrojanRequest::generate(password_hash, url, port, tcp) + "\r\n";
}


std::string TrojanClient::Request(const std::string& url, int port, bool tcp) {
    std::lock_guard<std::recursive_mutex> lock(mtx_);
    ParsedURL parsed_url = ParseURL(url);
    try {
        if (!IsConnected() && !ConnectToServer()) {
            TrojanClientException::ReportError("Failed to connect to server.");
            return "Failed to connect to server.";
        }

        std::string password_hash = ComputePasswordHash();
        std::string trojanRequest = GenerateTrojanRequest(password_hash, parsed_url.hostname, port, tcp);

        if (!SendData(trojanRequest)) {
            return "Send Trojan request failed.";
        }

        std::string httpRequest =
            "GET " + parsed_url.path + " HTTP/1.1\r\nHost: " + parsed_url.hostname + "\r\n"
            "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:74.0) Gecko/20100101 Firefox/74.0 \r\n"
            "Accept-Type: */* \r\n"
            "Connection: close\r\n\r\n";
        if (!SendData(httpRequest)) {
            return "Send HTTP/HTTPS request failed.";
        }

        char responseBuffer[BUFFER_SIZE];
        int receivedLen = ReceiveData(responseBuffer, BUFFER_SIZE);
        if (receivedLen <= 0) {
            return "Failed to receive response.";
        }
        return std::string(responseBuffer, receivedLen);
    }
    catch (const std::exception& e) {
        TrojanClientException::ReportError("Error in Request method: " + std::string(e.what()));
        return "Error in Request method: " + std::string(e.what());
    }
}

console.cpp

int main() {
    
    std::string server_ip = "127.0.0.1";
    std::string password = "password1";

    int server_port = 443;              // 服务器端口
    
    std::string sni = "baidu.com";        // 任意字符串,用于测试

    TrojanClient client(server_ip, server_port, password, sni);

    if (client.ConnectToServer()) {
        // 发送测试数据
        std::string response = client.Request("ipinfo.io/ip", 80, true);
        std::cout << "Response: " << response << std::endl;
        client.Disconnect();
    }
    else {
        std::cerr << "Failed to connect to server." << std::endl;
    }

    //// 等待用户按回车键
    //std::cout << "Press Enter to exit...";
    //std::cin.get();

    return 0;
}

When I request ipinfo.io/ip by 80 port , I can get the repsonse data. But when I request ipinfo.io/ip by 443 port , I got error message.SSL_Read error ,error code: 6

I observe the server log, I found something wrong. When I request 443 port, this function:

void ServerSession::in_async_write(const std::string &data) {
    auto self = shared_from_this();
    auto data_copy = std::make_shared<std::string>(data);
    boost::asio::async_write(in_socket, boost::asio::buffer(*data_copy), [this, self, data_copy](const boost::system::error_code error, size_t) {
        if (error) {
            destroy();
            return;
        }
        in_sent();
    });
}

never be called. Who can Help me ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant