diff options
-rw-r--r-- | sys/net/pfkeyv2.c | 135 |
1 files changed, 99 insertions, 36 deletions
diff --git a/sys/net/pfkeyv2.c b/sys/net/pfkeyv2.c index 94effba5cb5..8476a89eb60 100644 --- a/sys/net/pfkeyv2.c +++ b/sys/net/pfkeyv2.c @@ -1,4 +1,4 @@ -/* $OpenBSD: pfkeyv2.c,v 1.185 2018/06/20 09:44:51 mpi Exp $ */ +/* $OpenBSD: pfkeyv2.c,v 1.186 2018/06/25 09:48:17 mpi Exp $ */ /* * @(#)COPYRIGHT 1.1 (NRL) 17 January 1995 @@ -132,19 +132,27 @@ extern struct radix_node_head **spd_tables; struct sockaddr pfkey_addr = { 2, PF_KEY, }; struct domain pfkeydomain; +/* + * pfkey PCB + * + * Locks used to protect struct members in this file: + * I immutable after creation + * a atomic operations + * l pkptable's lock + * s socket lock + */ struct pkpcb { struct rawcb pkp_rcb; -#define kcb_socket pkp_rcb.rcb_socket -#define kcb_faddr pkp_rcb.rcb_faddr -#define kcb_laddr pkp_rcb.rcb_laddr -#define kcb_proto pkp_rcb.rcb_proto - - SRPL_ENTRY(pkpcb) kcb_list; - struct refcnt kcb_refcnt; - int kcb_flags; - uint32_t kcb_pid; - uint32_t kcb_registration; /* Inc. if SATYPE_MAX > 31 */ - unsigned int kcb_rdomain; +#define kcb_socket pkp_rcb.rcb_socket /* [I] associated socket */ +#define kcb_faddr pkp_rcb.rcb_faddr /* [I] */ +#define kcb_proto pkp_rcb.rcb_proto /* [I] */ + + SRPL_ENTRY(pkpcb) kcb_list; /* [l] */ + struct refcnt kcb_refcnt; /* [a] */ + int kcb_flags; /* [s] */ + uint32_t kcb_reg; /* [s] Inc if SATYPE_MAX > 31 */ + uint32_t kcb_pid; /* [I] */ + unsigned int kcb_rdomain; /* [I] routing domain */ }; #define sotokeycb(so) ((struct pkpcb *)(so)->so_pcb) #define keylock(kp) solock((kp)->kcb_socket) @@ -472,12 +480,11 @@ pfkeyv2_sendmessage(void **headers, int mode, struct socket *so, * original destination. */ SRPL_FOREACH(kp, &sr, &pkptable.pkp_list, kcb_list) { - if (kp->kcb_socket == so) + if (kp->kcb_socket == so || kp->kcb_rdomain != rdomain) continue; s = keylock(kp); - if ((kp->kcb_flags & PFKEYV2_SOCKETFLAGS_PROMISC) && - (kp->kcb_rdomain == rdomain)) + if (kp->kcb_flags & PFKEYV2_SOCKETFLAGS_PROMISC) pfkey_sendup(kp, packet, 1); keyunlock(kp, s); } @@ -491,15 +498,17 @@ pfkeyv2_sendmessage(void **headers, int mode, struct socket *so, * the specified satype (e.g., all IPSEC-ESP negotiators) */ SRPL_FOREACH(kp, &sr, &pkptable.pkp_list, kcb_list) { + if (kp->kcb_rdomain != rdomain) + continue; + s = keylock(kp); - if ((kp->kcb_flags & PFKEYV2_SOCKETFLAGS_REGISTERED) && - (kp->kcb_rdomain == rdomain)) { + if (kp->kcb_flags & PFKEYV2_SOCKETFLAGS_REGISTERED) { if (!satype) { /* Just send to everyone registered */ pfkey_sendup(kp, packet, 1); } else { /* Check for specified satype */ - if ((1 << satype) & kp->kcb_registration) + if ((1 << satype) & kp->kcb_reg) pfkey_sendup(kp, packet, 1); } } @@ -525,10 +534,12 @@ pfkeyv2_sendmessage(void **headers, int mode, struct socket *so, /* Send to all registered promiscuous listeners */ SRPL_FOREACH(kp, &sr, &pkptable.pkp_list, kcb_list) { + if (kp->kcb_rdomain != rdomain) + continue; + s = keylock(kp); if ((kp->kcb_flags & PFKEYV2_SOCKETFLAGS_PROMISC) && - !(kp->kcb_flags & PFKEYV2_SOCKETFLAGS_REGISTERED) && - (kp->kcb_rdomain == rdomain)) + !(kp->kcb_flags & PFKEYV2_SOCKETFLAGS_REGISTERED)) pfkey_sendup(kp, packet, 1); keyunlock(kp, s); } @@ -539,9 +550,11 @@ pfkeyv2_sendmessage(void **headers, int mode, struct socket *so, case PFKEYV2_SENDMESSAGE_BROADCAST: /* Send message to all sockets */ SRPL_FOREACH(kp, &sr, &pkptable.pkp_list, kcb_list) { + if (kp->kcb_rdomain != rdomain) + continue; + s = keylock(kp); - if (kp->kcb_rdomain == rdomain) - pfkey_sendup(kp, packet, 1); + pfkey_sendup(kp, packet, 1); keyunlock(kp, s); } SRPL_LEAVE(&sr); @@ -1029,8 +1042,6 @@ pfkeyv2_send(struct socket *so, void *message, int len) promisc = npromisc; mtx_leave(&pfkeyv2_mtx); - NET_LOCK(); - /* Verify that we received this over a legitimate pfkeyv2 socket */ bzero(headers, sizeof(headers)); @@ -1070,9 +1081,11 @@ pfkeyv2_send(struct socket *so, void *message, int len) /* Send to all promiscuous listeners */ SRPL_FOREACH(bkp, &sr, &pkptable.pkp_list, kcb_list) { + if (bkp->kcb_rdomain != rdomain) + continue; + s = keylock(bkp); - if ((bkp->kcb_flags & PFKEYV2_SOCKETFLAGS_PROMISC) && - (bkp->kcb_rdomain == rdomain)) + if (bkp->kcb_flags & PFKEYV2_SOCKETFLAGS_PROMISC) pfkey_sendup(bkp, packet, 1); keyunlock(bkp, s); } @@ -1109,16 +1122,20 @@ pfkeyv2_send(struct socket *so, void *message, int len) /* Find an unused SA identifier */ sprng = (struct sadb_spirange *) headers[SADB_EXT_SPIRANGE]; + NET_LOCK(); sa1->tdb_spi = reserve_spi(rdomain, sprng->sadb_spirange_min, sprng->sadb_spirange_max, &sa1->tdb_src, &sa1->tdb_dst, sa1->tdb_sproto, &rval); - if (sa1->tdb_spi == 0) + if (sa1->tdb_spi == 0) { + NET_UNLOCK(); goto ret; + } /* Send a message back telling what the SA (the SPI really) is */ if (!(freeme = malloc(sizeof(struct sadb_sa), M_PFKEY, M_NOWAIT | M_ZERO))) { rval = ENOMEM; + NET_UNLOCK(); goto ret; } @@ -1128,6 +1145,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) /* We really only care about the SPI, but we'll export the SA */ export_sa((void **) &bckptr, sa1); + NET_UNLOCK(); break; case SADB_UPDATE: @@ -1162,12 +1180,14 @@ pfkeyv2_send(struct socket *so, void *message, int len) #endif /* IPSEC */ /* Find TDB */ + NET_LOCK(); sa2 = gettdb(rdomain, ssa->sadb_sa_spi, sunionp, SADB_X_GETSPROTO(smsg->sadb_msg_satype)); /* If there's no such SA, we're done */ if (sa2 == NULL) { rval = ESRCH; + NET_UNLOCK(); goto ret; } @@ -1188,6 +1208,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) &newsa->tdb_sproto, &alg))) { tdb_free(freeme); freeme = NULL; + NET_UNLOCK(); goto ret; } @@ -1239,6 +1260,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) rval = EINVAL; tdb_free(freeme); freeme = NULL; + NET_UNLOCK(); goto ret; } @@ -1260,6 +1282,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) headers[SADB_EXT_IDENTITY_DST] || headers[SADB_EXT_SENSITIVITY]) { rval = EINVAL; + NET_UNLOCK(); goto ret; } @@ -1286,6 +1309,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) puttdb(sa2); } } + NET_UNLOCK(); break; case SADB_ADD: @@ -1319,18 +1343,21 @@ pfkeyv2_send(struct socket *so, void *message, int len) } #endif /* IPSEC */ + NET_LOCK(); sa2 = gettdb(rdomain, ssa->sadb_sa_spi, sunionp, SADB_X_GETSPROTO(smsg->sadb_msg_satype)); /* We can't add an existing SA! */ if (sa2 != NULL) { rval = EEXIST; + NET_UNLOCK(); goto ret; } /* We can only add "mature" SAs */ if (ssa->sadb_sa_state != SADB_SASTATE_MATURE) { rval = EINVAL; + NET_UNLOCK(); goto ret; } @@ -1349,6 +1376,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) &newsa->tdb_sproto, &alg))) { tdb_free(freeme); freeme = NULL; + NET_UNLOCK(); goto ret; } @@ -1402,12 +1430,14 @@ pfkeyv2_send(struct socket *so, void *message, int len) rval = EINVAL; tdb_free(freeme); freeme = NULL; + NET_UNLOCK(); goto ret; } } /* Add TDB in table */ puttdb((struct tdb *) freeme); + NET_UNLOCK(); freeme = NULL; break; @@ -1418,27 +1448,33 @@ pfkeyv2_send(struct socket *so, void *message, int len) (union sockaddr_union *)(headers[SADB_EXT_ADDRESS_DST] + sizeof(struct sadb_address)); + NET_LOCK(); sa2 = gettdb(rdomain, ssa->sadb_sa_spi, sunionp, SADB_X_GETSPROTO(smsg->sadb_msg_satype)); if (sa2 == NULL) { rval = ESRCH; + NET_UNLOCK(); goto ret; } tdb_delete(sa2); + NET_UNLOCK(); sa2 = NULL; break; case SADB_X_ASKPOLICY: /* Get the relevant policy */ + NET_LOCK(); ipa = ipsec_get_acquire(((struct sadb_x_policy *) headers[SADB_X_EXT_POLICY])->sadb_x_policy_seq); if (ipa == NULL) { rval = ESRCH; + NET_UNLOCK(); goto ret; } rval = pfkeyv2_policy(ipa, headers, &freeme); + NET_UNLOCK(); if (rval) mode = PFKEYV2_SENDMESSAGE_UNICAST; @@ -1450,26 +1486,31 @@ pfkeyv2_send(struct socket *so, void *message, int len) (union sockaddr_union *)(headers[SADB_EXT_ADDRESS_DST] + sizeof(struct sadb_address)); + NET_LOCK(); sa2 = gettdb(rdomain, ssa->sadb_sa_spi, sunionp, SADB_X_GETSPROTO(smsg->sadb_msg_satype)); if (sa2 == NULL) { rval = ESRCH; + NET_UNLOCK(); goto ret; } rval = pfkeyv2_get(sa2, headers, &freeme, NULL); + NET_UNLOCK(); if (rval) mode = PFKEYV2_SENDMESSAGE_UNICAST; break; case SADB_REGISTER: + s = keylock(kp); if (!(kp->kcb_flags & PFKEYV2_SOCKETFLAGS_REGISTERED)) { kp->kcb_flags |= PFKEYV2_SOCKETFLAGS_REGISTERED; mtx_enter(&pfkeyv2_mtx); nregistered++; mtx_leave(&pfkeyv2_mtx); } + keyunlock(kp, s); i = sizeof(struct sadb_supported) + sizeof(ealgs); @@ -1497,8 +1538,10 @@ pfkeyv2_send(struct socket *so, void *message, int len) } /* Keep track what this socket has registered for */ - kp->kcb_registration |= + s = keylock(kp); + kp->kcb_reg |= (1 << ((struct sadb_msg *)message)->sadb_msg_satype); + keyunlock(kp, s); ssup = (struct sadb_supported *) freeme; ssup->sadb_supported_len = i / sizeof(uint64_t); @@ -1540,6 +1583,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) case SADB_FLUSH: rval = 0; + NET_LOCK(); switch (smsg->sadb_msg_satype) { case SADB_SATYPE_UNSPEC: spd_table_walk(rdomain, pfkeyv2_policy_flush, NULL); @@ -1559,6 +1603,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) default: rval = EINVAL; /* Unknown/unsupported type */ } + NET_UNLOCK(); break; @@ -1568,7 +1613,9 @@ pfkeyv2_send(struct socket *so, void *message, int len) dump_state.sadb_msg = (struct sadb_msg *) headers[0]; dump_state.socket = so; + NET_LOCK(); rval = tdb_walk(rdomain, pfkeyv2_dump_walker, &dump_state); + NET_UNLOCK(); if (!rval) goto realret; if ((rval == ENOMEM) || (rval == ENOBUFS)) @@ -1585,10 +1632,12 @@ pfkeyv2_send(struct socket *so, void *message, int len) sunionp = (union sockaddr_union *) (headers[SADB_EXT_ADDRESS_DST] + sizeof(struct sadb_address)); + NET_LOCK(); tdb1 = gettdb(rdomain, ssa->sadb_sa_spi, sunionp, SADB_X_GETSPROTO(smsg->sadb_msg_satype)); if (tdb1 == NULL) { rval = ESRCH; + NET_UNLOCK(); goto ret; } @@ -1601,6 +1650,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) SADB_X_GETSPROTO(sa_proto->sadb_protocol_proto)); if (tdb2 == NULL) { rval = ESRCH; + NET_UNLOCK(); goto ret; } @@ -1608,6 +1658,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) for (tdb3 = tdb2; tdb3; tdb3 = tdb3->tdb_onext) if (tdb3 == tdb1) { rval = ESRCH; + NET_UNLOCK(); goto ret; } @@ -1623,6 +1674,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) /* Link them */ tdb1->tdb_onext = tdb2; tdb2->tdb_inext = tdb1; + NET_UNLOCK(); } break; @@ -1635,8 +1687,10 @@ pfkeyv2_send(struct socket *so, void *message, int len) union sockaddr_union *ssrc; int exists = 0; + NET_LOCK(); if ((rnh = spd_table_add(rdomain)) == NULL) { rval = ENOMEM; + NET_UNLOCK(); goto ret; } @@ -1645,6 +1699,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) if ((sab->sadb_protocol_direction != IPSP_DIRECTION_IN) && (sab->sadb_protocol_direction != IPSP_DIRECTION_OUT)) { rval = EINVAL; + NET_UNLOCK(); goto ret; } @@ -1693,6 +1748,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) if (exists && (ipo->ipo_flags & IPSP_POLICY_STATIC)) { if (!(sab->sadb_protocol_flags & SADB_X_POLICYFLAGS_POLICY)) { + NET_UNLOCK(); goto ret; } } @@ -1701,11 +1757,13 @@ pfkeyv2_send(struct socket *so, void *message, int len) if (delflag) { if (exists) { rval = ipsec_delete_policy(ipo); + NET_UNLOCK(); goto ret; } /* If we were asked to delete something non-existent, error. */ rval = ESRCH; + NET_UNLOCK(); break; } @@ -1721,6 +1779,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) ipo = pool_get(&ipsec_policy_pool, PR_NOWAIT|PR_ZERO); if (ipo == NULL) { rval = ENOMEM; + NET_UNLOCK(); goto ret; } } @@ -1757,6 +1816,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) ipsec_delete_policy(ipo); rval = EINVAL; + NET_UNLOCK(); goto ret; } @@ -1791,6 +1851,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) else pool_put(&ipsec_policy_pool, ipo); rval = ENOBUFS; + NET_UNLOCK(); goto ret; } } @@ -1820,7 +1881,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) if (ipo->ipo_ids) ipsp_ids_free(ipo->ipo_ids); pool_put(&ipsec_policy_pool, ipo); - + NET_UNLOCK(); goto ret; } TAILQ_INSERT_HEAD(&ipsec_policy_head, ipo, ipo_list); @@ -1836,6 +1897,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) } else { ipo->ipo_last_searched = ipo->ipo_flags = 0; } + NET_UNLOCK(); } break; @@ -1847,15 +1909,15 @@ pfkeyv2_send(struct socket *so, void *message, int len) goto ret; SRPL_FOREACH(bkp, &sr, &pkptable.pkp_list, kcb_list) { - if (bkp == kp) + if (bkp == kp || bkp->kcb_rdomain != rdomain) continue; - s = keylock(bkp); - if ((bkp->kcb_rdomain == rdomain) && - (!smsg->sadb_msg_seq || - (smsg->sadb_msg_seq == kp->kcb_pid))) + if (!smsg->sadb_msg_seq || + (smsg->sadb_msg_seq == kp->kcb_pid)) { + s = keylock(bkp); pfkey_sendup(bkp, packet, 1); - keyunlock(bkp, s); + keyunlock(bkp, s); + } } SRPL_LEAVE(&sr); @@ -1866,6 +1928,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) goto ret; } + s = keylock(kp); i = (kp->kcb_flags & PFKEYV2_SOCKETFLAGS_PROMISC) ? 1 : 0; j = smsg->sadb_msg_satype ? 1 : 0; @@ -1885,6 +1948,7 @@ pfkeyv2_send(struct socket *so, void *message, int len) mtx_leave(&pfkeyv2_mtx); } } + keyunlock(kp, s); } break; @@ -1922,7 +1986,6 @@ ret: rval = pfkeyv2_sendmessage(headers, mode, so, 0, 0, rdomain); realret: - NET_UNLOCK(); if (freeme) free(freeme, M_PFKEY, 0); |