Skip to content

Proper check for Origins in WS server. #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions server-utils/src/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{fmt, ops};
use hosts::{Host, Port};
use matcher::Matcher;
use matcher::{Matcher, Pattern};

/// Origin Protocol
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -74,11 +74,6 @@ impl Origin {
Origin::with_host(protocol, hostname)
}

/// Checks if given string matches the pattern.
pub fn matches<T: AsRef<str>>(&self, other: T) -> bool {
self.matcher.matches(other)
}

fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
format!(
"{}://{}",
Expand All @@ -92,6 +87,12 @@ impl Origin {
}
}

impl Pattern for Origin {
fn matches<T: AsRef<str>>(&self, other: T) -> bool {
self.matcher.matches(other)
}
}

impl ops::Deref for Origin {
type Target = str;
fn deref(&self) -> &Self::Target {
Expand Down
13 changes: 7 additions & 6 deletions server-utils/src/hosts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::collections::HashSet;
use std::net::SocketAddr;
use matcher::Matcher;
use matcher::{Matcher, Pattern};

const SPLIT_PROOF: &'static str = "split always returns non-empty iterator.";

Expand Down Expand Up @@ -80,11 +80,6 @@ impl Host {
Host::new(host, port)
}

/// Checks if given string matches the pattern.
pub fn matches<T: AsRef<str>>(&self, other: T) -> bool {
self.matcher.matches(other)
}

fn pre_process(host: &str) -> String {
// Remove possible protocol definition
let mut it = host.split("://");
Expand All @@ -111,6 +106,12 @@ impl Host {
}
}

impl Pattern for Host {
fn matches<T: AsRef<str>>(&self, other: T) -> bool {
self.matcher.matches(other)
}
}

impl ::std::ops::Deref for Host {
type Target = str;
fn deref(&self) -> &Self::Target {
Expand Down
2 changes: 2 additions & 0 deletions server-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ pub mod cors;
pub mod hosts;
pub mod reactor;
mod matcher;

pub use matcher::Pattern;
10 changes: 9 additions & 1 deletion server-utils/src/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ use globset::{GlobMatcher, GlobBuilder};
use std::ascii::AsciiExt;
use std::{fmt, hash};

/// Pattern that can be matched to string.
pub trait Pattern {
/// Returns true if given string matches the pattern.
fn matches<T: AsRef<str>>(&self, other: T) -> bool;
}

#[derive(Clone)]
pub struct Matcher(Option<GlobMatcher>, String);
impl Matcher {
Expand All @@ -16,8 +22,10 @@ impl Matcher {
string.into()
)
}
}

pub fn matches<T: AsRef<str>>(&self, other: T) -> bool {
impl Pattern for Matcher {
fn matches<T: AsRef<str>>(&self, other: T) -> bool {
let s = other.as_ref();
match self.0 {
Some(ref matcher) => matcher.is_match(s),
Expand Down
6 changes: 3 additions & 3 deletions ws/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std;
use std::ascii::AsciiExt;
use std::sync::{atomic, Arc};

use core;
use core::futures::Future;
use server_utils::Pattern;
use server_utils::cors::Origin;
use server_utils::hosts::Host;
use server_utils::tokio_core::reactor::Remote;
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S> {
}

fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool where
T: ::std::ops::Deref<Target=str>,
T: Pattern,
{
let header = header.map(std::str::from_utf8);

Expand All @@ -270,7 +270,7 @@ fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool
// Validate Origin
(Some(Ok(val)), Some(values)) => {
for v in values {
if val.eq_ignore_ascii_case(&v) {
if v.matches(val) {
return true
}
}
Expand Down