5G-LENA nr-v3.3-120-gdac69c56
The 5G/NR module for the ns-3 simulator
Loading...
Searching...
No Matches
nr-mac-scheduler-tdma-ai.cc
1// Copyright (c) 2024 Seoul National University (SNU)
2// Copyright (c) 2024 Centre Tecnologic de Telecomunicacions de Catalunya (CTTC)
3//
4// SPDX-License-Identifier: GPL-2.0-only
5
6#include "nr-mac-scheduler-tdma-ai.h"
7
8#include "ns3/boolean.h"
9#include "ns3/callback.h"
10#include "ns3/log.h"
11
12#include <algorithm>
13#include <functional>
14
15namespace ns3
16{
17
18NS_LOG_COMPONENT_DEFINE("NrMacSchedulerTdmaAi");
19NS_OBJECT_ENSURE_REGISTERED(NrMacSchedulerTdmaAi);
20
21TypeId
23{
24 static TypeId tid =
25 TypeId("ns3::NrMacSchedulerTdmaAi")
26 .SetParent<NrMacSchedulerTdmaQos>()
27 .AddConstructor<NrMacSchedulerTdmaAi>()
28 .AddAttribute("NotifyCbDl",
29 "The callback function to notify the AI model for the downlink",
30 CallbackValue(MakeNullCallback<NrMacSchedulerUeInfoAi::NotifyCb>()),
31 MakeCallbackAccessor(&NrMacSchedulerTdmaAi::m_notifyCbDl),
32 MakeCallbackChecker())
33 .AddAttribute("NotifyCbUl",
34 "The callback function to notify the AI model for the uplink",
35 CallbackValue(MakeNullCallback<NrMacSchedulerUeInfoAi::NotifyCb>()),
36 MakeCallbackAccessor(&NrMacSchedulerTdmaAi::m_notifyCbUl),
37 MakeCallbackChecker())
38 .AddAttribute("ActiveDlAi",
39 "The flag to activate the AI model for the downlink",
40 BooleanValue(false),
41 MakeBooleanAccessor(&NrMacSchedulerTdmaAi::m_activeDlAi),
42 MakeBooleanChecker())
43 .AddAttribute("ActiveUlAi",
44 "The flag to activate the AI model for the uplink",
45 BooleanValue(false),
46 MakeBooleanAccessor(&NrMacSchedulerTdmaAi::m_activeUlAi),
47 MakeBooleanChecker());
48 return tid;
49}
50
55
56std::shared_ptr<NrMacSchedulerUeInfo>
59{
60 NS_LOG_FUNCTION(this);
61 return std::make_shared<NrMacSchedulerUeInfoAi>(
62 m_alpha,
63 params.m_rnti,
64 params.m_beamId,
65 std::bind(&NrMacSchedulerTdmaAi::GetNumRbPerRbg, this));
66}
67
68std::function<bool(const NrMacSchedulerNs3::UePtrAndBufferReq& lhs,
78
79std::function<bool(const NrMacSchedulerNs3::UePtrAndBufferReq& lhs,
89
90void
92{
93 NS_LOG_FUNCTION(this);
94 m_notifyCbDl = notifyCb;
95 m_activeDlAi = true;
96}
97
98void
100{
101 NS_LOG_FUNCTION(this);
102 m_notifyCbUl = notifyCb;
103 m_activeUlAi = true;
104}
105
106std::vector<NrMacSchedulerUeInfoAi::LcObservation>
108 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
109{
110 NS_LOG_FUNCTION(this);
111 std::vector<NrMacSchedulerUeInfoAi::LcObservation> observations;
112 for (const auto& ue : ueVector)
113 {
114 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
115 std::vector<NrMacSchedulerUeInfoAi::LcObservation> ueObservation =
116 uePtr->GetDlObservation();
117 observations.insert(observations.end(), ueObservation.begin(), ueObservation.end());
118 }
119 return observations;
120}
121
122std::vector<NrMacSchedulerUeInfoAi::LcObservation>
124 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
125{
126 NS_LOG_FUNCTION(this);
127 std::vector<NrMacSchedulerUeInfoAi::LcObservation> observations;
128 for (const auto& ue : ueVector)
129 {
130 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
131 std::vector<NrMacSchedulerUeInfoAi::LcObservation> ueObservation =
132 uePtr->GetUlObservation();
133 observations.insert(observations.end(), ueObservation.begin(), ueObservation.end());
134 }
135 return observations;
136}
137
138bool
140{
141 return false;
142}
143
144bool
146{
147 return false;
148}
149
150float
152 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
153{
154 NS_LOG_FUNCTION(this);
155 float reward = 0.0;
156 for (const auto& ue : ueVector)
157 {
158 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
159 reward += uePtr->GetDlReward();
160 }
161 return reward;
162}
163
164float
166 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
167{
168 NS_LOG_FUNCTION(this);
169 float reward = 0.0;
170 for (const auto& ue : ueVector)
171 {
172 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
173 reward += uePtr->GetUlReward();
174 }
175 return reward;
176}
177
178void
180 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
181{
182 NS_LOG_FUNCTION(this);
183 if (!m_notifyCbDl.IsNull())
184 {
185 std::string extraInfo = "";
188 this,
189 std::placeholders::_1,
190 ueVector);
191 m_notifyCbDl(GetUeObservationsDl(ueVector),
193 GetUeRewardsDl(ueVector),
194 extraInfo,
195 updateWeightsFn);
196 }
197}
198
199void
201 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
202{
203 NS_LOG_FUNCTION(this);
204 if (!m_notifyCbUl.IsNull())
205 {
206 std::string extraInfo = "";
209 this,
210 std::placeholders::_1,
211 ueVector);
212 m_notifyCbUl(GetUeObservationsUl(ueVector),
214 GetUeRewardsUl(ueVector),
215 extraInfo,
216 updateWeightsFn);
217 }
218}
219
220void
223 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
224{
225 NS_LOG_FUNCTION(this);
226 for (const auto& ue : ueVector)
227 {
228 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
229 NrMacSchedulerUeInfoAi::Weights weights = ueWeights.at(uePtr->m_rnti);
230 uePtr->UpdateDlWeights(weights);
231 }
232}
233
234void
237 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
238{
239 NS_LOG_FUNCTION(this);
240 for (const auto& ue : ueVector)
241 {
242 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
243 NrMacSchedulerUeInfoAi::Weights weights = ueWeights.at(uePtr->m_rnti);
244 uePtr->UpdateUlWeights(weights);
245 }
246}
247
248} // namespace ns3
bool m_activeUlAi
Flag for activating AI for uplink.
bool m_activeDlAi
Flag for activating AI for downlink.
uint64_t GetNumRbPerRbg() const
Private function that is used to get the number of resource blocks per resource block group and also ...
std::pair< UePtr, uint32_t > UePtrAndBufferReq
Pair between a pointer to NrMacSchedulerUeInfo and its buffer occupancy.
NrMacSchedulerTdmaAi()
NrMacSchedulerTdma constructor.
std::shared_ptr< NrMacSchedulerUeInfo > CreateUeRepresentation(const NrMacCschedSapProvider::CschedUeConfigReqParameters &params) const override
Create an UE representation of the type NrMacSchedulerUeInfoAi.
void CallNotifyDlFn(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const override
Call the notify callback function in the OpenGymEnv class in the ns3-gym module for downlink.
void UpdateAllUeWeightsDl(const NrMacSchedulerUeInfoAi::UeWeightsMap &ueWeights, const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Update weights of all UEs for downlink.
bool GetIsGameOverUl() const
Check if the uplink game is over.
void UpdateAllUeWeightsUl(const NrMacSchedulerUeInfoAi::UeWeightsMap &ueWeights, const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Update weights of all UEs for uplink.
float GetUeRewardsDl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get rewards for downlink.
void CallNotifyUlFn(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const override
Call the notify callback function in the OpenGymEnv class in the ns3-gym module for uplink.
static TypeId GetTypeId()
GetTypeId.
std::function< bool(const NrMacSchedulerNs3::UePtrAndBufferReq &lhs, const NrMacSchedulerNs3::UePtrAndBufferReq &rhs)> GetUeCompareUlFn() const override
Return the comparison function to sort UL UEs according to the scheduler policy.
void SetNotifyCbDl(NrMacSchedulerUeInfoAi::NotifyCb notifyCb)
Set the notify callback function for downlink.
std::vector< NrMacSchedulerUeInfoAi::LcObservation > GetUeObservationsDl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get UE observations for downlink.
std::function< bool(const NrMacSchedulerNs3::UePtrAndBufferReq &lhs, const NrMacSchedulerNs3::UePtrAndBufferReq &rhs)> GetUeCompareDlFn() const override
Return the comparison function to sort DL UEs according to the scheduler policy.
bool GetIsGameOverDl() const
Check if the downlink game is over.
std::vector< NrMacSchedulerUeInfoAi::LcObservation > GetUeObservationsUl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get UE observations for uplink.
float GetUeRewardsUl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get rewards for uplink.
void SetNotifyCbUl(NrMacSchedulerUeInfoAi::NotifyCb notifyCb)
Set the notify callback function for uplink.
Assign entire symbols in a QoS-based fashion.
static bool CompareUeWeightsUl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...
static bool CompareUeWeightsDl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...
std::unordered_map< uint8_t, double > Weights
A hash map for weights.
std::unordered_map< uint8_t, Weights > UeWeightsMap
A hash map for UE weights.
Callback< void, const std::vector< LcObservation > &, bool, float, const std::string &, const UpdateAllUeWeightsFn & > NotifyCb
A callback type for notifying with specific parameters.
std::function< void(const UeWeightsMap &)> UpdateAllUeWeightsFn
A function type for updating the weights of all UEs.
static bool CompareUeWeightsUl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...
static bool CompareUeWeightsDl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...