Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,44 @@ ws.onMessage([](std::variant<rtc::binary, rtc::string> message) {
ws.open("wss://my.websocket/service");
```

### WebSocket with Custom Headers

```cpp
rtc::WebSocket ws;

// Send custom headers during handshake
std::map<std::string, std::string> headers = {
{"Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"},
{"X-API-Key", "my-secret-api-key"},
{"User-Agent", "MyApp/1.0"}
};

ws.open("wss://my.websocket/service", headers);
```

### WebSocket Server with Header Access

```cpp
rtc::WebSocketServer server(config);

server.onClient([](std::shared_ptr<rtc::WebSocket> client) {
// Access request headers sent by the client
auto requestHeaders = client->requestHeaders();

for (const auto& [name, value] : requestHeaders) {
std::cout << name << ": " << value << std::endl;
}

// Check for specific headers
auto authIt = std::find_if(requestHeaders.begin(), requestHeaders.end(),
[](const auto& header) { return header.first == "authorization"; });

if (authIt != requestHeaders.end()) {
std::cout << "Client authenticated with: " << authIt->second << std::endl;
}
});
```

## Compatibility

The library implements the following communication protocols:
Expand Down
1 change: 1 addition & 0 deletions include/rtc/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <atomic>
#include <functional>
#include <map>

namespace rtc {

Expand Down
3 changes: 2 additions & 1 deletion include/rtc/websocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ class RTC_CPP_EXPORT WebSocket final : private CheshireCat<impl::WebSocket>, pub
bool isClosed() const override;
size_t maxMessageSize() const override;

void open(const string &url);
void open(const string &url, const std::map<string, string> &headers = {});
void close() override;
void forceClose();
bool send(const message_variant data) override;
bool send(const byte *data, size_t size) override;

optional<string> remoteAddress() const;
optional<string> path() const;
std::multimap<string, string> requestHeaders() const;

private:
using CheshireCat<impl::WebSocket>::impl;
Expand Down
2 changes: 1 addition & 1 deletion src/capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ int rtcAddTrackEx(int pc, const rtcTrackInit *init) {
case RTC_CODEC_OPUS:
case RTC_CODEC_PCMU:
case RTC_CODEC_PCMA:
case RTC_CODEC_AAC:
case RTC_CODEC_AAC:
case RTC_CODEC_G722: {
auto audio = std::make_unique<Description::Audio>(mid, direction);
switch (init->codec) {
Expand Down
4 changes: 2 additions & 2 deletions src/impl/websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certific

WebSocket::~WebSocket() { PLOG_VERBOSE << "Destroying WebSocket"; }

void WebSocket::open(const string &url) {
void WebSocket::open(const string &url, const std::map<string, string> &headers) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, but since you plan to store the map, you should pass it by value and move it into the WsHandshake() constructor.

PLOG_VERBOSE << "Opening WebSocket to URL: " << url;

if (state != State::Closed)
Expand Down Expand Up @@ -126,7 +126,7 @@ void WebSocket::open(const string &url) {

mHostname = hostname; // for TLS SNI and Proxy
mService = service; // For proxy
std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols, headers));

changeState(State::Connecting);

Expand Down
2 changes: 1 addition & 1 deletion src/impl/websocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
WebSocket(optional<Configuration> optConfig = nullopt, certificate_ptr certificate = nullptr);
~WebSocket();

void open(const string &url);
void open(const string &url, const std::map<string, string> &headers = {});
void close();
void remoteClose();
bool outgoing(message_ptr message);
Expand Down
17 changes: 15 additions & 2 deletions src/impl/wshandshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ using std::chrono::system_clock;

WsHandshake::WsHandshake() {}

WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols)
: mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) {
WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols, std::map<string, string> headers)
: mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)), mCustomHeaders(std::move(headers)) {

if (mHost.empty())
throw std::invalid_argument("WebSocket HTTP host cannot be empty");
Expand All @@ -55,6 +55,11 @@ std::vector<string> WsHandshake::protocols() const {
return mProtocols;
}

std::multimap<string, string> WsHandshake::requestHeaders() const {
std::unique_lock lock(mMutex);
return mRequestHeaders;
}

string WsHandshake::generateHttpRequest() {
std::unique_lock lock(mMutex);
mKey = generateKey();
Expand All @@ -73,6 +78,11 @@ string WsHandshake::generateHttpRequest() {
if (!mProtocols.empty())
out += "Sec-WebSocket-Protocol: " + utils::implode(mProtocols, ',') + "\r\n";

// Add custom headers
for (const auto& [headerName, headerValue] : mCustomHeaders) {
out += headerName + ": " + headerValue + "\r\n";
}
Comment on lines +82 to +84
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should at least sanitize headerValue:

Suggested change
for (const auto& [headerName, headerValue] : mCustomHeaders) {
out += headerName + ": " + headerValue + "\r\n";
}
for (auto [headerName, headerValue] : mCustomHeaders) {
headerValue.erase(std::remove(headerValue.begin(), headerValue.end(), '\r'), headerValue.end());
std::replace(headerValue.begin(), headerValue.end(), '\n', ' ');
out += headerName + ": " + headerValue + "\r\n";
}


out += "\r\n";

return out;
Expand Down Expand Up @@ -161,6 +171,9 @@ size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {

auto headers = parseHttpHeaders(lines);

// Store all request headers for later access
mRequestHeaders = headers;

auto h = headers.find("host");
if (h == headers.end())
throw RequestError("WebSocket host header missing in request", 400);
Expand Down
5 changes: 4 additions & 1 deletion src/impl/wshandshake.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ namespace rtc::impl {
class WsHandshake final {
public:
WsHandshake();
WsHandshake(string host, string path = "/", std::vector<string> protocols = {});
WsHandshake(string host, string path = "/", std::vector<string> protocols = {}, std::map<string, string> headers = {});

string host() const;
string path() const;
std::vector<string> protocols() const;
std::multimap<string, string> requestHeaders() const;

string generateHttpRequest();
string generateHttpResponse();
Expand Down Expand Up @@ -57,6 +58,8 @@ class WsHandshake final {
string mHost;
string mPath;
std::vector<string> mProtocols;
std::map<string, string> mCustomHeaders;
std::multimap<string, string> mRequestHeaders;
Comment on lines +61 to +62
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to use two different maps here? You could use only one multimap.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey ! It's to add support for the websocket server.

But good suggestion. Rather than custom and request we can just rename them to request headers or just headers up to you.

string mKey;
mutable std::mutex mMutex;
};
Expand Down
10 changes: 9 additions & 1 deletion src/websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ bool WebSocket::isClosed() const { return impl()->state.load() == State::Closed;

size_t WebSocket::maxMessageSize() const { return impl()->maxMessageSize(); }

void WebSocket::open(const string &url) { impl()->open(url); }
void WebSocket::open(const string &url, const std::map<string, string> &headers) {
impl()->open(url, headers);
}

void WebSocket::close() { impl()->close(); }

Expand All @@ -68,6 +70,12 @@ optional<string> WebSocket::path() const {
return state != State::Connecting && handshake ? make_optional(handshake->path()) : nullopt;
}

std::multimap<string, string> WebSocket::requestHeaders() const {
auto state = impl()->state.load();
auto handshake = impl()->getWsHandshake();
return state != State::Connecting && handshake ? handshake->requestHeaders() : std::multimap<string, string>{};
}

std::ostream &operator<<(std::ostream &out, WebSocket::State state) {
using State = WebSocket::State;
const char *str;
Expand Down
41 changes: 36 additions & 5 deletions test/websocketserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <atomic>
#include <chrono>
#include <iostream>
#include <map>
#include <memory>
#include <thread>

Expand All @@ -34,18 +35,42 @@ void test_websocketserver() {
WebSocketServer server(std::move(serverConfig));

shared_ptr<WebSocket> client;
server.onClient([&client](shared_ptr<WebSocket> incoming) {
std::atomic<bool> headersVerified = false;
server.onClient([&client, &headersVerified](shared_ptr<WebSocket> incoming) {
cout << "WebSocketServer: Client connection received" << endl;
client = incoming;

if(auto addr = client->remoteAddress())
cout << "WebSocketServer: Client remote address is " << *addr << endl;

client->onOpen([wclient = make_weak_ptr(client)]() {
client->onOpen([wclient = make_weak_ptr(client), &headersVerified]() {
cout << "WebSocketServer: Client connection open" << endl;
if(auto client = wclient.lock())
if(auto client = wclient.lock()) {
if(auto path = client->path())
cout << "WebSocketServer: Requested path is " << *path << endl;

// Test custom headers functionality
auto requestHeadersMultimap = client->requestHeaders();
std::map<string, string> requestHeaders;
for (const auto& [key, value] : requestHeadersMultimap) {
requestHeaders[key] = value;
}

cout << "WebSocketServer: Checking custom headers..." << endl;
bool customHeaderFound = requestHeaders.find("x-test-custom") != requestHeaders.end() &&
requestHeaders.at("x-test-custom") == "test-value-123";
bool authHeaderFound = requestHeaders.find("authorization") != requestHeaders.end() &&
requestHeaders.at("authorization") == "Bearer custom-token";

if (customHeaderFound && authHeaderFound) {
cout << "WebSocketServer: Custom headers verified successfully!" << endl;
headersVerified = true;
} else {
cout << "WebSocketServer: Custom headers verification failed!" << endl;
cout << " X-Test-Custom header: " << (customHeaderFound ? "FOUND" : "NOT FOUND") << endl;
cout << " Authorization header: " << (authHeaderFound ? "FOUND" : "NOT FOUND") << endl;
}
}
});

client->onClosed([]() {
Expand Down Expand Up @@ -91,10 +116,13 @@ void test_websocketserver() {
}
});

ws.open("wss://localhost:48080/");
ws.open("wss://localhost:48080/", {
{"X-Test-Custom", "test-value-123"},
{"Authorization", "Bearer custom-token"}
});

int attempts = 15;
while ((!ws.isOpen() || !received) && attempts--)
while ((!ws.isOpen() || !received || !headersVerified) && attempts--)
this_thread::sleep_for(1s);

if (!ws.isOpen())
Expand All @@ -103,6 +131,9 @@ void test_websocketserver() {
if (!received || !maxSizeReceived)
throw runtime_error("Expected messages not received");

if (!headersVerified)
throw runtime_error("Custom headers not verified");

ws.close();
this_thread::sleep_for(1s);

Expand Down
Loading