diff --git a/repos/os/include/net/internet_checksum.h b/repos/os/include/net/internet_checksum.h index 0768f9ca8d..8ff8ccd566 100644 --- a/repos/os/include/net/internet_checksum.h +++ b/repos/os/include/net/internet_checksum.h @@ -52,6 +52,35 @@ namespace Net { Ipv4_packet::Protocol ip_prot, Ipv4_address &ip_src, Ipv4_address &ip_dst); + + /** + * Accumulating modifier for incremental updates of internet checksums + */ + class Internet_checksum_diff + { + private: + + signed long _value { 0 }; + + public: + + /** + * Update modifier according to a data update in the target region + * + * PRECONDITIONS + * + * * The pointers must refer to data that is at an offset inside + * the checksum'd region that is a multiple of 2 bytes (16 bits). + */ + void add_up_diff(Packed_uint16 const *new_data_ptr, + Packed_uint16 const *old_data_ptr, + Genode::size_t data_sz); + + /** + * Return the given checksum with this modifier applied + */ + Genode::uint16_t apply_to(signed long sum) const; + }; } #endif /* _NET__INTERNET_CHECKSUM_H_ */ diff --git a/repos/os/include/net/ipv4.h b/repos/os/include/net/ipv4.h index 7609bd9b9d..ad01ceea83 100644 --- a/repos/os/include/net/ipv4.h +++ b/repos/os/include/net/ipv4.h @@ -27,8 +27,10 @@ namespace Genode { class Output; } namespace Net { + enum { IPV4_ADDR_LEN = 4 }; + class Internet_checksum_diff; class Ipv4_address; class Ipv4_packet; @@ -94,6 +96,8 @@ class Net::Ipv4_packet void update_checksum(); + void update_checksum(Internet_checksum_diff const &icd); + bool checksum_error() const; private: @@ -237,6 +241,9 @@ class Net::Ipv4_packet _offset_6_u16 = host_to_big_endian(be); } + void src(Ipv4_address v, Internet_checksum_diff &icd); + void dst(Ipv4_address v, Internet_checksum_diff &icd); + /********* ** log ** diff --git a/repos/os/src/lib/net/internet_checksum.cc b/repos/os/src/lib/net/internet_checksum.cc index 9628fd6607..0580337e90 100644 --- a/repos/os/src/lib/net/internet_checksum.cc +++ b/repos/os/src/lib/net/internet_checksum.cc @@ -29,6 +29,14 @@ struct Packed_uint8 } __attribute__((packed)); +static void fold_checksum_to_16_bits(signed long &sum) +{ + while (addr_t const remainder = sum >> 16) { + sum = (sum & 0xffff) + remainder; + } +} + + static uint16_t checksum_of_raw_data(Packed_uint16 const *data_ptr, size_t data_sz, signed long sum) @@ -42,9 +50,7 @@ static uint16_t checksum_of_raw_data(Packed_uint16 const *data_ptr, if (data_sz > 0) { sum += ((Packed_uint8 const *)data_ptr)->value; } - /* fold sum to 16-bit value */ - while (addr_t const sum_rsh = sum >> 16) - sum = (sum & 0xffff) + sum_rsh; + fold_checksum_to_16_bits(sum); /* return one's complement */ return (uint16_t)(~sum); @@ -84,3 +90,34 @@ uint16_t Net::internet_checksum_pseudo_ip(Packed_uint16 const *data_ptr, /* add up data bytes */ return checksum_of_raw_data(data_ptr, data_sz, sum); } + + +/**************************** + ** Internet_checksum_diff ** + ****************************/ + +void Internet_checksum_diff::add_up_diff(Packed_uint16 const *new_data_ptr, + Packed_uint16 const *old_data_ptr, + size_t data_sz) +{ + /* add up byte differences in pairs */ + signed long diff { 0 }; + for (; data_sz > 1; data_sz -= sizeof(Packed_uint16)) { + diff += old_data_ptr->value - new_data_ptr->value; + old_data_ptr++; + new_data_ptr++; + } + /* add difference of left-over byte, if any */ + if (data_sz > 0) { + diff += *(uint8_t *)old_data_ptr - *(uint8_t *)new_data_ptr; + } + _value += diff; +} + + +uint16_t Internet_checksum_diff::apply_to(signed long sum) const +{ + sum += _value; + fold_checksum_to_16_bits(sum); + return (uint16_t)sum; +} diff --git a/repos/os/src/lib/net/ipv4.cc b/repos/os/src/lib/net/ipv4.cc index e4383c54c8..e2a4e5646a 100644 --- a/repos/os/src/lib/net/ipv4.cc +++ b/repos/os/src/lib/net/ipv4.cc @@ -156,3 +156,23 @@ size_t Ipv4_packet::size(size_t max_size) const size_t const stated_size = total_length(); return stated_size < max_size ? stated_size : max_size; } + + +void Ipv4_packet::src(Ipv4_address v, Internet_checksum_diff &icd) +{ + icd.add_up_diff((Packed_uint16 *)&v.addr[0], (Packed_uint16 *)&_src[0], 4); + src(v); +} + + +void Ipv4_packet::dst(Ipv4_address v, Internet_checksum_diff &icd) +{ + icd.add_up_diff((Packed_uint16 *)&v.addr[0], (Packed_uint16 *)&_dst[0], 4); + dst(v); +} + + +void Ipv4_packet::update_checksum(Internet_checksum_diff const &icd) +{ + _checksum = icd.apply_to(_checksum); +} diff --git a/repos/os/src/server/nic_router/interface.cc b/repos/os/src/server/nic_router/interface.cc index 2a7a838e52..114587a032 100644 --- a/repos/os/src/server/nic_router/interface.cc +++ b/repos/os/src/server/nic_router/interface.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include /* local includes */ @@ -293,18 +294,19 @@ void Interface::_destroy_link(Link &link) } -void Interface::_pass_prot_to_domain(Domain &domain, - Ethernet_frame ð, - Size_guard &size_guard, - Ipv4_packet &ip, - L3_protocol const prot, - void *const prot_base, - size_t const prot_size) +void Interface::_pass_prot_to_domain(Domain &domain, + Ethernet_frame ð, + Size_guard &size_guard, + Ipv4_packet &ip, + Internet_checksum_diff const &ip_icd, + L3_protocol const prot, + void *const prot_base, + size_t const prot_size) { _update_checksum( prot, prot_base, prot_size, ip.src(), ip.dst(), ip.total_length()); - ip.update_checksum(); + ip.update_checksum(ip_icd); domain.interfaces().for_each([&] (Interface &interface) { eth.src(interface._router_mac); @@ -576,15 +578,16 @@ void Interface::_adapt_eth(Ethernet_frame ð, } -void Interface::_nat_link_and_pass(Ethernet_frame ð, - Size_guard &size_guard, - Ipv4_packet &ip, - L3_protocol const prot, - void *const prot_base, - size_t const prot_size, - Link_side_id const &local_id, - Domain &local_domain, - Domain &remote_domain) +void Interface::_nat_link_and_pass(Ethernet_frame ð, + Size_guard &size_guard, + Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, + L3_protocol const prot, + void *const prot_base, + size_t const prot_size, + Link_side_id const &local_id, + Domain &local_domain, + Domain &remote_domain) { try { Pointer remote_port_alloc; @@ -596,7 +599,7 @@ void Interface::_nat_link_and_pass(Ethernet_frame ð, log("[", local_domain, "] using NAT rule: ", nat); } _src_port(prot, prot_base, nat.port_alloc(prot).alloc()); - ip.src(remote_domain.ip_config().interface().address); + ip.src(remote_domain.ip_config().interface().address, ip_icd); remote_port_alloc = nat.port_alloc(prot); }, [&] /* no_match */ () { } @@ -605,7 +608,8 @@ void Interface::_nat_link_and_pass(Ethernet_frame ð, ip.src(), _src_port(prot, prot_base) }; _new_link(prot, local_id, remote_port_alloc, remote_domain, remote_id); _pass_prot_to_domain( - remote_domain, eth, size_guard, ip, prot, prot_base, prot_size); + remote_domain, eth, size_guard, ip, ip_icd, prot, prot_base, + prot_size); } catch (Port_allocator_guard::Out_of_indices) { switch (prot) { @@ -975,6 +979,7 @@ void Interface::_send_icmp_echo_reply(Ethernet_frame ð, void Interface::_handle_icmp_query(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, Packet_descriptor const &pkt, L3_protocol prot, void *prot_base, @@ -999,13 +1004,13 @@ void Interface::_handle_icmp_query(Ethernet_frame ð, " link: ", link); } _adapt_eth(eth, remote_side.src_ip(), pkt, remote_domain); - ip.src(remote_side.dst_ip()); - ip.dst(remote_side.src_ip()); + ip.src(remote_side.dst_ip(), ip_icd); + ip.dst(remote_side.src_ip(), ip_icd); _src_port(prot, prot_base, remote_side.dst_port()); _dst_port(prot, prot_base, remote_side.src_port()); _pass_prot_to_domain( - remote_domain, eth, size_guard, ip, prot, prot_base, - prot_size); + remote_domain, eth, size_guard, ip, ip_icd, prot, + prot_base, prot_size); _link_packet(prot, prot_base, link, client); done = true; @@ -1026,8 +1031,9 @@ void Interface::_handle_icmp_query(Ethernet_frame ð, Domain &remote_domain = rule.domain(); _adapt_eth(eth, local_id.dst_ip, pkt, remote_domain); - _nat_link_and_pass(eth, size_guard, ip, prot, prot_base, prot_size, - local_id, local_domain, remote_domain); + _nat_link_and_pass(eth, size_guard, ip, ip_icd, prot, prot_base, + prot_size, local_id, local_domain, + remote_domain); done = true; }, @@ -1044,13 +1050,16 @@ void Interface::_handle_icmp_query(Ethernet_frame ð, void Interface::_handle_icmp_error(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, Packet_descriptor const &pkt, Domain &local_domain, Icmp_packet &icmp, size_t icmp_sz) { + Ipv4_packet &embed_ip { icmp.data(size_guard) }; + Internet_checksum_diff embed_ip_icd { }; + /* drop packet if embedded IP checksum invalid */ - Ipv4_packet &embed_ip = icmp.data(size_guard); if (embed_ip.checksum_error()) { throw Drop_packet("bad checksum in IP packet embedded in ICMP error"); } @@ -1078,20 +1087,20 @@ void Interface::_handle_icmp_error(Ethernet_frame ð, /* adapt source and destination of Ethernet frame and IP packet */ _adapt_eth(eth, remote_side.src_ip(), pkt, remote_domain); if (remote_side.dst_ip() == remote_domain.ip_config().interface().address) { - ip.src(remote_side.dst_ip()); + ip.src(remote_side.dst_ip(), ip_icd); } - ip.dst(remote_side.src_ip()); + ip.dst(remote_side.src_ip(), ip_icd); /* adapt source and destination of embedded IP and transport packet */ - embed_ip.src(remote_side.src_ip()); - embed_ip.dst(remote_side.dst_ip()); + embed_ip.src(remote_side.src_ip(), embed_ip_icd); + embed_ip.dst(remote_side.dst_ip(), embed_ip_icd); _src_port(embed_prot, embed_prot_base, remote_side.src_port()); _dst_port(embed_prot, embed_prot_base, remote_side.dst_port()); /* update checksum of both IP headers and the ICMP header */ - embed_ip.update_checksum(); + embed_ip.update_checksum(embed_ip_icd); icmp.update_checksum(icmp_sz - sizeof(Icmp_packet)); - ip.update_checksum(); + ip.update_checksum(ip_icd); /* send adapted packet to all interfaces of remote domain */ remote_domain.interfaces().for_each([&] (Interface &interface) { @@ -1113,6 +1122,7 @@ void Interface::_handle_icmp_error(Ethernet_frame ð, void Interface::_handle_icmp(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, Packet_descriptor const &pkt, L3_protocol prot, void *prot_base, @@ -1139,8 +1149,8 @@ void Interface::_handle_icmp(Ethernet_frame ð, /* try to act as ICMP router */ switch (icmp.type()) { case Icmp_packet::Type::ECHO_REPLY: - case Icmp_packet::Type::ECHO_REQUEST: _handle_icmp_query(eth, size_guard, ip, pkt, prot, prot_base, prot_size, local_domain); break; - case Icmp_packet::Type::DST_UNREACHABLE: _handle_icmp_error(eth, size_guard, ip, pkt, local_domain, icmp, prot_size); break; + case Icmp_packet::Type::ECHO_REQUEST: _handle_icmp_query(eth, size_guard, ip, ip_icd, pkt, prot, prot_base, prot_size, local_domain); break; + case Icmp_packet::Type::DST_UNREACHABLE: _handle_icmp_error(eth, size_guard, ip, ip_icd, pkt, local_domain, icmp, prot_size); break; default: Drop_packet("unhandled type in ICMP"); } } @@ -1150,8 +1160,10 @@ void Interface::_handle_ip(Ethernet_frame ð, Packet_descriptor const &pkt, Domain &local_domain) { + Ipv4_packet &ip { eth.data(size_guard) }; + Internet_checksum_diff ip_icd { }; + /* drop fragmented IPv4 as it isn't supported */ - Ipv4_packet &ip = eth.data(size_guard); Ipv4_address_prefix const &local_intf = local_domain.ip_config().interface(); if (ip.more_fragments() || ip.fragment_offset() != 0) { @@ -1222,8 +1234,8 @@ void Interface::_handle_ip(Ethernet_frame ð, } } else if (prot == L3_protocol::ICMP) { - _handle_icmp(eth, size_guard, ip, pkt, prot, prot_base, prot_size, - local_domain, local_intf); + _handle_icmp(eth, size_guard, ip, ip_icd, pkt, prot, prot_base, + prot_size, local_domain, local_intf); return; } @@ -1246,13 +1258,13 @@ void Interface::_handle_ip(Ethernet_frame ð, " link: ", link); } _adapt_eth(eth, remote_side.src_ip(), pkt, remote_domain); - ip.src(remote_side.dst_ip()); - ip.dst(remote_side.src_ip()); + ip.src(remote_side.dst_ip(), ip_icd); + ip.dst(remote_side.src_ip(), ip_icd); _src_port(prot, prot_base, remote_side.dst_port()); _dst_port(prot, prot_base, remote_side.src_port()); _pass_prot_to_domain( - remote_domain, eth, size_guard, ip, prot, prot_base, - prot_size); + remote_domain, eth, size_guard, ip, ip_icd, prot, + prot_base, prot_size); _link_packet(prot, prot_base, link, client); done = true; @@ -1276,13 +1288,13 @@ void Interface::_handle_ip(Ethernet_frame ð, } Domain &remote_domain = rule.domain(); _adapt_eth(eth, rule.to_ip(), pkt, remote_domain); - ip.dst(rule.to_ip()); + ip.dst(rule.to_ip(), ip_icd); if (!(rule.to_port() == Port(0))) { _dst_port(prot, prot_base, rule.to_port()); } _nat_link_and_pass( - eth, size_guard, ip, prot, prot_base, prot_size, - local_id, local_domain, remote_domain); + eth, size_guard, ip, ip_icd, prot, prot_base, + prot_size, local_id, local_domain, remote_domain); done = true; }, @@ -1307,7 +1319,7 @@ void Interface::_handle_ip(Ethernet_frame ð, Domain &remote_domain = permit_rule.domain(); _adapt_eth(eth, local_id.dst_ip, pkt, remote_domain); _nat_link_and_pass( - eth, size_guard, ip, prot, prot_base, prot_size, + eth, size_guard, ip, ip_icd, prot, prot_base, prot_size, local_id, local_domain, remote_domain); done = true; diff --git a/repos/os/src/server/nic_router/interface.h b/repos/os/src/server/nic_router/interface.h index 7e585767e5..b7e798c4a9 100644 --- a/repos/os/src/server/nic_router/interface.h +++ b/repos/os/src/server/nic_router/interface.h @@ -230,6 +230,7 @@ class Net::Interface : private Interface_list::Element void _handle_icmp_query(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, Packet_descriptor const &pkt, L3_protocol prot, void *prot_base, @@ -239,6 +240,7 @@ class Net::Interface : private Interface_list::Element void _handle_icmp_error(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, Packet_descriptor const &pkt, Domain &local_domain, Icmp_packet &icmp, @@ -247,6 +249,7 @@ class Net::Interface : private Interface_list::Element void _handle_icmp(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, Packet_descriptor const &pkt, L3_protocol prot, void *prot_base, @@ -262,6 +265,7 @@ class Net::Interface : private Interface_list::Element void _nat_link_and_pass(Ethernet_frame ð, Size_guard &size_guard, Ipv4_packet &ip, + Internet_checksum_diff &ip_icd, L3_protocol const prot, void *const prot_base, Genode::size_t const prot_size, @@ -276,13 +280,14 @@ class Net::Interface : private Interface_list::Element Size_guard &size_guard, Domain &local_domain); - void _pass_prot_to_domain(Domain &domain, - Ethernet_frame ð, - Size_guard &size_guard, - Ipv4_packet &ip, - L3_protocol const prot, - void *const prot_base, - Genode::size_t const prot_size); + void _pass_prot_to_domain(Domain &domain, + Ethernet_frame ð, + Size_guard &size_guard, + Ipv4_packet &ip, + Internet_checksum_diff const &ip_icd, + L3_protocol const prot, + void *const prot_base, + Genode::size_t const prot_size); void _handle_pkt();